diff --git a/doc/throttle-debounce.png b/doc/throttle-debounce.png new file mode 100644 index 000000000..9d26208a7 Binary files /dev/null and b/doc/throttle-debounce.png differ diff --git a/src/main/java/cz/cvut/kbss/termit/config/AppConfig.java b/src/main/java/cz/cvut/kbss/termit/config/AppConfig.java index 9a6ec39d7..835dd77e5 100644 --- a/src/main/java/cz/cvut/kbss/termit/config/AppConfig.java +++ b/src/main/java/cz/cvut/kbss/termit/config/AppConfig.java @@ -19,14 +19,17 @@ import cz.cvut.kbss.termit.util.AsyncExceptionHandler; import org.springframework.aop.interceptor.AsyncUncaughtExceptionHandler; +import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.EnableAspectJAutoProxy; import org.springframework.context.annotation.EnableMBeanExport; +import org.springframework.context.annotation.ImportResource; import org.springframework.context.annotation.aspectj.EnableSpringConfigured; import org.springframework.retry.annotation.EnableRetry; import org.springframework.scheduling.annotation.AsyncConfigurer; import org.springframework.scheduling.annotation.EnableAsync; import org.springframework.scheduling.annotation.EnableScheduling; +import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; @Configuration @EnableMBeanExport @@ -35,10 +38,24 @@ @EnableAsync @EnableScheduling @EnableRetry +@ImportResource("classpath*:spring-aop.xml") public class AppConfig implements AsyncConfigurer { @Override public AsyncUncaughtExceptionHandler getAsyncUncaughtExceptionHandler() { return new AsyncExceptionHandler(); } + + /** + * This thread pool is responsible for executing long-running tasks in the application. + */ + @Bean(destroyMethod = "destroy") + public ThreadPoolTaskScheduler longRunningTaskScheduler(cz.cvut.kbss.termit.util.Configuration config) { + ThreadPoolTaskScheduler threadPoolTaskScheduler = new ThreadPoolTaskScheduler(); + threadPoolTaskScheduler.setPoolSize(config.getAsyncThreadCount()); + threadPoolTaskScheduler.setThreadNamePrefix("TermItScheduler-"); + threadPoolTaskScheduler.setWaitForTasksToCompleteOnShutdown(true); + threadPoolTaskScheduler.setRemoveOnCancelPolicy(true); + return threadPoolTaskScheduler; + } } diff --git a/src/main/java/cz/cvut/kbss/termit/config/WebAppConfig.java b/src/main/java/cz/cvut/kbss/termit/config/WebAppConfig.java index 61781c11a..eaa03f2e8 100644 --- a/src/main/java/cz/cvut/kbss/termit/config/WebAppConfig.java +++ b/src/main/java/cz/cvut/kbss/termit/config/WebAppConfig.java @@ -62,19 +62,13 @@ @Configuration public class WebAppConfig implements WebMvcConfigurer { - private final cz.cvut.kbss.termit.util.Configuration.Repository config; + private final cz.cvut.kbss.termit.util.Configuration config; @Value("${application.version:development}") private String version; public WebAppConfig(cz.cvut.kbss.termit.util.Configuration config) { - this.config = config.getRepository(); - } - - @Bean(name = "objectMapper") - @Primary - public ObjectMapper objectMapper() { - return createJsonObjectMapper(); + this.config = config; } /** @@ -99,11 +93,6 @@ public static ObjectMapper createJsonObjectMapper() { return objectMapper; } - @Bean(name = "jsonLdMapper") - public ObjectMapper jsonLdObjectMapper() { - return createJsonLdObjectMapper(); - } - /** * Creates an {@link ObjectMapper} for processing JSON-LD using the JB4JSON-LD library. *

@@ -119,9 +108,21 @@ public static ObjectMapper createJsonLdObjectMapper() { jsonLdModule.configure(cz.cvut.kbss.jsonld.ConfigParam.SCAN_PACKAGE, "cz.cvut.kbss.termit"); jsonLdModule.configure(SerializationConstants.FORM, SerializationConstants.FORM_COMPACT_WITH_CONTEXT); mapper.registerModule(jsonLdModule); + mapper.registerModule(new JavaTimeModule()); return mapper; } + @Bean(name = "objectMapper") + @Primary + public ObjectMapper objectMapper() { + return createJsonObjectMapper(); + } + + @Bean(name = "jsonLdMapper") + public ObjectMapper jsonLdObjectMapper() { + return createJsonLdObjectMapper(); + } + /** * Register the proxy for SPARQL endpoint. * @@ -133,10 +134,11 @@ public ServletWrappingController sparqlEndpointController() throws Exception { controller.setServletClass(AdjustedUriTemplateProxyServlet.class); controller.setBeanName("sparqlEndpointProxyServlet"); final Properties p = new Properties(); - p.setProperty("targetUri", config.getUrl()); + final cz.cvut.kbss.termit.util.Configuration.Repository repository = config.getRepository(); + p.setProperty("targetUri", repository.getUrl()); p.setProperty("log", "false"); - p.setProperty(ConfigParam.REPO_USERNAME.toString(), config.getUsername() != null ? config.getUsername() : ""); - p.setProperty(ConfigParam.REPO_PASSWORD.toString(), config.getPassword() != null ? config.getPassword() : ""); + p.setProperty(ConfigParam.REPO_USERNAME.toString(), repository.getUsername() != null ? repository.getUsername() : ""); + p.setProperty(ConfigParam.REPO_PASSWORD.toString(), repository.getPassword() != null ? repository.getPassword() : ""); controller.setInitParameters(p); controller.afterPropertiesSet(); return controller; @@ -147,7 +149,7 @@ public SimpleUrlHandlerMapping sparqlQueryControllerMapping() throws Exception { SimpleUrlHandlerMapping mapping = new SimpleUrlHandlerMapping(); mapping.setOrder(0); final Map urlMap = Collections.singletonMap(Constants.REST_MAPPING_PATH + "/query", - sparqlEndpointController()); + sparqlEndpointController()); mapping.setUrlMap(urlMap); return mapping; } @@ -193,10 +195,10 @@ public FilterRegistrationBean mdcFilter() { @Bean public OpenAPI customOpenAPI() { return new OpenAPI().components(new Components().addSecuritySchemes("bearer-key", - new SecurityScheme().type( - SecurityScheme.Type.HTTP) - .scheme("bearer") - .bearerFormat("JWT"))) + new SecurityScheme().type( + SecurityScheme.Type.HTTP) + .scheme("bearer") + .bearerFormat("JWT"))) .info(new Info().title("TermIt REST API").description("TermIt REST API definition.") .version(version)); } diff --git a/src/main/java/cz/cvut/kbss/termit/config/WebSocketMessageBrokerConfig.java b/src/main/java/cz/cvut/kbss/termit/config/WebSocketMessageBrokerConfig.java index ceefa1273..3f5cdb08d 100644 --- a/src/main/java/cz/cvut/kbss/termit/config/WebSocketMessageBrokerConfig.java +++ b/src/main/java/cz/cvut/kbss/termit/config/WebSocketMessageBrokerConfig.java @@ -4,19 +4,17 @@ import cz.cvut.kbss.termit.util.Constants; import cz.cvut.kbss.termit.websocket.handler.StompExceptionHandler; import cz.cvut.kbss.termit.websocket.handler.WebSocketExceptionHandler; -import cz.cvut.kbss.termit.websocket.handler.WebSocketMessageWithHeadersValueHandler; -import org.jetbrains.annotations.NotNull; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Lazy; import org.springframework.core.Ordered; import org.springframework.core.annotation.Order; +import org.springframework.lang.NonNull; import org.springframework.messaging.Message; import org.springframework.messaging.converter.MappingJackson2MessageConverter; import org.springframework.messaging.converter.MessageConverter; import org.springframework.messaging.converter.StringMessageConverter; import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver; -import org.springframework.messaging.handler.invocation.HandlerMethodReturnValueHandler; import org.springframework.messaging.simp.SimpMessagingTemplate; import org.springframework.messaging.simp.config.ChannelRegistration; import org.springframework.messaging.simp.config.MessageBrokerRegistry; @@ -88,17 +86,12 @@ public void addArgumentResolvers(List argumentRes * @see Spring security source */ @Override - public void configureClientInboundChannel(@NotNull ChannelRegistration registration) { + public void configureClientInboundChannel(@NonNull ChannelRegistration registration) { AuthorizationChannelInterceptor interceptor = new AuthorizationChannelInterceptor(messageAuthorizationManager); interceptor.setAuthorizationEventPublisher(new SpringAuthorizationEventPublisher(context)); registration.interceptors(webSocketJwtAuthorizationInterceptor, new SecurityContextChannelInterceptor(), interceptor); } - @Override - public void addReturnValueHandlers(List returnValueHandlers) { - returnValueHandlers.add(new WebSocketMessageWithHeadersValueHandler(simpMessagingTemplate)); - } - @Override public void registerStompEndpoints(StompEndpointRegistry registry) { registry.addEndpoint("/ws").setAllowedOrigins(allowedOrigins.split(",")); diff --git a/src/main/java/cz/cvut/kbss/termit/event/FileTextAnalysisFinishedEvent.java b/src/main/java/cz/cvut/kbss/termit/event/FileTextAnalysisFinishedEvent.java new file mode 100644 index 000000000..d8d7caa40 --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/event/FileTextAnalysisFinishedEvent.java @@ -0,0 +1,23 @@ +package cz.cvut.kbss.termit.event; + +import cz.cvut.kbss.termit.model.resource.File; +import org.springframework.lang.NonNull; + +import java.net.URI; + +/** + * Indicates that text analysis of a file was finished + */ +public class FileTextAnalysisFinishedEvent extends VocabularyEvent { + + private final URI fileUri; + + public FileTextAnalysisFinishedEvent(Object source, @NonNull File file) { + super(source, file.getDocument().getVocabulary()); + this.fileUri = file.getUri(); + } + + public URI getFileUri() { + return fileUri; + } +} diff --git a/src/main/java/cz/cvut/kbss/termit/event/LongRunningTaskChangedEvent.java b/src/main/java/cz/cvut/kbss/termit/event/LongRunningTaskChangedEvent.java new file mode 100644 index 000000000..fd3cf7af1 --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/event/LongRunningTaskChangedEvent.java @@ -0,0 +1,22 @@ +package cz.cvut.kbss.termit.event; + +import cz.cvut.kbss.termit.util.longrunning.LongRunningTaskStatus; +import org.springframework.context.ApplicationEvent; +import org.springframework.lang.NonNull; + +/** + * Indicates a status change of a long-running task. + */ +public class LongRunningTaskChangedEvent extends ApplicationEvent { + + private final LongRunningTaskStatus status; + + public LongRunningTaskChangedEvent(@NonNull Object source, final @NonNull LongRunningTaskStatus status) { + super(source); + this.status = status; + } + + public @NonNull LongRunningTaskStatus getStatus() { + return status; + } +} diff --git a/src/main/java/cz/cvut/kbss/termit/event/TermDefinitionTextAnalysisFinishedEvent.java b/src/main/java/cz/cvut/kbss/termit/event/TermDefinitionTextAnalysisFinishedEvent.java new file mode 100644 index 000000000..748d7a075 --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/event/TermDefinitionTextAnalysisFinishedEvent.java @@ -0,0 +1,22 @@ +package cz.cvut.kbss.termit.event; + +import cz.cvut.kbss.termit.model.AbstractTerm; +import org.springframework.lang.NonNull; + +import java.net.URI; + +/** + * Indicates that a text analysis of a term definition was finished + */ +public class TermDefinitionTextAnalysisFinishedEvent extends VocabularyEvent { + private final URI termUri; + + public TermDefinitionTextAnalysisFinishedEvent(@NonNull Object source, @NonNull AbstractTerm term) { + super(source, term.getVocabulary()); + this.termUri = term.getUri(); + } + + public URI getTermUri() { + return termUri; + } +} diff --git a/src/main/java/cz/cvut/kbss/termit/event/VocabularyContentModified.java b/src/main/java/cz/cvut/kbss/termit/event/VocabularyContentModifiedEvent.java similarity index 73% rename from src/main/java/cz/cvut/kbss/termit/event/VocabularyContentModified.java rename to src/main/java/cz/cvut/kbss/termit/event/VocabularyContentModifiedEvent.java index 8eed780fd..324c1d45c 100644 --- a/src/main/java/cz/cvut/kbss/termit/event/VocabularyContentModified.java +++ b/src/main/java/cz/cvut/kbss/termit/event/VocabularyContentModifiedEvent.java @@ -17,7 +17,7 @@ */ package cz.cvut.kbss.termit.event; -import org.springframework.context.ApplicationEvent; +import org.springframework.lang.NonNull; import java.net.URI; @@ -26,16 +26,9 @@ *

* This typically means a term is added, removed or modified. Modification of vocabulary metadata themselves is not considered here. */ -public class VocabularyContentModified extends ApplicationEvent { +public class VocabularyContentModifiedEvent extends VocabularyEvent { - private final URI vocabularyIri; - - public VocabularyContentModified(Object source, URI vocabularyIri) { - super(source); - this.vocabularyIri = vocabularyIri; - } - - public URI getVocabularyIri() { - return vocabularyIri; + public VocabularyContentModifiedEvent(@NonNull Object source, @NonNull URI vocabularyIri) { + super(source, vocabularyIri); } } diff --git a/src/main/java/cz/cvut/kbss/termit/event/VocabularyCreatedEvent.java b/src/main/java/cz/cvut/kbss/termit/event/VocabularyCreatedEvent.java index e1da1aeab..704169105 100644 --- a/src/main/java/cz/cvut/kbss/termit/event/VocabularyCreatedEvent.java +++ b/src/main/java/cz/cvut/kbss/termit/event/VocabularyCreatedEvent.java @@ -17,14 +17,16 @@ */ package cz.cvut.kbss.termit.event; -import org.springframework.context.ApplicationEvent; +import org.springframework.lang.NonNull; + +import java.net.URI; /** * Indicates that a vocabulary has been created. */ -public class VocabularyCreatedEvent extends ApplicationEvent { +public class VocabularyCreatedEvent extends VocabularyEvent { - public VocabularyCreatedEvent(Object source) { - super(source); + public VocabularyCreatedEvent(@NonNull Object source, @NonNull URI vocabularyIri) { + super(source, vocabularyIri); } } diff --git a/src/main/java/cz/cvut/kbss/termit/event/VocabularyEvent.java b/src/main/java/cz/cvut/kbss/termit/event/VocabularyEvent.java new file mode 100644 index 000000000..133afe2f5 --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/event/VocabularyEvent.java @@ -0,0 +1,28 @@ +package cz.cvut.kbss.termit.event; + +import org.springframework.context.ApplicationEvent; +import org.springframework.lang.NonNull; + +import java.net.URI; +import java.util.Objects; + +/** + * Base class for vocabulary related events + */ +public abstract class VocabularyEvent extends ApplicationEvent { + protected final URI vocabularyIri; + + protected VocabularyEvent(@NonNull Object source, @NonNull URI vocabularyIri) { + super(source); + Objects.requireNonNull(vocabularyIri); + this.vocabularyIri = vocabularyIri; + } + + /** + * The identifier of the vocabulary to which this event is bound + * @return vocabulary IRI + */ + public URI getVocabularyIri() { + return vocabularyIri; + } +} diff --git a/src/main/java/cz/cvut/kbss/termit/event/VocabularyValidationFinishedEvent.java b/src/main/java/cz/cvut/kbss/termit/event/VocabularyValidationFinishedEvent.java new file mode 100644 index 000000000..a5af0bbe8 --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/event/VocabularyValidationFinishedEvent.java @@ -0,0 +1,49 @@ +package cz.cvut.kbss.termit.event; + +import cz.cvut.kbss.termit.model.validation.ValidationResult; +import org.springframework.lang.NonNull; + +import java.net.URI; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +/** + * Indicates that validation for a set of vocabularies was finished. + */ +public class VocabularyValidationFinishedEvent extends VocabularyEvent { + + /** + * Vocabulary closure of {@link #vocabularyIri}. + * IRIs of vocabularies that are imported by {@link #vocabularyIri} and were part of the validation. + */ + private final List vocabularyIris; + + private final List validationResults; + + /** + * @param source the source of the event + * @param originVocabularyIri Vocabulary closure of {@link #vocabularyIri}. + * @param vocabularyIris IRI of the vocabulary on which the validation was triggered. + * @param validationResults results of the validation + */ + public VocabularyValidationFinishedEvent(@NonNull Object source, @NonNull URI originVocabularyIri, + @NonNull Collection vocabularyIris, + @NonNull List validationResults) { + super(source, originVocabularyIri); + // defensive copy + this.vocabularyIris = new ArrayList<>(vocabularyIris); + this.validationResults = new ArrayList<>(validationResults); + } + + @NonNull + public List getVocabularyIris() { + return Collections.unmodifiableList(vocabularyIris); + } + + @NonNull + public List getValidationResults() { + return Collections.unmodifiableList(validationResults); + } +} diff --git a/src/main/java/cz/cvut/kbss/termit/event/VocabularyWillBeRemovedEvent.java b/src/main/java/cz/cvut/kbss/termit/event/VocabularyWillBeRemovedEvent.java index 3fed1f16e..0e0b6503a 100644 --- a/src/main/java/cz/cvut/kbss/termit/event/VocabularyWillBeRemovedEvent.java +++ b/src/main/java/cz/cvut/kbss/termit/event/VocabularyWillBeRemovedEvent.java @@ -1,21 +1,15 @@ package cz.cvut.kbss.termit.event; -import org.springframework.context.ApplicationEvent; +import org.springframework.lang.NonNull; import java.net.URI; /** * Indicates that a Vocabulary will be removed */ -public class VocabularyWillBeRemovedEvent extends ApplicationEvent { - private final URI vocabulary; +public class VocabularyWillBeRemovedEvent extends VocabularyEvent { - public VocabularyWillBeRemovedEvent(Object source, URI vocabulary) { - super(source); - this.vocabulary = vocabulary; - } - - public URI getVocabulary() { - return vocabulary; + public VocabularyWillBeRemovedEvent(@NonNull Object source, @NonNull URI vocabularyIri) { + super(source, vocabularyIri); } } diff --git a/src/main/java/cz/cvut/kbss/termit/exception/ThrottleAspectException.java b/src/main/java/cz/cvut/kbss/termit/exception/ThrottleAspectException.java new file mode 100644 index 000000000..2f8270bf7 --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/exception/ThrottleAspectException.java @@ -0,0 +1,17 @@ +package cz.cvut.kbss.termit.exception; + +/** + * Indicates wrong usage of {@link cz.cvut.kbss.termit.util.throttle.Throttle} annotation. + * + * @see cz.cvut.kbss.termit.util.throttle.ThrottleAspect + */ +public class ThrottleAspectException extends TermItException { + + public ThrottleAspectException(String message) { + super(message); + } + + public ThrottleAspectException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/src/main/java/cz/cvut/kbss/termit/model/Vocabulary.java b/src/main/java/cz/cvut/kbss/termit/model/Vocabulary.java index 06c8acd58..2198e44f3 100644 --- a/src/main/java/cz/cvut/kbss/termit/model/Vocabulary.java +++ b/src/main/java/cz/cvut/kbss/termit/model/Vocabulary.java @@ -18,6 +18,7 @@ package cz.cvut.kbss.termit.model; import com.fasterxml.jackson.annotation.JsonIgnore; +import cz.cvut.kbss.jopa.exception.LazyLoadingException; import cz.cvut.kbss.jopa.model.MultilingualString; import cz.cvut.kbss.jopa.model.annotations.CascadeType; import cz.cvut.kbss.jopa.model.annotations.FetchType; @@ -236,13 +237,20 @@ public int hashCode() { @Override public String toString() { - return "Vocabulary{" + - getLabel() + - " " + Utils.uriToString(getUri()) + - ", glossary=" + glossary + - (importedVocabularies != null ? - ", importedVocabularies = [" + importedVocabularies.stream().map(Utils::uriToString).collect( - Collectors.joining(", ")) + "]" : "") + - '}'; + String result = "Vocabulary{"+ + getLabel() + " " + + Utils.uriToString(getUri()); + try { + result += ", glossary=" + glossary; + if (importedVocabularies != null) { + result +=", importedVocabularies = [" + + importedVocabularies.stream().map(Utils::uriToString) + .collect(Collectors.joining(", ")) + "]"; + } + } catch (LazyLoadingException e) { + // persistent context not available + } + + return result; } } diff --git a/src/main/java/cz/cvut/kbss/termit/model/validation/ValidationResult.java b/src/main/java/cz/cvut/kbss/termit/model/validation/ValidationResult.java index 331bc461b..ab4f30f9d 100644 --- a/src/main/java/cz/cvut/kbss/termit/model/validation/ValidationResult.java +++ b/src/main/java/cz/cvut/kbss/termit/model/validation/ValidationResult.java @@ -26,12 +26,13 @@ import cz.cvut.kbss.termit.model.Term; import org.topbraid.shacl.vocabulary.SH; +import java.io.Serializable; import java.net.URI; import java.util.Objects; @NonEntity @OWLClass(iri = SH.BASE_URI + "ValidationResult") -public class ValidationResult { +public class ValidationResult implements Serializable { @Id(generated = true) private URI id; diff --git a/src/main/java/cz/cvut/kbss/termit/persistence/dao/TermDao.java b/src/main/java/cz/cvut/kbss/termit/persistence/dao/TermDao.java index 34ded7f67..50d138f16 100644 --- a/src/main/java/cz/cvut/kbss/termit/persistence/dao/TermDao.java +++ b/src/main/java/cz/cvut/kbss/termit/persistence/dao/TermDao.java @@ -28,7 +28,7 @@ import cz.cvut.kbss.termit.event.AssetPersistEvent; import cz.cvut.kbss.termit.event.AssetUpdateEvent; import cz.cvut.kbss.termit.event.EvictCacheEvent; -import cz.cvut.kbss.termit.event.VocabularyContentModified; +import cz.cvut.kbss.termit.event.VocabularyContentModifiedEvent; import cz.cvut.kbss.termit.exception.PersistenceException; import cz.cvut.kbss.termit.model.AbstractTerm; import cz.cvut.kbss.termit.model.Term; @@ -174,7 +174,7 @@ public void persist(Term entity, Vocabulary vocabulary) { entity.setVocabulary(null); // This is inferred em.persist(entity, descriptorFactory.termDescriptor(vocabulary)); evictCachedSubTerms(Collections.emptySet(), entity.getParentTerms()); - eventPublisher.publishEvent(new VocabularyContentModified(this, vocabulary.getUri())); + eventPublisher.publishEvent(new VocabularyContentModifiedEvent(this, vocabulary.getUri())); eventPublisher.publishEvent(new AssetPersistEvent(this, entity)); } catch (RuntimeException e) { throw new PersistenceException(e); @@ -194,7 +194,7 @@ public Term update(Term entity) { eventPublisher.publishEvent(new AssetUpdateEvent(this, entity)); evictCachedSubTerms(original.getParentTerms(), entity.getParentTerms()); final Term result = em.merge(entity, descriptorFactory.termDescriptor(entity)); - eventPublisher.publishEvent(new VocabularyContentModified(this, original.getVocabulary())); + eventPublisher.publishEvent(new VocabularyContentModifiedEvent(this, original.getVocabulary())); return result; } catch (RuntimeException e) { throw new PersistenceException(e); @@ -790,7 +790,7 @@ public List findAllUnused(Vocabulary vocabulary) { public void remove(Term entity) { super.remove(entity); evictCachedSubTerms(entity.getParentTerms(), Collections.emptySet()); - eventPublisher.publishEvent(new VocabularyContentModified(this, entity.getVocabulary())); + eventPublisher.publishEvent(new VocabularyContentModifiedEvent(this, entity.getVocabulary())); } @Override diff --git a/src/main/java/cz/cvut/kbss/termit/persistence/dao/TermOccurrenceDao.java b/src/main/java/cz/cvut/kbss/termit/persistence/dao/TermOccurrenceDao.java index 373991a1e..1b01a46d3 100644 --- a/src/main/java/cz/cvut/kbss/termit/persistence/dao/TermOccurrenceDao.java +++ b/src/main/java/cz/cvut/kbss/termit/persistence/dao/TermOccurrenceDao.java @@ -29,9 +29,9 @@ import cz.cvut.kbss.termit.model.Asset; import cz.cvut.kbss.termit.model.Term; import cz.cvut.kbss.termit.model.assignment.TermOccurrence; -import cz.cvut.kbss.termit.persistence.dao.util.ScheduledContextRemover; import cz.cvut.kbss.termit.persistence.dao.util.SparqlResultToTermOccurrenceMapper; import cz.cvut.kbss.termit.util.Configuration; +import cz.cvut.kbss.termit.util.Utils; import cz.cvut.kbss.termit.util.Vocabulary; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -80,12 +80,9 @@ public class TermOccurrenceDao extends BaseDao { private final Configuration.Persistence config; - private final ScheduledContextRemover contextRemover; - - public TermOccurrenceDao(EntityManager em, Configuration config, ScheduledContextRemover contextRemover) { + public TermOccurrenceDao(EntityManager em, Configuration config) { super(TermOccurrence.class, em); this.config = config.getPersistence(); - this.contextRemover = contextRemover; } /** @@ -258,12 +255,12 @@ public void removeAll(Asset target) { Objects.requireNonNull(target); final URI sourceContext = TermOccurrence.resolveContext(target.getUri()); - final URI targetContext = URI.create(sourceContext + "-for-removal-" + System.currentTimeMillis()); - em.createNativeQuery("MOVE GRAPH ?g TO ?targetContext") - .setParameter("g", sourceContext) - .setParameter("targetContext", targetContext) - .executeUpdate(); - contextRemover.scheduleForRemoval(targetContext); + LOG.debug("Removing all occurrences from {}", sourceContext); + em.createNativeQuery("DROP GRAPH ?context") + .setParameter("context", sourceContext) + .executeUpdate(); + LOG.atDebug().setMessage("Removed all occurrences from {}") + .addArgument(() -> Utils.uriToString(sourceContext)).log(); } /** diff --git a/src/main/java/cz/cvut/kbss/termit/persistence/dao/VocabularyDao.java b/src/main/java/cz/cvut/kbss/termit/persistence/dao/VocabularyDao.java index 1f04fe548..21e5233f4 100644 --- a/src/main/java/cz/cvut/kbss/termit/persistence/dao/VocabularyDao.java +++ b/src/main/java/cz/cvut/kbss/termit/persistence/dao/VocabularyDao.java @@ -45,6 +45,7 @@ import cz.cvut.kbss.termit.service.snapshot.SnapshotProvider; import cz.cvut.kbss.termit.util.Configuration; import cz.cvut.kbss.termit.util.Utils; +import cz.cvut.kbss.termit.util.throttle.CacheableFuture; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; @@ -144,17 +145,17 @@ public Vocabulary getReference(URI id) { /** * Gets identifiers of all vocabularies imported by the specified vocabulary, including transitively imported ones. * - * @param entity Base vocabulary, whose imports should be retrieved + * @param vocabularyIri Identifier of base vocabulary, whose imports should be retrieved * @return Collection of (transitively) imported vocabularies */ - public Collection getTransitivelyImportedVocabularies(Vocabulary entity) { - Objects.requireNonNull(entity); + public Collection getTransitivelyImportedVocabularies(URI vocabularyIri) { + Objects.requireNonNull(vocabularyIri); try { return em.createNativeQuery("SELECT DISTINCT ?imported WHERE {" + "?x ?imports+ ?imported ." + "}", URI.class) .setParameter("imports", URI.create(cz.cvut.kbss.termit.util.Vocabulary.s_p_importuje_slovnik)) - .setParameter("x", entity.getUri()).getResultList(); + .setParameter("x", vocabularyIri).getResultList(); } catch (RuntimeException e) { throw new PersistenceException(e); } @@ -357,11 +358,11 @@ public void refreshLastModified(RefreshLastModifiedEvent event) { } @Transactional - public List validateContents(Vocabulary voc) { + public CacheableFuture> validateContents(URI vocabulary) { final VocabularyContentValidator validator = context.getBean(VocabularyContentValidator.class); - final Collection importClosure = getTransitivelyImportedVocabularies(voc); - importClosure.add(voc.getUri()); - return validator.validate(importClosure); + final Collection importClosure = getTransitivelyImportedVocabularies(vocabulary); + importClosure.add(vocabulary); + return validator.validate(vocabulary, importClosure); } /** diff --git a/src/main/java/cz/cvut/kbss/termit/persistence/dao/util/ScheduledContextRemover.java b/src/main/java/cz/cvut/kbss/termit/persistence/dao/util/ScheduledContextRemover.java deleted file mode 100644 index 326fe0ab0..000000000 --- a/src/main/java/cz/cvut/kbss/termit/persistence/dao/util/ScheduledContextRemover.java +++ /dev/null @@ -1,65 +0,0 @@ -package cz.cvut.kbss.termit.persistence.dao.util; - -import cz.cvut.kbss.jopa.model.EntityManager; -import cz.cvut.kbss.termit.util.Utils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.lang.NonNull; -import org.springframework.scheduling.annotation.Scheduled; -import org.springframework.stereotype.Component; -import org.springframework.transaction.annotation.Transactional; - -import java.net.URI; -import java.util.HashSet; -import java.util.Objects; -import java.util.Set; -import java.util.concurrent.TimeUnit; - -/** - * Drops registered repository contexts at scheduled moments. - *

- * This allows to move time-consuming removal of repository contexts containing a lot of data to times of low system - * activity. - */ -@Component -public class ScheduledContextRemover { - - private static final Logger LOG = LoggerFactory.getLogger(ScheduledContextRemover.class); - - private final EntityManager em; - - private final Set contextsToRemove = new HashSet<>(); - - public ScheduledContextRemover(EntityManager em) { - this.em = em; - } - - /** - * Schedules the specified context identifier for removal at the next execution of the context cleanup. - * - * @param contextUri Identifier of the context to remove - * @see #runContextRemoval() - */ - public synchronized void scheduleForRemoval(@NonNull URI contextUri) { - LOG.debug("Scheduling context {} for removal.", Utils.uriToString(contextUri)); - contextsToRemove.add(Objects.requireNonNull(contextUri)); - } - - /** - * Runs the removal of the registered repository contexts. - *

- * This method is scheduled and should not be invoked manually. - * - * @see #scheduleForRemoval(URI) - */ - @Transactional - @Scheduled(fixedRate = 1, timeUnit = TimeUnit.MINUTES) - public synchronized void runContextRemoval() { - LOG.trace("Running scheduled repository context removal."); - contextsToRemove.forEach(g -> { - LOG.trace("Dropping repository context {}.", Utils.uriToString(g)); - em.createNativeQuery("DROP GRAPH ?g").setParameter("g", g).executeUpdate(); - }); - contextsToRemove.clear(); - } -} diff --git a/src/main/java/cz/cvut/kbss/termit/persistence/validation/ResultCachingValidator.java b/src/main/java/cz/cvut/kbss/termit/persistence/validation/ResultCachingValidator.java index eb757ccfd..1d6cfe406 100644 --- a/src/main/java/cz/cvut/kbss/termit/persistence/validation/ResultCachingValidator.java +++ b/src/main/java/cz/cvut/kbss/termit/persistence/validation/ResultCachingValidator.java @@ -17,24 +17,34 @@ */ package cz.cvut.kbss.termit.persistence.validation; -import cz.cvut.kbss.termit.event.VocabularyContentModified; +import cz.cvut.kbss.termit.event.EvictCacheEvent; +import cz.cvut.kbss.termit.event.VocabularyContentModifiedEvent; +import cz.cvut.kbss.termit.event.VocabularyCreatedEvent; +import cz.cvut.kbss.termit.event.VocabularyEvent; +import cz.cvut.kbss.termit.exception.TermItException; import cz.cvut.kbss.termit.model.validation.ValidationResult; +import cz.cvut.kbss.termit.util.throttle.Throttle; +import cz.cvut.kbss.termit.util.throttle.ThrottledFuture; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Lookup; import org.springframework.context.annotation.Primary; import org.springframework.context.annotation.Profile; import org.springframework.context.event.EventListener; +import org.springframework.lang.NonNull; import org.springframework.stereotype.Component; +import org.springframework.transaction.annotation.Transactional; import java.net.URI; -import java.util.ArrayList; import java.util.Collection; -import java.util.HashSet; +import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; @Component("cachingValidator") @Primary @@ -43,12 +53,72 @@ public class ResultCachingValidator implements VocabularyContentValidator { private static final Logger LOG = LoggerFactory.getLogger(ResultCachingValidator.class); - private final Map, List> validationCache = new ConcurrentHashMap<>(); + /** + * Map of origin vocabulary IRI to vocabulary iri closure of imported vocabularies. + * When the record is missing, the cache is considered as dirty. + */ + private final Map> vocabularyClosure = new ConcurrentHashMap<>(); + private final Map> validationCache = new HashMap<>(); + + /** + * @return true when the cache contents are dirty and should be refreshed; false otherwise. + */ + public boolean isNotDirty(@NonNull URI originVocabularyIri) { + return vocabularyClosure.containsKey(originVocabularyIri); + } + + private Optional> getCached(@NonNull URI originVocabularyIri) { + synchronized (validationCache) { + return Optional.ofNullable(validationCache.get(originVocabularyIri)); + } + } + + @Throttle(value = "{#originVocabularyIri}", name="vocabularyValidation") + @Transactional @Override - public List validate(Collection vocabularyIris) { - final Set copy = new HashSet<>(vocabularyIris); // Defensive copy - return new ArrayList<>(validationCache.computeIfAbsent(copy, uris -> getValidator().validate(vocabularyIris))); + @NonNull + public ThrottledFuture> validate(@NonNull URI originVocabularyIri, @NonNull Collection vocabularyIris) { + final Set iris = Set.copyOf(vocabularyIris); + + if (iris.isEmpty()) { + LOG.warn("Validation of empty IRI list was requested for {}", originVocabularyIri); + return ThrottledFuture.done(List.of()); + } + + Optional> cached = getCached(originVocabularyIri); + if (isNotDirty(originVocabularyIri) && cached.isPresent()) { + return ThrottledFuture.done(cached.get()); + } + + return ThrottledFuture.of(() -> runValidation(originVocabularyIri, iris)).setCachedResult(cached.orElse(null)); + } + + @NonNull + private Collection runValidation(@NonNull URI originVocabularyIri, @NonNull final Set iris) { + Optional> cached = getCached(originVocabularyIri); + if (isNotDirty(originVocabularyIri) && cached.isPresent()) { + return cached.get(); + } + + final Collection results; + try { + // executes real validation + // get is safe here as long as we are on throttled thread from #validate method + results = getValidator().validate(originVocabularyIri, iris).get(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new TermItException(e); + } catch (ExecutionException e) { + throw new TermItException(e.getCause()); + } + + synchronized (validationCache) { + vocabularyClosure.put(originVocabularyIri, Collections.unmodifiableCollection(iris)); + validationCache.put(originVocabularyIri, Collections.unmodifiableCollection(results)); + } + + return results; } @Lookup @@ -56,9 +126,34 @@ Validator getValidator() { return null; // Will be replaced by Spring } - @EventListener - public void evictCache(VocabularyContentModified event) { - LOG.debug("Vocabulary content modified, evicting validation result cache."); - validationCache.clear(); + /** + * Marks cache related to the vocabulary from the event as dirty + */ + @EventListener({VocabularyContentModifiedEvent.class, VocabularyCreatedEvent.class}) + public void markCacheDirty(VocabularyEvent event) { + LOG.debug("Vocabulary content modified, marking cache as dirty for {}.", event.getVocabularyIri()); + // marked as dirty for specified vocabulary + vocabularyClosure.remove(event.getVocabularyIri()); + // now mark all vocabularies importing modified vocabulary as dirty too + synchronized (validationCache) { + vocabularyClosure.keySet().forEach(originVocabularyIri -> { + final Collection closure = vocabularyClosure.get(originVocabularyIri); + if (closure != null && closure.contains(event.getVocabularyIri())) { + vocabularyClosure.remove(originVocabularyIri); + } + }); + if (event instanceof VocabularyCreatedEvent) { + validationCache.remove(event.getVocabularyIri()); + } + } + } + + @EventListener(EvictCacheEvent.class) + public void evictCache() { + LOG.debug("Validation cache cleared"); + synchronized (validationCache) { + vocabularyClosure.clear(); + validationCache.clear(); + } } } diff --git a/src/main/java/cz/cvut/kbss/termit/persistence/validation/Validator.java b/src/main/java/cz/cvut/kbss/termit/persistence/validation/Validator.java index 80775c3b9..b01ac7dbf 100644 --- a/src/main/java/cz/cvut/kbss/termit/persistence/validation/Validator.java +++ b/src/main/java/cz/cvut/kbss/termit/persistence/validation/Validator.java @@ -21,11 +21,14 @@ import cz.cvut.kbss.jopa.model.EntityManager; import cz.cvut.kbss.jopa.model.MultilingualString; import cz.cvut.kbss.jsonld.JsonLd; +import cz.cvut.kbss.termit.event.VocabularyValidationFinishedEvent; import cz.cvut.kbss.termit.exception.TermItException; import cz.cvut.kbss.termit.model.validation.ValidationResult; import cz.cvut.kbss.termit.persistence.context.VocabularyContextMapper; import cz.cvut.kbss.termit.util.Configuration; import cz.cvut.kbss.termit.util.Utils; +import cz.cvut.kbss.termit.util.throttle.Throttle; +import cz.cvut.kbss.termit.util.throttle.ThrottledFuture; import org.apache.jena.rdf.model.Literal; import org.apache.jena.rdf.model.Model; import org.apache.jena.rdf.model.ModelFactory; @@ -39,8 +42,10 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ApplicationEventPublisher; import org.springframework.core.io.Resource; import org.springframework.core.io.support.PathMatchingResourcePatternResolver; +import org.springframework.lang.NonNull; import org.springframework.stereotype.Component; import org.springframework.transaction.annotation.Transactional; @@ -79,16 +84,17 @@ public class Validator implements VocabularyContentValidator { private final EntityManager em; private final VocabularyContextMapper vocabularyContextMapper; + private final ApplicationEventPublisher eventPublisher; - private com.github.sgov.server.Validator validator; private Model validationModel; @Autowired public Validator(EntityManager em, VocabularyContextMapper vocabularyContextMapper, - Configuration config) { + Configuration config, ApplicationEventPublisher eventPublisher) { this.em = em; this.vocabularyContextMapper = vocabularyContextMapper; + this.eventPublisher = eventPublisher; initValidator(config.getPersistence().getLanguage()); } @@ -102,8 +108,7 @@ public Validator(EntityManager em, */ private void initValidator(String language) { try { - this.validator = new com.github.sgov.server.Validator(); - this.validationModel = initValidationModel(validator, language); + this.validationModel = initValidationModel(new com.github.sgov.server.Validator(), language); } catch (IOException e) { throw new TermItException("Unable to initialize validator.", e); } @@ -138,27 +143,44 @@ private void loadOverrideRules(Model validationModel, String language) throws IO } } + @Throttle(value = "{#originVocabularyIri}", name = "vocabularyValidation") @Transactional(readOnly = true) @Override - public List validate(final Collection vocabularyIris) { + @NonNull + public ThrottledFuture> validate(final @NonNull URI originVocabularyIri, final @NonNull Collection vocabularyIris) { + if (vocabularyIris.isEmpty()) { + return ThrottledFuture.done(List.of()); + } + + return ThrottledFuture.of(() -> { + final List results = runValidation(vocabularyIris); + eventPublisher.publishEvent(new VocabularyValidationFinishedEvent(this, originVocabularyIri, vocabularyIris, results)); + return results; + }); + } + + protected synchronized List runValidation(@NonNull Collection vocabularyIris) { LOG.debug("Validating {}", vocabularyIris); try { + LOG.trace("Constructing model from RDF4J repository..."); final Model dataModel = getModelFromRdf4jRepository(vocabularyIris); - org.topbraid.shacl.validation.ValidationReport report = validator.validate(dataModel, validationModel); + LOG.trace("Model constructed, running validation..."); + org.topbraid.shacl.validation.ValidationReport report = new com.github.sgov.server.Validator() + .validate(dataModel, validationModel); LOG.debug("Done."); return report.results().stream() .sorted(new ValidationResultSeverityComparator()).map(result -> { final URI termUri = URI.create(result.getFocusNode().toString()); final URI severity = URI.create(result.getSeverity().getURI()); final URI errorUri = result.getSourceShape().isURIResource() ? - URI.create(result.getSourceShape().getURI()) : null; + URI.create(result.getSourceShape().getURI()) : null; final URI resultPath = result.getPath() != null && result.getPath().isURIResource() ? - URI.create(result.getPath().getURI()) : null; + URI.create(result.getPath().getURI()) : null; final MultilingualString messages = new MultilingualString(result.getMessages().stream() .map(RDFNode::asLiteral) .collect(Collectors.toMap( lit -> lit.getLanguage().isBlank() ? - JsonLd.NONE : lit.getLanguage(), + JsonLd.NONE : lit.getLanguage(), Literal::getLexicalForm))); return new ValidationResult() diff --git a/src/main/java/cz/cvut/kbss/termit/persistence/validation/VocabularyContentValidator.java b/src/main/java/cz/cvut/kbss/termit/persistence/validation/VocabularyContentValidator.java index 85ce431f0..54ffa94ae 100644 --- a/src/main/java/cz/cvut/kbss/termit/persistence/validation/VocabularyContentValidator.java +++ b/src/main/java/cz/cvut/kbss/termit/persistence/validation/VocabularyContentValidator.java @@ -18,10 +18,11 @@ package cz.cvut.kbss.termit.persistence.validation; import cz.cvut.kbss.termit.model.validation.ValidationResult; +import cz.cvut.kbss.termit.util.throttle.ThrottledFuture; +import org.springframework.lang.NonNull; import java.net.URI; import java.util.Collection; -import java.util.List; /** * Allows validating the content of vocabularies based on preconfigured rules. @@ -33,8 +34,10 @@ public interface VocabularyContentValidator { *

* The vocabularies are validated together, as a single unit. * - * @param vocabularyIris Vocabulary identifiers + * @param originVocabularyIri the origin vocabulary IRI + * @param vocabularyIris Vocabulary identifiers (including {@code originVocabularyIri} * @return List of violations of validation rules. Empty list if there are not violations */ - List validate(final Collection vocabularyIris); + @NonNull + ThrottledFuture> validate(@NonNull URI originVocabularyIri, @NonNull Collection vocabularyIris); } diff --git a/src/main/java/cz/cvut/kbss/termit/rest/ResourceController.java b/src/main/java/cz/cvut/kbss/termit/rest/ResourceController.java index f389a328a..11bb65415 100644 --- a/src/main/java/cz/cvut/kbss/termit/rest/ResourceController.java +++ b/src/main/java/cz/cvut/kbss/termit/rest/ResourceController.java @@ -148,8 +148,8 @@ public ResponseEntity getContent( try { final Optional timestamp = at.map(RestUtils::parseTimestamp); final TypeAwareResource content = resourceService.getContent(resource, - new ResourceRetrievalSpecification(timestamp, - withoutUnconfirmedOccurrences)); + new ResourceRetrievalSpecification(timestamp, + withoutUnconfirmedOccurrences)); final ResponseEntity.BodyBuilder builder = ResponseEntity.ok() .contentLength(content.contentLength()) .contentType(MediaType.parseMediaType( @@ -172,23 +172,24 @@ public ResponseEntity getContent( }) @PutMapping(value = "/{localName}/content") @ResponseStatus(HttpStatus.NO_CONTENT) - public void saveContent(@Parameter(description = ResourceControllerDoc.ID_LOCAL_NAME_DESCRIPTION, - example = ResourceControllerDoc.ID_LOCAL_NAME_EXAMPLE) - @PathVariable String localName, - @Parameter(description = ResourceControllerDoc.ID_NAMESPACE_DESCRIPTION, - example = ResourceControllerDoc.ID_NAMESPACE_EXAMPLE) - @RequestParam(name = QueryParams.NAMESPACE, required = false) Optional namespace, - @Parameter(description = "File with the new content.") - @RequestParam(name = "file") MultipartFile attachment) { + public Void saveContent(@Parameter(description = ResourceControllerDoc.ID_LOCAL_NAME_DESCRIPTION, + example = ResourceControllerDoc.ID_LOCAL_NAME_EXAMPLE) + @PathVariable String localName, + @Parameter(description = ResourceControllerDoc.ID_NAMESPACE_DESCRIPTION, + example = ResourceControllerDoc.ID_NAMESPACE_EXAMPLE) + @RequestParam(name = QueryParams.NAMESPACE, + required = false) Optional namespace, + @Parameter(description = "File with the new content.") + @RequestParam(name = "file") MultipartFile attachment) { + final Resource resource = getResource(localName, namespace); try { resourceService.saveContent(resource, attachment.getInputStream()); } catch (IOException e) { - throw new TermItException( - "Unable to read file (fileName=\"" + attachment.getOriginalFilename() + "\") content from request.", - e); + throw new TermItException("Unable to read file (fileName=\"" + attachment.getOriginalFilename() + "\") content from request.", e); } LOG.debug("Content saved for resource {}.", resource); + return null; } @Operation(security = {@SecurityRequirement(name = "bearer-key")}, @@ -212,8 +213,8 @@ public ResponseEntity hasContent(@Parameter(description = ResourceControll return ResponseEntity.notFound().build(); } else { final String contentType = resourceService.getContent(r, - new ResourceRetrievalSpecification(Optional.empty(), - false)) + new ResourceRetrievalSpecification(Optional.empty(), + false)) .getMediaType().orElse(null); return ResponseEntity.noContent().header(HttpHeaders.CONTENT_TYPE, contentType).build(); } @@ -297,7 +298,7 @@ public void removeFileFromDocument(@Parameter(description = ResourceControllerDo } @Operation(security = {@SecurityRequirement(name = "bearer-key")}, - description = "Runs text analysis on the content of the resource with the specified identifier.") + description = "Runs text analysis on the content of the resource with the specified identifier. Analysis will be performed asynchronously sometime in the future.") @ApiResponses({ @ApiResponse(responseCode = "204", description = "Text analysis executed."), @ApiResponse(responseCode = "404", description = ResourceControllerDoc.ID_NOT_FOUND_DESCRIPTION), @@ -306,19 +307,18 @@ public void removeFileFromDocument(@Parameter(description = ResourceControllerDo @PutMapping(value = "/{localName}/text-analysis") @ResponseStatus(HttpStatus.NO_CONTENT) public void runTextAnalysis(@Parameter(description = ResourceControllerDoc.ID_LOCAL_NAME_DESCRIPTION, - example = ResourceControllerDoc.ID_LOCAL_NAME_EXAMPLE) - @PathVariable String localName, - @Parameter(description = ResourceControllerDoc.ID_NAMESPACE_DESCRIPTION, - example = ResourceControllerDoc.ID_NAMESPACE_EXAMPLE) - @RequestParam(name = QueryParams.NAMESPACE, - required = false) Optional namespace, - @Parameter( - description = "Identifiers of vocabularies whose terms are used to seed text analysis.") - @RequestParam(name = "vocabulary", required = false, - defaultValue = "") Set vocabularies) { + example = ResourceControllerDoc.ID_LOCAL_NAME_EXAMPLE) + @PathVariable String localName, + @Parameter(description = ResourceControllerDoc.ID_NAMESPACE_DESCRIPTION, + example = ResourceControllerDoc.ID_NAMESPACE_EXAMPLE) + @RequestParam(name = QueryParams.NAMESPACE, + required = false) Optional namespace, + @Parameter( + description = "Identifiers of vocabularies whose terms are used to seed text analysis.") + @RequestParam(name = "vocabulary", required = false, + defaultValue = "") Set vocabularies) { final Resource resource = getResource(localName, namespace); resourceService.runTextAnalysis(resource, vocabularies); - LOG.debug("Text analysis finished for resource {}.", resource); } @Operation(security = {@SecurityRequirement(name = "bearer-key")}, @@ -367,10 +367,15 @@ public List getHistory( * A couple of constants for the {@link ResourceController} API documentation. */ private static final class ResourceControllerDoc { + private static final String ID_LOCAL_NAME_DESCRIPTION = "Locally (in the context of the specified namespace/default resource namespace) unique part of the resource identifier."; + private static final String ID_LOCAL_NAME_EXAMPLE = "mpp-draft.html"; + private static final String ID_NAMESPACE_DESCRIPTION = "Identifier namespace. Allows to override the default resource identifier namespace."; + private static final String ID_NAMESPACE_EXAMPLE = "http://onto.fel.cvut.cz/ontologies/zdroj/"; + private static final String ID_NOT_FOUND_DESCRIPTION = "Resource with the specified identifier not found."; } } diff --git a/src/main/java/cz/cvut/kbss/termit/rest/TermController.java b/src/main/java/cz/cvut/kbss/termit/rest/TermController.java index beec58274..9fc059aa9 100644 --- a/src/main/java/cz/cvut/kbss/termit/rest/TermController.java +++ b/src/main/java/cz/cvut/kbss/termit/rest/TermController.java @@ -220,14 +220,16 @@ public ResponseEntity checkTerms( @Parameter(description = "Language of the label.") @RequestParam(name = "language", required = false) String language) { final URI vocabularyUri = getVocabularyUri(namespace, localName); - final Vocabulary vocabulary = termService.getVocabularyReference(vocabularyUri); - if (prefLabel != null) { - final boolean exists = termService.existsInVocabulary(prefLabel, vocabulary, language); - return new ResponseEntity<>(exists ? HttpStatus.OK : HttpStatus.NOT_FOUND); - } else { - final Integer count = termService.getTermCount(vocabulary); - return ResponseEntity.ok().header(Constants.X_TOTAL_COUNT_HEADER, count.toString()).build(); - } + + final Vocabulary vocabulary = termService.getVocabularyReference(vocabularyUri); + if (prefLabel != null) { + final boolean exists = termService.existsInVocabulary(prefLabel, vocabulary, language); + return new ResponseEntity<>(exists ? HttpStatus.OK : HttpStatus.NOT_FOUND); + } else { + final Integer count = termService.getTermCount(vocabulary); + return ResponseEntity.ok().header(Constants.X_TOTAL_COUNT_HEADER, count.toString()).build(); + } + } private Vocabulary getVocabulary(URI vocabularyUri) { @@ -270,11 +272,13 @@ public List getAllRoots( @Parameter( description = "Identifiers of terms that should be included in the response (regardless of whether they are root terms or not).") @RequestParam(name = "includeTerms", required = false, defaultValue = "") List includeTerms) { + final Vocabulary vocabulary = getVocabulary(getVocabularyUri(namespace, localName)); return includeImported ? - termService - .findAllRootsIncludingImported(vocabulary, createPageRequest(pageSize, pageNo), includeTerms) : - termService.findAllRoots(vocabulary, createPageRequest(pageSize, pageNo), includeTerms); + termService + .findAllRootsIncludingImported(vocabulary, createPageRequest(pageSize, pageNo), includeTerms) : + termService.findAllRoots(vocabulary, createPageRequest(pageSize, pageNo), includeTerms); + } @Operation(security = {@SecurityRequirement(name = "bearer-key")}, @@ -608,8 +612,8 @@ public void runTextAnalysisOnTerm( @PathVariable String termLocalName, @Parameter(description = ApiDoc.ID_NAMESPACE_DESCRIPTION, example = ApiDoc.ID_NAMESPACE_EXAMPLE) @RequestParam(name = QueryParams.NAMESPACE, required = false) Optional namespace) { - termService.analyzeTermDefinition(getById(localName, termLocalName, namespace), - getVocabularyUri(namespace, localName)); + LOG.warn("Called legacy endpoint intended for internal use or testing only! (/vocabularies/{}/terms/{}/text-analysis)", localName, termLocalName); + termService.analyzeTermDefinition(getById(localName, termLocalName, namespace), getVocabularyUri(namespace, localName)); } @Operation(security = {@SecurityRequirement(name = "bearer-key")}, diff --git a/src/main/java/cz/cvut/kbss/termit/rest/VocabularyController.java b/src/main/java/cz/cvut/kbss/termit/rest/VocabularyController.java index 881e6b71b..c03272516 100644 --- a/src/main/java/cz/cvut/kbss/termit/rest/VocabularyController.java +++ b/src/main/java/cz/cvut/kbss/termit/rest/VocabularyController.java @@ -381,11 +381,11 @@ public void removeVocabulary(@Parameter(description = ApiDoc.ID_LOCAL_NAME_DESCR @GetMapping(value = "/{localName}/relations") public List relations(@Parameter(description = ApiDoc.ID_LOCAL_NAME_DESCRIPTION, example = ApiDoc.ID_LOCAL_NAME_EXAMPLE) - @PathVariable String localName, + @PathVariable String localName, @Parameter(description = ApiDoc.ID_NAMESPACE_DESCRIPTION, - example = ApiDoc.ID_NAMESPACE_EXAMPLE) - @RequestParam(name = QueryParams.NAMESPACE, - required = false) Optional namespace) { + example = ApiDoc.ID_NAMESPACE_EXAMPLE) + @RequestParam(name = QueryParams.NAMESPACE, + required = false) Optional namespace) { final URI identifier = resolveIdentifier(namespace.orElse(config.getNamespace().getVocabulary()), localName); final Vocabulary vocabulary = vocabularyService.findRequired(identifier); @@ -401,11 +401,11 @@ public List relations(@Parameter(description = ApiDoc.ID_LOCAL_NA @GetMapping(value = "/{localName}/terms/relations") public List termsRelations(@Parameter(description = ApiDoc.ID_LOCAL_NAME_DESCRIPTION, example = ApiDoc.ID_LOCAL_NAME_EXAMPLE) - @PathVariable String localName, + @PathVariable String localName, @Parameter(description = ApiDoc.ID_NAMESPACE_DESCRIPTION, - example = ApiDoc.ID_NAMESPACE_EXAMPLE) - @RequestParam(name = QueryParams.NAMESPACE, - required = false) Optional namespace) { + example = ApiDoc.ID_NAMESPACE_EXAMPLE) + @RequestParam(name = QueryParams.NAMESPACE, + required = false) Optional namespace) { final URI identifier = resolveIdentifier(namespace.orElse(config.getNamespace().getVocabulary()), localName); final Vocabulary vocabulary = vocabularyService.findRequired(identifier); diff --git a/src/main/java/cz/cvut/kbss/termit/rest/handler/RestExceptionHandler.java b/src/main/java/cz/cvut/kbss/termit/rest/handler/RestExceptionHandler.java index 03d50a199..1a304d8bf 100644 --- a/src/main/java/cz/cvut/kbss/termit/rest/handler/RestExceptionHandler.java +++ b/src/main/java/cz/cvut/kbss/termit/rest/handler/RestExceptionHandler.java @@ -44,13 +44,15 @@ import org.slf4j.LoggerFactory; import org.springframework.http.HttpStatus; import org.springframework.http.ResponseEntity; -import org.springframework.security.access.AccessDeniedException; import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.userdetails.UsernameNotFoundException; import org.springframework.web.bind.annotation.ExceptionHandler; import org.springframework.web.bind.annotation.RestControllerAdvice; +import org.springframework.web.context.request.async.AsyncRequestNotUsableException; import org.springframework.web.multipart.MaxUploadSizeExceededException; +import static cz.cvut.kbss.termit.util.ExceptionUtils.isCausedBy; + /** * Exception handlers for REST controllers. *

@@ -80,7 +82,10 @@ private static void logException(Throwable ex, HttpServletRequest request) { } private static void logException(String message, Throwable ex) { - LOG.error(message, ex); + // Prevents exceptions caused by broken connection with a client from logging + if (!isCausedBy(ex, AsyncRequestNotUsableException.class)) { + LOG.error(message, ex); + } } private static ErrorInfo errorInfo(HttpServletRequest request, Throwable e) { @@ -132,21 +137,7 @@ public ResponseEntity authorizationException(HttpServletRequest reque public ResponseEntity authenticationException(HttpServletRequest request, AuthenticationException e) { LOG.warn("Authentication failure during HTTP request to {}: {}", request.getRequestURI(), e.getMessage()); LOG.atDebug().setCause(e).log(e.getMessage()); - return new ResponseEntity<>(errorInfo(request, e), HttpStatus.FORBIDDEN); - } - - /** - * Fired, for example, on method security violation - */ - @ExceptionHandler(AccessDeniedException.class) - public ResponseEntity accessDeniedException(HttpServletRequest request, AccessDeniedException e) { - LOG.atWarn().setMessage("[{}] Unauthorized access: {}").addArgument(() -> { - if (request.getUserPrincipal() != null) { - return request.getUserPrincipal().getName(); - } - return "(unknown user)"; - }).addArgument(e.getMessage()).log(); - return new ResponseEntity<>(errorInfo(request, e), HttpStatus.FORBIDDEN); + return new ResponseEntity<>(errorInfo(request, e), HttpStatus.UNAUTHORIZED); } @ExceptionHandler(ValidationException.class) diff --git a/src/main/java/cz/cvut/kbss/termit/security/JwtUtils.java b/src/main/java/cz/cvut/kbss/termit/security/JwtUtils.java index 3d8b52c4c..e10d16462 100644 --- a/src/main/java/cz/cvut/kbss/termit/security/JwtUtils.java +++ b/src/main/java/cz/cvut/kbss/termit/security/JwtUtils.java @@ -64,15 +64,16 @@ public class JwtUtils { private final ObjectMapper objectMapper; - private final Key key; - private final JwtParser jwtParser; + private final Key key; + @Autowired public JwtUtils(@Qualifier("objectMapper") ObjectMapper objectMapper, Configuration config) { this.objectMapper = objectMapper; this.key = Utils.isBlank(config.getJwt().getSecretKey()) ? Keys.secretKeyFor(SIGNATURE_ALGORITHM) : Keys.hmacShaKeyFor(config.getJwt().getSecretKey().getBytes(StandardCharsets.UTF_8)); + this.jwtParser = Jwts.parserBuilder().setSigningKey(key) .deserializeJsonWith(new JacksonDeserializer<>(objectMapper)) .build(); diff --git a/src/main/java/cz/cvut/kbss/termit/service/business/ResourceService.java b/src/main/java/cz/cvut/kbss/termit/service/business/ResourceService.java index faa633d75..f3d7a9cc7 100644 --- a/src/main/java/cz/cvut/kbss/termit/service/business/ResourceService.java +++ b/src/main/java/cz/cvut/kbss/termit/service/business/ResourceService.java @@ -98,7 +98,7 @@ public ResourceService(ResourceRepositoryService repositoryService, DocumentMana */ @EventListener public void onVocabularyRemoval(VocabularyWillBeRemovedEvent event) { - vocabularyService.find(event.getVocabulary()).ifPresent(vocabulary -> { + vocabularyService.find(event.getVocabularyIri()).ifPresent(vocabulary -> { if(vocabulary.getDocument() != null) { remove(vocabulary.getDocument()); } diff --git a/src/main/java/cz/cvut/kbss/termit/service/business/TermService.java b/src/main/java/cz/cvut/kbss/termit/service/business/TermService.java index 983f00eb1..db3bb6564 100644 --- a/src/main/java/cz/cvut/kbss/termit/service/business/TermService.java +++ b/src/main/java/cz/cvut/kbss/termit/service/business/TermService.java @@ -42,6 +42,7 @@ import cz.cvut.kbss.termit.util.Configuration; import cz.cvut.kbss.termit.util.TypeAwareResource; import cz.cvut.kbss.termit.util.Utils; +import cz.cvut.kbss.termit.util.throttle.Throttle; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; @@ -50,6 +51,7 @@ import org.springframework.security.access.prepost.PostAuthorize; import org.springframework.security.access.prepost.PreAuthorize; import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Propagation; import org.springframework.transaction.annotation.Transactional; import java.net.URI; @@ -374,10 +376,7 @@ public void persistRoot(Term term, Vocabulary owner) { Objects.requireNonNull(owner); languageService.getInitialTermState().ifPresent(is -> term.setState(is.getUri())); repositoryService.addRootTermToVocabulary(term, owner); - if (!config.getTextAnalysis().isDisableVocabularyAnalysisOnTermEdit()) { - analyzeTermDefinition(term, owner.getUri()); - vocabularyService.runTextAnalysisOnAllTerms(owner); - } + vocabularyService.runTextAnalysisOnAllTerms(owner); } /** @@ -392,10 +391,7 @@ public void persistChild(Term child, Term parent) { Objects.requireNonNull(parent); languageService.getInitialTermState().ifPresent(is -> child.setState(is.getUri())); repositoryService.addChildTerm(child, parent); - if (!config.getTextAnalysis().isDisableVocabularyAnalysisOnTermEdit()) { - analyzeTermDefinition(child, parent.getVocabulary()); - vocabularyService.runTextAnalysisOnAllTerms(findVocabularyRequired(parent.getVocabulary())); - } + vocabularyService.runTextAnalysisOnAllTerms(findVocabularyRequired(parent.getVocabulary())); } /** @@ -412,11 +408,14 @@ public Term update(Term term) { checkForInvalidTerminalStateAssignment(original, term.getState()); // Ensure the change is merged into the repo before analyzing other terms final Term result = repositoryService.update(term); - if (!Objects.equals(original.getDefinition(), term.getDefinition()) && !config.getTextAnalysis().isDisableVocabularyAnalysisOnTermEdit()) { - analyzeTermDefinition(term, original.getVocabulary()); - } - if (!Objects.equals(original.getLabel(), term.getLabel()) && !config.getTextAnalysis().isDisableVocabularyAnalysisOnTermEdit()) { - vocabularyService.runTextAnalysisOnAllTerms(getVocabularyReference(original.getVocabulary())); + + // if the label changed, run analysis on all terms in the vocabulary + if (!Objects.equals(original.getLabel(), result.getLabel())) { + vocabularyService.runTextAnalysisOnAllTerms(getVocabularyReference(result.getVocabulary())); + // if all terms have not been analyzed, check if the definition has changed, + // and if so, perform an analysis for the term definition + } else if (!Objects.equals(original.getDefinition(), result.getDefinition())) { + analyzeTermDefinition(result, result.getVocabulary()); } return result; } @@ -441,8 +440,13 @@ public void remove(@NonNull Term term) { * @param term Term to analyze * @param vocabularyIri Identifier of the vocabulary used for analysis */ + @Throttle(value = "{#vocabularyIri, #term.getUri()}", + group = "T(ThrottleGroupProvider).getTextAnalysisVocabularyTerm(#vocabulary.getUri(), #term.getUri())", + name="termDefinitionAnalysis") + @Transactional(propagation = Propagation.REQUIRES_NEW) @PreAuthorize("@termAuthorizationService.canModify(#term)") public void analyzeTermDefinition(AbstractTerm term, URI vocabularyIri) { + term = findRequired(term.getUri()); // required when throttling for persistent context Objects.requireNonNull(term); if (term.getDefinition() == null || term.getDefinition().isEmpty()) { return; diff --git a/src/main/java/cz/cvut/kbss/termit/service/business/VocabularyService.java b/src/main/java/cz/cvut/kbss/termit/service/business/VocabularyService.java index 69a2dfc22..1d20cf5b2 100644 --- a/src/main/java/cz/cvut/kbss/termit/service/business/VocabularyService.java +++ b/src/main/java/cz/cvut/kbss/termit/service/business/VocabularyService.java @@ -24,7 +24,9 @@ import cz.cvut.kbss.termit.dto.acl.AccessControlListDto; import cz.cvut.kbss.termit.dto.listing.TermDto; import cz.cvut.kbss.termit.dto.listing.VocabularyDto; +import cz.cvut.kbss.termit.event.VocabularyContentModifiedEvent; import cz.cvut.kbss.termit.event.VocabularyCreatedEvent; +import cz.cvut.kbss.termit.event.VocabularyEvent; import cz.cvut.kbss.termit.exception.NotFoundException; import cz.cvut.kbss.termit.model.Vocabulary; import cz.cvut.kbss.termit.model.acl.AccessControlList; @@ -34,7 +36,6 @@ import cz.cvut.kbss.termit.model.validation.ValidationResult; import cz.cvut.kbss.termit.persistence.context.VocabularyContextMapper; import cz.cvut.kbss.termit.persistence.snapshot.SnapshotCreator; -import cz.cvut.kbss.termit.service.business.async.AsyncTermService; import cz.cvut.kbss.termit.service.changetracking.ChangeRecordProvider; import cz.cvut.kbss.termit.service.export.ExportFormat; import cz.cvut.kbss.termit.service.repository.ChangeRecordService; @@ -45,6 +46,8 @@ import cz.cvut.kbss.termit.util.TypeAwareClasspathResource; import cz.cvut.kbss.termit.util.TypeAwareFileSystemResource; import cz.cvut.kbss.termit.util.TypeAwareResource; +import cz.cvut.kbss.termit.util.throttle.CacheableFuture; +import cz.cvut.kbss.termit.util.throttle.Throttle; import org.jetbrains.annotations.NotNull; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -52,6 +55,7 @@ import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.ApplicationEventPublisherAware; import org.springframework.context.annotation.Lazy; +import org.springframework.context.event.EventListener; import org.springframework.security.access.prepost.PostAuthorize; import org.springframework.security.access.prepost.PostFilter; import org.springframework.security.access.prepost.PreAuthorize; @@ -90,7 +94,7 @@ public class VocabularyService private final ChangeRecordService changeRecordService; - private final AsyncTermService termService; + private final TermService termService; private final VocabularyContextMapper contextMapper; @@ -104,7 +108,7 @@ public class VocabularyService public VocabularyService(VocabularyRepositoryService repositoryService, ChangeRecordService changeRecordService, - @Lazy AsyncTermService termService, + @Lazy TermService termService, VocabularyContextMapper contextMapper, AccessControlListService aclService, VocabularyAuthorizationService authorizationService, @@ -118,6 +122,16 @@ public VocabularyService(VocabularyRepositoryService repositoryService, this.context = context; } + /** + * Receives {@link VocabularyContentModifiedEvent} and triggers validation. + * The goal for this is to get the results cached and do not force users to wait for validation + * when they request it. + */ + @EventListener({VocabularyContentModifiedEvent.class, VocabularyCreatedEvent.class}) + public void onVocabularyContentModified(VocabularyEvent event) { + repositoryService.validateContents(event.getVocabularyIri()); + } + @Override @PostFilter("@vocabularyAuthorizationService.canRead(filterObject)") public List findAll() { @@ -168,7 +182,7 @@ public void persist(Vocabulary instance) { repositoryService.persist(instance); final AccessControlList acl = aclService.createFor(instance); instance.setAcl(acl.getUri()); - eventPublisher.publishEvent(new VocabularyCreatedEvent(instance)); + eventPublisher.publishEvent(new VocabularyCreatedEvent(this, instance.getUri())); } @Override @@ -231,7 +245,7 @@ public Vocabulary importVocabulary(boolean rename, MultipartFile file) { final Vocabulary imported = repositoryService.importVocabulary(rename, file); final AccessControlList acl = aclService.createFor(imported); imported.setAcl(acl.getUri()); - eventPublisher.publishEvent(new VocabularyCreatedEvent(imported)); + eventPublisher.publishEvent(new VocabularyCreatedEvent(this, imported.getUri())); return imported; } @@ -290,8 +304,12 @@ public List getChangesOfContent(Vocabulary vocabulary) { * @param vocabulary Vocabulary to be analyzed */ @Transactional + @Throttle(value = "{#vocabulary.getUri()}", + group = "T(ThrottleGroupProvider).getTextAnalysisVocabularyAllTerms(#vocabulary.getUri())", + name = "allTermsVocabularyAnalysis") @PreAuthorize("@vocabularyAuthorizationService.canModify(#vocabulary)") public void runTextAnalysisOnAllTerms(Vocabulary vocabulary) { + vocabulary = findRequired(vocabulary.getUri()); // required when throttling for persistent context LOG.debug("Analyzing definitions of all terms in vocabulary {} and vocabularies it imports.", vocabulary); SnapshotProvider.verifySnapshotNotModified(vocabulary); final List allTerms = termService.findAll(vocabulary); @@ -299,12 +317,13 @@ public void runTextAnalysisOnAllTerms(Vocabulary vocabulary) { importedVocabulary -> allTerms.addAll(termService.findAll(getReference(importedVocabulary)))); final Map termsToContexts = new HashMap<>(allTerms.size()); allTerms.forEach(t -> termsToContexts.put(t, contextMapper.getVocabularyContext(t.getVocabulary()))); - termService.asyncAnalyzeTermDefinitions(termsToContexts); + termsToContexts.forEach(termService::analyzeTermDefinition); } /** * Runs text analysis on definitions of all terms in all vocabularies. */ + @Throttle(group = "T(ThrottleGroupProvider).getTextAnalysisVocabulariesAll()", name = "allVocabulariesAnalysis") @Transactional public void runTextAnalysisOnAllVocabularies() { LOG.debug("Analyzing definitions of all terms in all vocabularies."); @@ -312,7 +331,7 @@ public void runTextAnalysisOnAllVocabularies() { repositoryService.findAll().forEach(v -> { List terms = termService.findAll(new Vocabulary(v.getUri())); terms.forEach(t -> termsToContexts.put(t, contextMapper.getVocabularyContext(t.getVocabulary()))); - termService.asyncAnalyzeTermDefinitions(termsToContexts); + termsToContexts.forEach(termService::analyzeTermDefinition); }); } @@ -337,10 +356,10 @@ public void remove(Vocabulary asset) { /** * Validates a vocabulary: - it checks glossary rules, - it checks OntoUml constraints. * - * @param validate Vocabulary to validate + * @param vocabulary Vocabulary to validate */ - public List validateContents(Vocabulary validate) { - return repositoryService.validateContents(validate); + public CacheableFuture> validateContents(URI vocabulary) { + return repositoryService.validateContents(vocabulary); } /** @@ -367,7 +386,7 @@ public Integer getTermCount(Vocabulary vocabulary) { @PreAuthorize("@vocabularyAuthorizationService.canCreateSnapshot(#vocabulary)") public Snapshot createSnapshot(Vocabulary vocabulary) { final Snapshot s = getSnapshotCreator().createSnapshot(vocabulary); - eventPublisher.publishEvent(new VocabularyCreatedEvent(s)); + eventPublisher.publishEvent(new VocabularyCreatedEvent(this, s.getUri())); cloneAccessControlList(s, vocabulary); return s; } diff --git a/src/main/java/cz/cvut/kbss/termit/service/business/async/AsyncTermService.java b/src/main/java/cz/cvut/kbss/termit/service/business/async/AsyncTermService.java deleted file mode 100644 index fc807b733..000000000 --- a/src/main/java/cz/cvut/kbss/termit/service/business/async/AsyncTermService.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * TermIt - * Copyright (C) 2023 Czech Technical University in Prague - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ -package cz.cvut.kbss.termit.service.business.async; - -import cz.cvut.kbss.termit.dto.listing.TermDto; -import cz.cvut.kbss.termit.model.AbstractTerm; -import cz.cvut.kbss.termit.model.Vocabulary; -import cz.cvut.kbss.termit.service.business.TermService; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.scheduling.annotation.Async; -import org.springframework.stereotype.Service; - -import java.net.URI; -import java.util.List; -import java.util.Map; - -/** - * Provides asynchronous processing of term-related tasks. - */ -@Service -public class AsyncTermService { - - private static final Logger LOG = LoggerFactory.getLogger(AsyncTermService.class); - - private final TermService termService; - - public AsyncTermService(TermService termService) { - this.termService = termService; - } - - /** - * Gets a list of all terms in the specified vocabulary. - * - * @param vocabulary Vocabulary whose terms to retrieve. A reference is sufficient - * @return List of vocabulary term DTOs - */ - public List findAll(Vocabulary vocabulary) { - return termService.findAll(vocabulary); - } - - /** - * Asynchronously runs text analysis on the definitions of all the specified terms. - *

- * The analysis calls are executed in a sequence, but this method itself is executed asynchronously. - * - * @param termsWithContexts Map of terms to vocabulary context identifiers they belong to - */ - @Async - public void asyncAnalyzeTermDefinitions(Map termsWithContexts) { - LOG.trace("Asynchronously analyzing definitions of {} terms.", termsWithContexts.size()); - termsWithContexts.forEach(termService::analyzeTermDefinition); - } -} diff --git a/src/main/java/cz/cvut/kbss/termit/service/document/AnnotationGenerator.java b/src/main/java/cz/cvut/kbss/termit/service/document/AnnotationGenerator.java index 4333be04a..494263979 100644 --- a/src/main/java/cz/cvut/kbss/termit/service/document/AnnotationGenerator.java +++ b/src/main/java/cz/cvut/kbss/termit/service/document/AnnotationGenerator.java @@ -18,9 +18,12 @@ package cz.cvut.kbss.termit.service.document; import cz.cvut.kbss.termit.exception.AnnotationGenerationException; +import cz.cvut.kbss.termit.exception.TermItException; import cz.cvut.kbss.termit.model.AbstractTerm; +import cz.cvut.kbss.termit.model.Asset; import cz.cvut.kbss.termit.model.assignment.TermOccurrence; import cz.cvut.kbss.termit.model.resource.File; +import cz.cvut.kbss.termit.util.throttle.Throttle; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; @@ -28,7 +31,10 @@ import org.springframework.transaction.annotation.Transactional; import java.io.InputStream; -import java.util.List; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.FutureTask; +import java.util.concurrent.atomic.AtomicBoolean; /** * Creates annotations (term occurrences) for vocabulary terms. @@ -39,6 +45,8 @@ @Service public class AnnotationGenerator { + private static final long THREAD_JOIN_TIMEOUT = 1000L * 60; // 1 minute + private static final Logger LOG = LoggerFactory.getLogger(AnnotationGenerator.class); private final DocumentManager documentManager; @@ -48,8 +56,7 @@ public class AnnotationGenerator { private final TermOccurrenceSaver occurrenceSaver; @Autowired - public AnnotationGenerator(DocumentManager documentManager, - TermOccurrenceResolvers resolvers, + public AnnotationGenerator(DocumentManager documentManager, TermOccurrenceResolvers resolvers, TermOccurrenceSaver occurrenceSaver) { this.documentManager = documentManager; this.resolvers = resolvers; @@ -63,17 +70,63 @@ public AnnotationGenerator(DocumentManager documentManager, * @param source Source file of the annotated document */ @Transactional + @Throttle(value = "{source.getUri()}", name = "documentAnnotationGeneration") public void generateAnnotations(InputStream content, File source) { final TermOccurrenceResolver occurrenceResolver = findResolverFor(source); LOG.debug("Resolving annotations of file {}.", source); occurrenceResolver.parseContent(content, source); occurrenceResolver.setExistingOccurrences(occurrenceSaver.getExistingOccurrences(source)); - final List occurrences = occurrenceResolver.findTermOccurrences(); - saveAnnotatedContent(source, occurrenceResolver.getContent()); - occurrenceSaver.saveOccurrences(occurrences, source); + findAndSaveTermOccurrences(source, occurrenceResolver); LOG.trace("Finished generating annotations for file {}.", source); } + /** + * Calls {@link TermOccurrenceResolver#findTermOccurrences(TermOccurrenceResolver.OccurrenceConsumer)} on {@code #occurrenceResolver} + * creating new thread that will save any found occurrence in parallel. + * Saves annotated content ({@link #saveAnnotatedContent(File, InputStream)} when the source is a {@link File}. + */ + private void findAndSaveTermOccurrences(Asset source, TermOccurrenceResolver occurrenceResolver) { + AtomicBoolean finished = new AtomicBoolean(false); + // alternatively, SynchronousQueue could be used, but this allows to have some space as buffer + final ArrayBlockingQueue toSave = new ArrayBlockingQueue<>(10); + // not limiting the queue size would result in OutOfMemoryError + + FutureTask findTask = new FutureTask<>(() -> { + try { + LOG.trace("Resolving term occurrences for {}.", source); + occurrenceResolver.findTermOccurrences(toSave::put); + LOG.trace("Finished resolving term occurrences for {}.", source); + LOG.trace("Saving term occurrences for {}.", source); + if (source instanceof File sourceFile) { + saveAnnotatedContent(sourceFile, occurrenceResolver.getContent()); + } + LOG.trace("Term occurrences saved for {}.", source); + } finally { + finished.set(true); + } + return null; + }); + Thread finder = new Thread(findTask); + finder.setName("AnnotationGenerator-TermOccurrenceResolver"); + finder.start(); + + occurrenceSaver.saveFromQueue(source, finished, toSave); + + try { + findTask.get(); // propagates exceptions + finder.join(THREAD_JOIN_TIMEOUT); + } catch (InterruptedException e) { + LOG.error("Thread interrupted while saving annotations of file {}.", source); + Thread.currentThread().interrupt(); + throw new TermItException(e); + } catch (ExecutionException e) { + if (e.getCause() instanceof RuntimeException re) { + throw re; + } + throw new TermItException(e); + } + } + private TermOccurrenceResolver findResolverFor(File file) { // This will allow us to potentially support different types of files final TermOccurrenceResolver htmlResolver = resolvers.htmlTermOccurrenceResolver(); @@ -100,8 +153,7 @@ public void generateAnnotations(InputStream content, AbstractTerm annotatedTerm) final TermOccurrenceResolver occurrenceResolver = resolvers.htmlTermOccurrenceResolver(); LOG.debug("Resolving annotations of the definition of {}.", annotatedTerm); occurrenceResolver.parseContent(content, annotatedTerm); - final List occurrences = occurrenceResolver.findTermOccurrences(); - occurrenceSaver.saveOccurrences(occurrences, annotatedTerm); + occurrenceResolver.findTermOccurrences(o -> occurrenceSaver.saveOccurrence(o, annotatedTerm)); LOG.trace("Finished generating annotations for the definition of {}.", annotatedTerm); } } diff --git a/src/main/java/cz/cvut/kbss/termit/service/document/AsynchronousTermOccurrenceSaver.java b/src/main/java/cz/cvut/kbss/termit/service/document/AsynchronousTermOccurrenceSaver.java deleted file mode 100644 index a12186af0..000000000 --- a/src/main/java/cz/cvut/kbss/termit/service/document/AsynchronousTermOccurrenceSaver.java +++ /dev/null @@ -1,42 +0,0 @@ -package cz.cvut.kbss.termit.service.document; - -import cz.cvut.kbss.termit.model.Asset; -import cz.cvut.kbss.termit.model.assignment.TermOccurrence; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.context.annotation.Primary; -import org.springframework.context.annotation.Profile; -import org.springframework.scheduling.annotation.Async; -import org.springframework.stereotype.Service; - -import java.util.List; - -/** - * Saves term occurrences asynchronously. - */ -@Primary -@Service -@Profile("!test") -public class AsynchronousTermOccurrenceSaver implements TermOccurrenceSaver { - - private static final Logger LOG = LoggerFactory.getLogger(AsynchronousTermOccurrenceSaver.class); - - private final SynchronousTermOccurrenceSaver synchronousSaver; - - public AsynchronousTermOccurrenceSaver(SynchronousTermOccurrenceSaver synchronousSaver) { - this.synchronousSaver = synchronousSaver; - } - - @Async - @Override - public void saveOccurrences(List occurrences, Asset source) { - LOG.debug("Asynchronously saving term occurrences for asset {}.", source); - synchronousSaver.saveOccurrences(occurrences, source); - LOG.trace("Finished saving term occurrences for asset {}.", source); - } - - @Override - public List getExistingOccurrences(Asset source) { - return synchronousSaver.getExistingOccurrences(source); - } -} diff --git a/src/main/java/cz/cvut/kbss/termit/service/document/SynchronousTermOccurrenceSaver.java b/src/main/java/cz/cvut/kbss/termit/service/document/SynchronousTermOccurrenceSaver.java deleted file mode 100644 index e8ec00613..000000000 --- a/src/main/java/cz/cvut/kbss/termit/service/document/SynchronousTermOccurrenceSaver.java +++ /dev/null @@ -1,43 +0,0 @@ -package cz.cvut.kbss.termit.service.document; - -import cz.cvut.kbss.termit.model.Asset; -import cz.cvut.kbss.termit.model.assignment.TermOccurrence; -import cz.cvut.kbss.termit.persistence.dao.TermOccurrenceDao; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.springframework.stereotype.Service; -import org.springframework.transaction.annotation.Transactional; - -import java.util.List; - -/** - * Saves occurrences synchronously. - *

- * Existing occurrences are reused if they match. - */ -@Service -public class SynchronousTermOccurrenceSaver implements TermOccurrenceSaver { - - private static final Logger LOG = LoggerFactory.getLogger(SynchronousTermOccurrenceSaver.class); - - private final TermOccurrenceDao termOccurrenceDao; - - public SynchronousTermOccurrenceSaver(TermOccurrenceDao termOccurrenceDao) { - this.termOccurrenceDao = termOccurrenceDao; - } - - @Transactional - @Override - public void saveOccurrences(List occurrences, Asset source) { - LOG.debug("Saving term occurrences for asset {}.", source); - LOG.trace("Removing all existing occurrences in asset {}.", source); - termOccurrenceDao.removeAll(source); - LOG.trace("Persisting new occurrences in {}.", source); - occurrences.stream().filter(o -> !o.getTerm().equals(source.getUri())).forEach(termOccurrenceDao::persist); - } - - @Override - public List getExistingOccurrences(Asset source) { - return termOccurrenceDao.findAllTargeting(source); - } -} diff --git a/src/main/java/cz/cvut/kbss/termit/service/document/TermOccurrenceResolver.java b/src/main/java/cz/cvut/kbss/termit/service/document/TermOccurrenceResolver.java index 55964f5b3..616c0707d 100644 --- a/src/main/java/cz/cvut/kbss/termit/service/document/TermOccurrenceResolver.java +++ b/src/main/java/cz/cvut/kbss/termit/service/document/TermOccurrenceResolver.java @@ -31,6 +31,7 @@ import java.net.URI; import java.util.Collections; import java.util.List; +import java.util.function.Consumer; /** * Base class for resolving term occurrences in an annotated document. @@ -49,7 +50,7 @@ protected TermOccurrenceResolver(TermRepositoryService termService) { * Parses the specified input into some abstract representation from which new terms and term occurrences can be * extracted. *

- * Note that this method has to be called before calling {@link #findTermOccurrences()}. + * Note that this method has to be called before calling {@link #findTermOccurrences(Consumer)}. * * @param input The input to parse * @param source Original source of the input. Used for term occurrence generation @@ -80,10 +81,10 @@ public void setExistingOccurrences(List existingOccurrences) { *

* {@link #parseContent(InputStream, Asset)} has to be called prior to this method. * - * @return List of term occurrences identified in the input + * @param resultConsumer the consumer that will be called for each result * @see #parseContent(InputStream, Asset) */ - public abstract List findTermOccurrences(); + public abstract void findTermOccurrences(OccurrenceConsumer resultConsumer); /** * Checks whether this resolver supports the specified source file type. @@ -102,11 +103,11 @@ public void setExistingOccurrences(List existingOccurrences) { */ protected TermOccurrence createOccurrence(URI termUri, Asset source) { final TermOccurrence occurrence; - if (source instanceof File) { - final FileOccurrenceTarget target = new FileOccurrenceTarget((File) source); + if (source instanceof File file) { + final FileOccurrenceTarget target = new FileOccurrenceTarget(file); occurrence = new TermFileOccurrence(termUri, target); - } else if (source instanceof AbstractTerm) { - final DefinitionalOccurrenceTarget target = new DefinitionalOccurrenceTarget((AbstractTerm) source); + } else if (source instanceof AbstractTerm abstractTerm) { + final DefinitionalOccurrenceTarget target = new DefinitionalOccurrenceTarget(abstractTerm); occurrence = new TermDefinitionalOccurrence(termUri, target); } else { throw new IllegalArgumentException("Unsupported term occurrence source " + source); @@ -114,4 +115,9 @@ protected TermOccurrence createOccurrence(URI termUri, Asset source) { occurrence.markSuggested(); return occurrence; } + + @FunctionalInterface + public interface OccurrenceConsumer { + void accept(TermOccurrence termOccurrence) throws InterruptedException; + } } diff --git a/src/main/java/cz/cvut/kbss/termit/service/document/TermOccurrenceSaver.java b/src/main/java/cz/cvut/kbss/termit/service/document/TermOccurrenceSaver.java index 85286d4bb..9843a9864 100644 --- a/src/main/java/cz/cvut/kbss/termit/service/document/TermOccurrenceSaver.java +++ b/src/main/java/cz/cvut/kbss/termit/service/document/TermOccurrenceSaver.java @@ -1,24 +1,99 @@ package cz.cvut.kbss.termit.service.document; +import cz.cvut.kbss.termit.exception.TermItException; import cz.cvut.kbss.termit.model.Asset; import cz.cvut.kbss.termit.model.assignment.TermOccurrence; +import cz.cvut.kbss.termit.persistence.dao.TermOccurrenceDao; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; import java.util.List; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; /** - * Saves occurrences of terms. + * Saves occurrences synchronously. + *

+ * Existing occurrences are reused if they match. */ -public interface TermOccurrenceSaver { +@Service +public class TermOccurrenceSaver { + + private static final Logger LOG = LoggerFactory.getLogger(TermOccurrenceSaver.class); + + private final TermOccurrenceDao termOccurrenceDao; + + public TermOccurrenceSaver(TermOccurrenceDao termOccurrenceDao) { + this.termOccurrenceDao = termOccurrenceDao; + } /** * Saves the specified occurrences of terms in the specified asset. *

+ * Removes all existing occurrences. + *

* Implementations may reuse existing occurrences if they match the provided ones. * * @param occurrences Occurrences to save * @param source Asset in which the terms occur */ - void saveOccurrences(List occurrences, Asset source); + @Transactional + public void saveOccurrences(List occurrences, Asset source) { + LOG.debug("Saving term occurrences for asset {}.", source); + removeAll(source); + LOG.trace("Persisting new occurrences in {}.", source); + occurrences.stream().filter(o -> !o.getTerm().equals(source.getUri())).forEach(termOccurrenceDao::persist); + } + + public void saveOccurrence(TermOccurrence occurrence, Asset source) { + if (occurrence.getTerm().equals(source.getUri())) { + return; + } + if(!termOccurrenceDao.exists(occurrence.getUri())) { + termOccurrenceDao.persist(occurrence); + } else { + LOG.debug("Occurrence already exists, skipping: {}", occurrence); + } + } + + /** + * Continously saves occurrences from the queue while blocking current thread until + * {@code #finished} is set to {@code true}. + *

+ * Removes all existing occurrences before processing. + * + * @param source Asset in which the terms occur + * @param finished Whether all occurrences were added to the queue + * @param toSave the queue with occurrences to save + */ + @Transactional + public void saveFromQueue(final Asset source, final AtomicBoolean finished, + final BlockingQueue toSave) { + LOG.debug("Saving term occurrences for asset {}.", source); + removeAll(source); + TermOccurrence occurrence; + long count = 0; + try { + while (!finished.get() || !toSave.isEmpty()) { + if (toSave.isEmpty()) { + Thread.yield(); + } + occurrence = toSave.poll(1, TimeUnit.SECONDS); + if (occurrence != null) { + saveOccurrence(occurrence, source); + count++; + } + } + LOG.debug("Saved {} term occurrences for assert {}.", count, source); + } catch (InterruptedException e) { + LOG.error("Thread interrupted while waiting for occurrences to save."); + Thread.currentThread().interrupt(); + throw new TermItException(e); + } + } /** * Gets a list of existing term occurrences in the specified asset. @@ -26,5 +101,12 @@ public interface TermOccurrenceSaver { * @param source Asset in which the terms occur * @return List of existing term occurrences */ - List getExistingOccurrences(Asset source); + public List getExistingOccurrences(Asset source) { + return termOccurrenceDao.findAllTargeting(source); + } + + private void removeAll(Asset source) { + LOG.trace("Removing all existing occurrences in asset {}.", source); + termOccurrenceDao.removeAll(source); + } } diff --git a/src/main/java/cz/cvut/kbss/termit/service/document/TextAnalysisService.java b/src/main/java/cz/cvut/kbss/termit/service/document/TextAnalysisService.java index dbc94dfaf..adc9dfdae 100644 --- a/src/main/java/cz/cvut/kbss/termit/service/document/TextAnalysisService.java +++ b/src/main/java/cz/cvut/kbss/termit/service/document/TextAnalysisService.java @@ -18,6 +18,8 @@ package cz.cvut.kbss.termit.service.document; import cz.cvut.kbss.termit.dto.TextAnalysisInput; +import cz.cvut.kbss.termit.event.FileTextAnalysisFinishedEvent; +import cz.cvut.kbss.termit.event.TermDefinitionTextAnalysisFinishedEvent; import cz.cvut.kbss.termit.exception.WebServiceIntegrationException; import cz.cvut.kbss.termit.model.AbstractTerm; import cz.cvut.kbss.termit.model.TextAnalysisRecord; @@ -25,9 +27,11 @@ import cz.cvut.kbss.termit.persistence.dao.TextAnalysisRecordDao; import cz.cvut.kbss.termit.util.Configuration; import cz.cvut.kbss.termit.util.Utils; +import cz.cvut.kbss.termit.util.throttle.Throttle; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ApplicationEventPublisher; import org.springframework.core.io.Resource; import org.springframework.http.HttpEntity; import org.springframework.http.HttpHeaders; @@ -61,14 +65,18 @@ public class TextAnalysisService { private final TextAnalysisRecordDao recordDao; + private final ApplicationEventPublisher eventPublisher; + @Autowired public TextAnalysisService(RestTemplate restClient, Configuration config, DocumentManager documentManager, - AnnotationGenerator annotationGenerator, TextAnalysisRecordDao recordDao) { + AnnotationGenerator annotationGenerator, TextAnalysisRecordDao recordDao, + ApplicationEventPublisher eventPublisher) { this.restClient = restClient; this.config = config; this.documentManager = documentManager; this.annotationGenerator = annotationGenerator; this.recordDao = recordDao; + this.eventPublisher = eventPublisher; } /** @@ -80,12 +88,15 @@ public TextAnalysisService(RestTemplate restClient, Configuration config, Docume * @param file File whose content shall be analyzed * @param vocabularyContexts Identifiers of repository contexts containing vocabularies intended for text analysis */ + @Throttle(value = "{#file.getUri()}", name = "fileAnalysis") @Transactional public void analyzeFile(File file, Set vocabularyContexts) { Objects.requireNonNull(file); final TextAnalysisInput input = createAnalysisInput(file); input.setVocabularyContexts(vocabularyContexts); invokeTextAnalysisOnFile(file, input); + LOG.debug("Text analysis finished for resource {}.", file.getUri()); + eventPublisher.publishEvent(new FileTextAnalysisFinishedEvent(this, file)); } private TextAnalysisInput createAnalysisInput(File file) { @@ -179,6 +190,7 @@ public void analyzeTermDefinition(AbstractTerm term, URI vocabularyContext) { input.setVocabularyRepositoryPassword(config.getRepository().getPassword()); invokeTextAnalysisOnTerm(term, input); + eventPublisher.publishEvent(new TermDefinitionTextAnalysisFinishedEvent(this, term)); } } diff --git a/src/main/java/cz/cvut/kbss/termit/service/document/html/HtmlTermOccurrenceResolver.java b/src/main/java/cz/cvut/kbss/termit/service/document/html/HtmlTermOccurrenceResolver.java index c67c466ca..2983c3c51 100644 --- a/src/main/java/cz/cvut/kbss/termit/service/document/html/HtmlTermOccurrenceResolver.java +++ b/src/main/java/cz/cvut/kbss/termit/service/document/html/HtmlTermOccurrenceResolver.java @@ -18,6 +18,7 @@ package cz.cvut.kbss.termit.service.document.html; import cz.cvut.kbss.termit.exception.AnnotationGenerationException; +import cz.cvut.kbss.termit.exception.TermItException; import cz.cvut.kbss.termit.model.Asset; import cz.cvut.kbss.termit.model.Term; import cz.cvut.kbss.termit.model.assignment.OccurrenceTarget; @@ -47,10 +48,8 @@ import java.io.InputStream; import java.net.URI; import java.nio.charset.StandardCharsets; -import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; @@ -65,15 +64,19 @@ public class HtmlTermOccurrenceResolver extends TermOccurrenceResolver { private static final String BNODE_PREFIX = "_:"; + private static final String SCORE_ATTRIBUTE = "score"; private static final Logger LOG = LoggerFactory.getLogger(HtmlTermOccurrenceResolver.class); private final HtmlSelectorGenerators selectorGenerators; + private final DocumentManager documentManager; + private final Configuration config; private Document document; + private Asset source; private Map prefixes; @@ -152,11 +155,10 @@ private String fullIri(String possiblyPrefixed) { } @Override - public List findTermOccurrences() { + public void findTermOccurrences(OccurrenceConsumer resultConsumer) { assert document != null; final Set visited = new HashSet<>(); final Elements elements = document.getElementsByAttribute(Constants.RDFa.ABOUT); - final List result = new ArrayList<>(elements.size()); final Double scoreThreshold = Double.parseDouble(config.getTextAnalysis().getTermOccurrenceMinScore()); for (Element element : elements) { if (isNotTermOccurrence(element)) { @@ -171,27 +173,31 @@ public List findTermOccurrences() { LOG.trace("Processing RDFa annotated element {}.", element); final Optional occurrence = resolveAnnotation(element, source); occurrence.ifPresent(to -> { - if (!to.isSuggested()) { - // Occurrence already approved in content (from previous manual approval) - result.add(to); - } else if (existsApproved(to)) { - LOG.trace("Found term occurrence {} with matching existing approved occurrence.", to); - to.markApproved(); - // Annotation without score is considered approved by the frontend - element.removeAttr(SCORE_ATTRIBUTE); - result.add(to); - } else { - if (to.getScore() > scoreThreshold) { - LOG.trace("Found term occurrence {}.", to); - result.add(to); + try { + if (!to.isSuggested()) { + // Occurrence already approved in content (from previous manual approval) + resultConsumer.accept(to); + } else if (existsApproved(to)) { + LOG.trace("Found term occurrence {} with matching existing approved occurrence.", to); + to.markApproved(); + // Annotation without score is considered approved by the frontend + element.removeAttr(SCORE_ATTRIBUTE); + resultConsumer.accept(to); } else { - LOG.trace("The confidence score of occurrence {} is lower than the configured threshold {}.", - to, scoreThreshold); + if (to.getScore() > scoreThreshold) { + LOG.trace("Found term occurrence {}.", to); + resultConsumer.accept(to); + } else { + LOG.trace("The confidence score of occurrence {} is lower than the configured threshold {}.", to, scoreThreshold); + } } + } catch (InterruptedException e) { + LOG.error("Thread interrupted while resolving term occurrences."); + Thread.currentThread().interrupt(); + throw new TermItException(e); } }); } - return result; } private Optional resolveAnnotation(Element rdfaElem, Asset source) { @@ -226,9 +232,7 @@ private void verifyTermExists(Element rdfaElem, URI termUri, String termId) { return; } if (!termService.exists(termUri)) { - throw new AnnotationGenerationException( - "Term with id " + Utils.uriToString( - termUri) + " denoted by RDFa element '" + rdfaElem + "' not found."); + throw new AnnotationGenerationException("Term with id " + Utils.uriToString(termUri) + " denoted by RDFa element '" + rdfaElem + "' not found."); } existingTermIds.add(termId); } @@ -273,8 +277,8 @@ public boolean supports(Asset source) { return true; } final Optional probedContentType = documentManager.getContentType(sourceFile); - return probedContentType.isPresent() - && (probedContentType.get().equals(MediaType.TEXT_HTML_VALUE) - || probedContentType.get().equals(MediaType.APPLICATION_XHTML_XML_VALUE)); + return probedContentType.isPresent() && (probedContentType.get() + .equals(MediaType.TEXT_HTML_VALUE) || probedContentType.get() + .equals(MediaType.APPLICATION_XHTML_XML_VALUE)); } } diff --git a/src/main/java/cz/cvut/kbss/termit/service/document/html/SelectorGenerator.java b/src/main/java/cz/cvut/kbss/termit/service/document/html/SelectorGenerator.java index b13049475..a55c7a022 100644 --- a/src/main/java/cz/cvut/kbss/termit/service/document/html/SelectorGenerator.java +++ b/src/main/java/cz/cvut/kbss/termit/service/document/html/SelectorGenerator.java @@ -57,10 +57,11 @@ default String extractExactText(Element[] elements) { default StringBuilder extractNodeText(Iterable nodes) { final StringBuilder sb = new StringBuilder(); for (Node node : nodes) { - if (!(node instanceof TextNode) && !(node instanceof Element)) { - continue; + if (node instanceof TextNode textNode) { + sb.append(textNode.getWholeText()); + } else if (node instanceof Element elementNode) { + sb.append(elementNode.wholeText()); } - sb.append(node instanceof TextNode ? ((TextNode) node).getWholeText() : ((Element) node).wholeText()); } return sb; } diff --git a/src/main/java/cz/cvut/kbss/termit/service/document/html/TextPositionSelectorGenerator.java b/src/main/java/cz/cvut/kbss/termit/service/document/html/TextPositionSelectorGenerator.java index b2fb792a8..767e06676 100644 --- a/src/main/java/cz/cvut/kbss/termit/service/document/html/TextPositionSelectorGenerator.java +++ b/src/main/java/cz/cvut/kbss/termit/service/document/html/TextPositionSelectorGenerator.java @@ -20,9 +20,13 @@ import cz.cvut.kbss.termit.model.selector.TextPositionSelector; import org.jsoup.nodes.Element; import org.jsoup.nodes.Node; +import org.jsoup.nodes.TextNode; import org.jsoup.select.Elements; +import org.jsoup.select.NodeTraversor; +import org.jsoup.select.NodeVisitor; import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; /** * Generates a {@link TextPositionSelector} for the specified elements. @@ -47,6 +51,14 @@ public TextPositionSelector generateSelector(Element... elements) { return selector; } + /** + * This code was extracted from {@link #extractNodeText} and related functions + * to prevent constructing whole string contents for only getting its length. + * Now only length is counted from the contents of text nodes. + * @see SelectorGenerator#extractNodeText(Iterable) + * @see Element#wholeText() + * @see TextNode#getWholeText() + */ private int resolveStartPosition(Element element) { final Elements ancestors = element.parents(); Element previous = element; diff --git a/src/main/java/cz/cvut/kbss/termit/service/jmx/AppAdminBean.java b/src/main/java/cz/cvut/kbss/termit/service/jmx/AppAdminBean.java index ae8019e24..c6095f424 100644 --- a/src/main/java/cz/cvut/kbss/termit/service/jmx/AppAdminBean.java +++ b/src/main/java/cz/cvut/kbss/termit/service/jmx/AppAdminBean.java @@ -19,7 +19,6 @@ import cz.cvut.kbss.termit.event.EvictCacheEvent; import cz.cvut.kbss.termit.event.RefreshLastModifiedEvent; -import cz.cvut.kbss.termit.event.VocabularyContentModified; import cz.cvut.kbss.termit.rest.dto.HealthInfo; import cz.cvut.kbss.termit.service.mail.Message; import cz.cvut.kbss.termit.service.mail.Postman; @@ -66,7 +65,6 @@ public void invalidateCaches() { eventPublisher.publishEvent(new EvictCacheEvent(this)); LOG.info("Refreshing last modified timestamps..."); eventPublisher.publishEvent(new RefreshLastModifiedEvent(this)); - eventPublisher.publishEvent(new VocabularyContentModified(this, null)); } @ManagedOperation(description = "Sends test email to the specified address.") diff --git a/src/main/java/cz/cvut/kbss/termit/service/repository/VocabularyRepositoryService.java b/src/main/java/cz/cvut/kbss/termit/service/repository/VocabularyRepositoryService.java index a9730c702..0f8fede41 100644 --- a/src/main/java/cz/cvut/kbss/termit/service/repository/VocabularyRepositoryService.java +++ b/src/main/java/cz/cvut/kbss/termit/service/repository/VocabularyRepositoryService.java @@ -41,6 +41,7 @@ import cz.cvut.kbss.termit.util.Configuration; import cz.cvut.kbss.termit.util.Constants; import cz.cvut.kbss.termit.util.Utils; +import cz.cvut.kbss.termit.util.throttle.CacheableFuture; import cz.cvut.kbss.termit.workspace.EditableVocabularies; import jakarta.validation.Validator; import org.apache.tika.Tika; @@ -205,7 +206,7 @@ public Vocabulary update(Vocabulary instance) { } public Collection getTransitivelyImportedVocabularies(Vocabulary entity) { - return vocabularyDao.getTransitivelyImportedVocabularies(entity); + return vocabularyDao.getTransitivelyImportedVocabularies(entity.getUri()); } public Set getRelatedVocabularies(Vocabulary entity) { @@ -319,8 +320,8 @@ private void ensureNoTermRelationsExists(Vocabulary vocabulary) throws AssetRemo } } - public List validateContents(Vocabulary instance) { - return vocabularyDao.validateContents(instance); + public CacheableFuture> validateContents(URI vocabulary) { + return vocabularyDao.validateContents(vocabulary); } public Integer getTermCount(Vocabulary vocabulary) { diff --git a/src/main/java/cz/cvut/kbss/termit/util/Configuration.java b/src/main/java/cz/cvut/kbss/termit/util/Configuration.java index 1b7bcaf11..cf609cab8 100644 --- a/src/main/java/cz/cvut/kbss/termit/util/Configuration.java +++ b/src/main/java/cz/cvut/kbss/termit/util/Configuration.java @@ -18,7 +18,9 @@ package cz.cvut.kbss.termit.util; import cz.cvut.kbss.termit.model.acl.AccessLevel; +import cz.cvut.kbss.termit.util.throttle.ThrottleAspect; import jakarta.validation.Valid; +import jakarta.validation.constraints.Future; import jakarta.validation.constraints.Min; import jakarta.validation.constraints.NotNull; import org.springframework.boot.context.properties.ConfigurationProperties; @@ -56,6 +58,32 @@ public class Configuration { * server. */ private String jmxBeanName = "TermItAdminBean"; + + /** + * The number of threads for thread pool executing asynchronous and long-running tasks. + * @configurationdoc.default The number of processors available to the Java virtual machine. + */ + @Min(1) + private Integer asyncThreadCount = Runtime.getRuntime().availableProcessors(); + + /** + * The amount of time in which calls of throttled methods + * should be merged. + * The value must be positive ({@code > 0}). + * @configurationdoc.default 10 seconds + * @see cz.cvut.kbss.termit.util.throttle.Throttle + * @see cz.cvut.kbss.termit.util.throttle.ThrottleAspect + */ + private Duration throttleThreshold = Duration.ofSeconds(10); + + /** + * After how much time, should objects with completed futures be discarded. + * The value must be positive ({@code > 0}). + * @configurationdoc.default 1 minute + * @see ThrottleAspect#clearOldFutures() + */ + private Duration throttleDiscardThreshold = Duration.ofMinutes(1); + @Valid private Persistence persistence = new Persistence(); @Valid @@ -111,6 +139,14 @@ public void setJmxBeanName(String jmxBeanName) { this.jmxBeanName = jmxBeanName; } + public Integer getAsyncThreadCount() { + return asyncThreadCount; + } + + public void setAsyncThreadCount(@Min(1) Integer asyncThreadCount) { + this.asyncThreadCount = asyncThreadCount; + } + public Persistence getPersistence() { return persistence; } @@ -263,6 +299,22 @@ public void setTemplate(Template template) { this.template = template; } + public Duration getThrottleThreshold() { + return throttleThreshold; + } + + public void setThrottleThreshold(Duration throttleThreshold) { + this.throttleThreshold = throttleThreshold; + } + + public Duration getThrottleDiscardThreshold() { + return throttleDiscardThreshold; + } + + public void setThrottleDiscardThreshold(Duration throttleDiscardThreshold) { + this.throttleDiscardThreshold = throttleDiscardThreshold; + } + @Validated public static class Persistence { /** @@ -600,8 +652,6 @@ public static class TextAnalysis { @Min(8) private int textQuoteSelectorContextLength = 32; - private boolean disableVocabularyAnalysisOnTermEdit = false; - public String getUrl() { return url; } @@ -625,14 +675,6 @@ public int getTextQuoteSelectorContextLength() { public void setTextQuoteSelectorContextLength(int textQuoteSelectorContextLength) { this.textQuoteSelectorContextLength = textQuoteSelectorContextLength; } - - public boolean isDisableVocabularyAnalysisOnTermEdit() { - return disableVocabularyAnalysisOnTermEdit; - } - - public void setDisableVocabularyAnalysisOnTermEdit(boolean disableVocabularyAnalysisOnTermEdit) { - this.disableVocabularyAnalysisOnTermEdit = disableVocabularyAnalysisOnTermEdit; - } } @Validated diff --git a/src/main/java/cz/cvut/kbss/termit/util/Constants.java b/src/main/java/cz/cvut/kbss/termit/util/Constants.java index fb0959d8f..601c4703f 100644 --- a/src/main/java/cz/cvut/kbss/termit/util/Constants.java +++ b/src/main/java/cz/cvut/kbss/termit/util/Constants.java @@ -18,10 +18,12 @@ package cz.cvut.kbss.termit.util; import cz.cvut.kbss.jopa.vocabulary.SKOS; +import cz.cvut.kbss.termit.util.throttle.ThrottleAspect; import org.springframework.data.domain.PageRequest; import org.springframework.data.domain.Pageable; import java.net.URI; +import java.time.Duration; import java.time.Instant; import java.time.ZoneId; import java.time.format.DateTimeFormatter; @@ -207,6 +209,10 @@ public static final class MediaType { public static final String EXCEL = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"; public static final String TURTLE = "text/turtle"; public static final String RDF_XML = "application/rdf+xml"; + + private MediaType() { + throw new AssertionError(); + } } /** @@ -244,6 +250,23 @@ private QueryParams() { } } + public static final class DebouncingGroups { + + /** + * Text analysis of all terms in specific vocabulary + */ + public static final String TEXT_ANALYSIS_VOCABULARY_TERMS_ALL_DEFINITIONS = "TEXT_ANALYSIS_VOCABULARY_TERMS_ALL_DEFINITIONS"; + + /** + * Text analysis of all vocabularies + */ + public static final String TEXT_ANALYSIS_VOCABULARY = "TEXT_ANALYSIS_VOCABULARY"; + + private DebouncingGroups() { + throw new AssertionError(); + } + } + /** * the maximum amount of data to buffer when sending messages to a WebSocket session */ diff --git a/src/main/java/cz/cvut/kbss/termit/util/ExceptionUtils.java b/src/main/java/cz/cvut/kbss/termit/util/ExceptionUtils.java new file mode 100644 index 000000000..e31b081c7 --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/util/ExceptionUtils.java @@ -0,0 +1,31 @@ +package cz.cvut.kbss.termit.util; + +import org.springframework.lang.NonNull; + +import java.util.HashSet; +import java.util.Set; + +public class ExceptionUtils { + private ExceptionUtils() { + throw new AssertionError(); + } + + /** + * Resolves all nested causes of the {@code throwable} and returns true if any is matching the {@code cause} + */ + public static boolean isCausedBy(final Throwable throwable, @NonNull final Class cause) { + Throwable t = throwable; + final Set visited = new HashSet<>(); + while (t != null) { + if(visited.add(t)) { + if (cause.isInstance(t)){ + return true; + } + t = t.getCause(); + continue; + } + break; + } + return false; + } +} diff --git a/src/main/java/cz/cvut/kbss/termit/util/Pair.java b/src/main/java/cz/cvut/kbss/termit/util/Pair.java new file mode 100644 index 000000000..ad0f36a34 --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/util/Pair.java @@ -0,0 +1,61 @@ +package cz.cvut.kbss.termit.util; + + +import org.springframework.lang.NonNull; + +import java.util.Objects; + +public class Pair { + + private final T first; + + private final V second; + + public Pair(T first, V second) { + this.first = first; + this.second = second; + } + + public T getFirst() { + return first; + } + + public V getSecond() { + return second; + } + + + /** + * First compares the first value, if they are equal, compares the second value. + */ + public static class ComparablePair, V extends java.lang.Comparable> + extends Pair implements java.lang.Comparable> { + + public ComparablePair(T first, V second) { + super(first, second); + } + + @Override + public int compareTo(@NonNull Pair.ComparablePair o) { + final int firstComparison = this.getFirst().compareTo(o.getFirst()); + if (firstComparison != 0) { + return firstComparison; + } + return this.getSecond().compareTo(o.getSecond()); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ComparablePair that = (ComparablePair) o; + return Objects.equals(getFirst(), that.getFirst()) && Objects.equals(getSecond(), that.getSecond()); + } + + @Override + public int hashCode() { + return Objects.hash(getFirst(), getSecond()); + } + } +} + diff --git a/src/main/java/cz/cvut/kbss/termit/util/longrunning/LongRunningTask.java b/src/main/java/cz/cvut/kbss/termit/util/longrunning/LongRunningTask.java new file mode 100644 index 000000000..d59913ec2 --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/util/longrunning/LongRunningTask.java @@ -0,0 +1,43 @@ +package cz.cvut.kbss.termit.util.longrunning; + +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; + +import java.time.Instant; +import java.util.Optional; +import java.util.UUID; + +/** + * An asynchronously running task that is expected to run for some time. + */ +public interface LongRunningTask { + + @Nullable + String getName(); + + /** + * @return true when the task is being actively executed, false otherwise. + */ + boolean isRunning(); + + /** + * Returns {@code true} if this task completed. + *

+ * Completion may be due to normal termination, an exception, or + * cancellation -- in all of these cases, this method will return + * {@code true}. + * + * @return {@code true} if this task completed + */ + boolean isDone(); + + /** + * @return a timestamp of the task execution start, + * or empty if the task execution has not yet started. + */ + @NonNull + Optional startedAt(); + + @NonNull + UUID getUuid(); +} diff --git a/src/main/java/cz/cvut/kbss/termit/util/longrunning/LongRunningTaskScheduler.java b/src/main/java/cz/cvut/kbss/termit/util/longrunning/LongRunningTaskScheduler.java new file mode 100644 index 000000000..d4c396f7c --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/util/longrunning/LongRunningTaskScheduler.java @@ -0,0 +1,22 @@ +package cz.cvut.kbss.termit.util.longrunning; + +import org.springframework.lang.NonNull; + +/** + * An object that will schedule a long-running tasks + * @see LongRunningTask + */ +public abstract class LongRunningTaskScheduler { + private final LongRunningTasksRegistry registry; + + protected LongRunningTaskScheduler(LongRunningTasksRegistry registry) { + this.registry = registry; + } + + protected final void notifyTaskChanged(final @NonNull LongRunningTask task) { + final String name = task.getName(); + if (name != null && !name.isBlank()) { + registry.onTaskChanged(task); + } + } +} diff --git a/src/main/java/cz/cvut/kbss/termit/util/longrunning/LongRunningTaskStatus.java b/src/main/java/cz/cvut/kbss/termit/util/longrunning/LongRunningTaskStatus.java new file mode 100644 index 000000000..aa4859c61 --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/util/longrunning/LongRunningTaskStatus.java @@ -0,0 +1,64 @@ +package cz.cvut.kbss.termit.util.longrunning; + +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; + +import java.io.Serializable; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Objects; +import java.util.UUID; + +public class LongRunningTaskStatus implements Serializable { + + private final String name; + + private final UUID uuid; + + private final State state; + + private final Instant startedAt; + + public LongRunningTaskStatus(@NonNull LongRunningTask task) { + Objects.requireNonNull(task.getName()); + this.name = task.getName(); + this.startedAt = task.startedAt().map(time -> time.truncatedTo(ChronoUnit.SECONDS)).orElse(null); + this.state = State.of(task); + this.uuid = task.getUuid(); + } + + public @NonNull String getName() { + return name; + } + + public State getState() { + return state; + } + + public @Nullable Instant getStartedAt() { + return startedAt; + } + + public @NonNull UUID getUuid() { + return uuid; + } + + @Override + public String toString() { + return "{" + state.name() + (startedAt == null ? "" : ", startedAt=" + startedAt) + ", " + uuid + "}"; + } + + public enum State { + PENDING, RUNNING, DONE; + + public static State of(@NonNull LongRunningTask task) { + if (task.isRunning()) { + return RUNNING; + } else if (task.isDone()) { + return DONE; + } else { + return PENDING; + } + } + } +} diff --git a/src/main/java/cz/cvut/kbss/termit/util/longrunning/LongRunningTasksRegistry.java b/src/main/java/cz/cvut/kbss/termit/util/longrunning/LongRunningTasksRegistry.java new file mode 100644 index 000000000..a73435f4b --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/util/longrunning/LongRunningTasksRegistry.java @@ -0,0 +1,64 @@ +package cz.cvut.kbss.termit.util.longrunning; + +import cz.cvut.kbss.termit.event.LongRunningTaskChangedEvent; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ApplicationEventPublisher; +import org.springframework.lang.NonNull; +import org.springframework.stereotype.Component; + +import java.util.List; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; + +@Component +public class LongRunningTasksRegistry { + + private static final Logger LOG = LoggerFactory.getLogger(LongRunningTasksRegistry.class); + + private final ConcurrentHashMap registry = new ConcurrentHashMap<>(); + + private final ApplicationEventPublisher eventPublisher; + + @Autowired + public LongRunningTasksRegistry(ApplicationEventPublisher eventPublisher) { + this.eventPublisher = eventPublisher; + } + + public void onTaskChanged(@NonNull final LongRunningTask task) { + final LongRunningTaskStatus status = new LongRunningTaskStatus(task); + + if (LOG.isTraceEnabled()) { + LOG.atTrace().setMessage("Long running task changed state: {}{}").addArgument(status::getName) + .addArgument(status).log(); + } + + handleTaskChanged(task); + eventPublisher.publishEvent(new LongRunningTaskChangedEvent(this, status)); + } + + private void handleTaskChanged(@NonNull final LongRunningTask task) { + if(task.isDone()) { + registry.remove(task.getUuid()); + } else { + registry.put(task.getUuid(), task); + } + + // perform cleanup + registry.forEach((key, value) -> { + if (value.isDone()) { + registry.remove(key); + } + }); + + if (LOG.isTraceEnabled() && registry.isEmpty()) { + LOG.trace("All long running tasks completed"); + } + } + + @NonNull + public List getTasks() { + return registry.values().stream().map(LongRunningTaskStatus::new).toList(); + } +} diff --git a/src/main/java/cz/cvut/kbss/termit/util/throttle/CacheableFuture.java b/src/main/java/cz/cvut/kbss/termit/util/throttle/CacheableFuture.java new file mode 100644 index 000000000..6af5651d5 --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/util/throttle/CacheableFuture.java @@ -0,0 +1,46 @@ +package cz.cvut.kbss.termit.util.throttle; + +import cz.cvut.kbss.termit.exception.TermItException; +import org.springframework.lang.Nullable; + +import java.util.Optional; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; + +/** + * A future which can provide a cached result before its completion. + * @see Future + */ +public interface CacheableFuture extends ChainableFuture { + + /** + * @return the cached result when available + */ + Optional getCachedResult(); + + /** + * Sets possible cached result + * + * @param cachedResult the result to set, or null to clear the cache + * @return self + */ + CacheableFuture setCachedResult(@Nullable final T cachedResult); + + /** + * @return the future result if it is available, cached result otherwise. + */ + default Optional getNow() { + try { + if (isDone() && !isCancelled()) { + return Optional.of(get()); + } + } catch (ExecutionException e) { + throw new TermItException(e); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new TermItException(e); + } + + return getCachedResult(); + } +} diff --git a/src/main/java/cz/cvut/kbss/termit/util/throttle/ChainableFuture.java b/src/main/java/cz/cvut/kbss/termit/util/throttle/ChainableFuture.java new file mode 100644 index 000000000..0d8b63d6c --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/util/throttle/ChainableFuture.java @@ -0,0 +1,16 @@ +package cz.cvut.kbss.termit.util.throttle; + +import java.util.concurrent.Future; +import java.util.function.Consumer; + +public interface ChainableFuture extends Future { + + /** + * Executes this action once the future is completed normally. + * Action is not executed on exceptional completion. + *

+ * If the future is already completed, action is executed synchronously. + * @param action action to be executed + */ + ChainableFuture then(Consumer action); +} diff --git a/src/main/java/cz/cvut/kbss/termit/util/throttle/SynchronousTransactionExecutor.java b/src/main/java/cz/cvut/kbss/termit/util/throttle/SynchronousTransactionExecutor.java new file mode 100644 index 000000000..74b31b905 --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/util/throttle/SynchronousTransactionExecutor.java @@ -0,0 +1,22 @@ +package cz.cvut.kbss.termit.util.throttle; + +import org.springframework.lang.NonNull; +import org.springframework.stereotype.Component; +import org.springframework.transaction.annotation.Transactional; + +import java.util.concurrent.Executor; + +/** + * Executes the runnable in a transaction synchronously. + * + * @see Transactional + */ +@Component +public class SynchronousTransactionExecutor implements Executor { + + @Transactional + @Override + public void execute(@NonNull Runnable command) { + command.run(); + } +} diff --git a/src/main/java/cz/cvut/kbss/termit/util/throttle/Throttle.java b/src/main/java/cz/cvut/kbss/termit/util/throttle/Throttle.java new file mode 100644 index 000000000..cc9c9080b --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/util/throttle/Throttle.java @@ -0,0 +1,105 @@ +package cz.cvut.kbss.termit.util.throttle; + +import cz.cvut.kbss.termit.util.Constants; +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import java.util.concurrent.Future; + +/** + * Indicates that calls to this method will be throttled & debounced. + *

+ * The task created from the method will be executed on the first call of the method, + * then every next call which comes earlier than {@link Constants#THROTTLE_THRESHOLD} + * will return a pending future which might be resolved by a newer call. + * Future will be resolved once per {@link Constants#THROTTLE_THRESHOLD} (+ duration to execute the future). + *

+ * + *

+ * Call to this method cannot be part of an existing transaction. + * If {@link org.springframework.transaction.annotation.Transactional @Transactional} is present with this annotation, + * new transaction is created for the task execution. + *

+ * Available only for methods returning {@code void}, {@link Void} and {@link ThrottledFuture}, + * method signature may be {@link Future}, + * or another type assignable from {@link ThrottledFuture}, + * but the returned concrete object has to be {@link ThrottledFuture}, method call will throw otherwise! + *

+ * Whole body of method with {@code void} or {@link Void} return types will be considered as task which will be executed later. + * In case of {@link Future} return type, only task in returned {@link ThrottledFuture} is throttled, + * meaning that actual body of the method will be executed every call. + *

+ * Note that returned future can be canceled + *

+ * Method may also return already canceled or fulfilled future; in that case, the result is returned immediately. + *

+ * Example implementation: + *


+ *  {@code @}Throttle(value = "{#paramObj, #anotherParam}")
+ *  public Future<String> myFunction(Object paramObj, Object anotherParam) {
+ *      // this will execute on every call as the return type is future
+ *      LOG.trace("my function called");
+ *      return ThrottledFuture.of(() -> doStuff()); // doStuff() will be throttled
+ *  }
+ * 
+ *

+ *  {@code @}Throttle(value = "{#paramObj, #anotherParam}")
+ *  public void myFunction(Object paramObj, Object anotherParam) {
+ *      // whole method body will be throttled, as return type is not future
+ *      LOG.trace("my function called");
+ *  }
+ * 
+ * + * @implNote Methods will be called from a separated thread. + * @see Debouncing and Throttling + * @see Throttling + debouncing image + */ +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +@Documented +public @interface Throttle { + + /** + * The Spring-EL expression + * returning a List of Objects or a String which will be used to construct the unique identifier + * for this throttled instance. + */ + @NonNull String value() default ""; + + /** + * The Spring-EL expression + * returning group identifier a List of Objects or a String to which this throttle belongs. + *

+ * When there is a pending task P with a group + * that is also a prefix for a group of a new task N, + * the new task N will be canceled immediately. + * The group of the task P is lower than the group of the task N. + *

+ * When a task with lower group is scheduled, all scheduled tasks with higher groups are canceled. + *

+ * Example: + *

+     *     new task A with group "my.group.task1" is scheduled
+     *     new task B with group "my.group.task1.subtask" wants to be scheduled
+     *        -> task B is canceled immediately (task A with lower group is already pending)
+     *     new task C with group "my.group" is scheduled
+     *        -> task A is canceled as the task C has lower group than A
+     * 
+ * Blank string disables any group processing. + * @see String#compareTo(String) + */ + @NonNull String group() default ""; + + /** + * @return a key name of the task which is displayed on the frontend. + * Example: {@code name = "validation"} on frontend a translatable name with a key + * {@code "longrunningtasks.name.validation"} is displayed. + * Leave blank to hide the task on the frontend. + */ + @Nullable String name() default ""; +} diff --git a/src/main/java/cz/cvut/kbss/termit/util/throttle/ThrottleAspect.java b/src/main/java/cz/cvut/kbss/termit/util/throttle/ThrottleAspect.java new file mode 100644 index 000000000..fcf8ea14b --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/util/throttle/ThrottleAspect.java @@ -0,0 +1,611 @@ +package cz.cvut.kbss.termit.util.throttle; + +import cz.cvut.kbss.termit.TermItApplication; +import cz.cvut.kbss.termit.exception.TermItException; +import cz.cvut.kbss.termit.exception.ThrottleAspectException; +import cz.cvut.kbss.termit.util.Configuration; +import cz.cvut.kbss.termit.util.Pair; +import cz.cvut.kbss.termit.util.longrunning.LongRunningTaskScheduler; +import cz.cvut.kbss.termit.util.longrunning.LongRunningTasksRegistry; +import org.aspectj.lang.JoinPoint; +import org.aspectj.lang.ProceedingJoinPoint; +import org.aspectj.lang.reflect.MethodSignature; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.context.annotation.Profile; +import org.springframework.context.annotation.Scope; +import org.springframework.core.annotation.Order; +import org.springframework.expression.EvaluationContext; +import org.springframework.expression.EvaluationException; +import org.springframework.expression.Expression; +import org.springframework.expression.ExpressionParser; +import org.springframework.expression.spel.standard.SpelExpressionParser; +import org.springframework.expression.spel.support.DataBindingPropertyAccessor; +import org.springframework.expression.spel.support.StandardEvaluationContext; +import org.springframework.expression.spel.support.StandardTypeLocator; +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; +import org.springframework.scheduling.TaskScheduler; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.stereotype.Component; +import org.springframework.transaction.annotation.Transactional; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.NavigableMap; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.TreeMap; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.springframework.beans.factory.config.ConfigurableBeanFactory.SCOPE_SINGLETON; + +/** + * @see Throttle + * @implNote The aspect is configured in {@code spring-aop.xml}, this uses Spring AOP instead of AspectJ. + */ +@Order +@Scope(SCOPE_SINGLETON) +@Component("throttleAspect") +@Profile("!test") +public class ThrottleAspect extends LongRunningTaskScheduler { + + private static final Logger LOG = LoggerFactory.getLogger(ThrottleAspect.class); + + /** + *

Throttled futures are returned as results of method calls.

+ *

Tasks inside them can be replaced by a newer ones allowing + * to merge multiple (throttled) method calls into a single one while always executing the newest one possible.

+ *

A task inside a throttled future represents + * a heavy/long-running task acquired from the body of an throttled method

+ * + * @implSpec Synchronize in the field declaration order before modification + */ + private final Map> throttledFutures; + + /** + * The last run is updated every time a task is finished. + * @implSpec Synchronize in the field declaration order before modification + */ + private final Map lastRun; + + /** + * Scheduled futures are returned from {@link #taskScheduler}. + * Futures are completed by execution of tasks created in {@link #createRunnableToSchedule}. + * Records about them are used for their cancellation in case of debouncing. + * + * @implSpec Synchronize in the field declaration order before modification + */ + private final NavigableMap> scheduledFutures; + + /** + * Thread safe set holding identifiers of threads that are + * currently executing a throttled task. + */ + private final Set throttledThreads = ConcurrentHashMap.newKeySet(); + + /** + * Parser for Spring Expression Language + */ + private final ExpressionParser parser = new SpelExpressionParser(); + + private final TaskScheduler taskScheduler; + + /** + * A base context for evaluation of SpEL expressions + */ + private final StandardEvaluationContext standardEvaluationContext; + + /** + * Used for acquiring {@link #lastRun} timestamps. + * @implNote for testing purposes + */ + private final Clock clock; + + /** + * Wrapper for executions in a transaction context + */ + private final SynchronousTransactionExecutor transactionExecutor; + + /** + * A timestamp of the last time maps were cleaned. + * The reference might be null. + * @see #clearOldFutures() + */ + private final AtomicReference lastClear; + + private final Configuration configuration; + + @Autowired + public ThrottleAspect(@Qualifier("longRunningTaskScheduler") TaskScheduler taskScheduler, + SynchronousTransactionExecutor transactionExecutor, + LongRunningTasksRegistry longRunningTasksRegistry, Configuration configuration) { + super(longRunningTasksRegistry); + this.taskScheduler = taskScheduler; + this.transactionExecutor = transactionExecutor; + this.configuration = configuration; + throttledFutures = new HashMap<>(); + lastRun = new HashMap<>(); + scheduledFutures = new TreeMap<>(); + clock = Clock.systemUTC(); // used by Instant.now() by default + standardEvaluationContext = makeDefaultContext(); + lastClear = new AtomicReference<>(Instant.now(clock)); + } + + /** + * Constructor for testing environment + */ + protected ThrottleAspect(Map> throttledFutures, + Map lastRun, + NavigableMap> scheduledFutures, TaskScheduler taskScheduler, + Clock clock, SynchronousTransactionExecutor transactionExecutor, + LongRunningTasksRegistry longRunningTasksRegistry, Configuration configuration) { + super(longRunningTasksRegistry); + this.throttledFutures = throttledFutures; + this.lastRun = lastRun; + this.scheduledFutures = scheduledFutures; + this.taskScheduler = taskScheduler; + this.clock = clock; + this.transactionExecutor = transactionExecutor; + this.configuration = configuration; + standardEvaluationContext = makeDefaultContext(); + lastClear = new AtomicReference<>(Instant.now(clock)); + } + + private static StandardEvaluationContext makeDefaultContext() { + StandardEvaluationContext standardEvaluationContext = new StandardEvaluationContext(); + standardEvaluationContext.addPropertyAccessor(DataBindingPropertyAccessor.forReadOnlyAccess()); + + final ClassLoader loader = ThrottleAspect.class.getClassLoader(); + final StandardTypeLocator typeLocator = new StandardTypeLocator(loader); + + final String basePackage = TermItApplication.class.getPackageName(); + Arrays.stream(loader.getDefinedPackages()).map(Package::getName).filter(s -> s.indexOf(basePackage) == 0) + .forEach(typeLocator::registerImport); + + standardEvaluationContext.setTypeLocator(typeLocator); + return standardEvaluationContext; + } + + /** + * @return future or null + * @throws TermItException when the target method throws + * @throws IllegalCallerException when the annotated method returns another type than {@code void}, {@link Void} or {@link Future} + * @implNote Around advice configured in {@code spring-aop.xml} + */ + public @Nullable Object throttleMethodCall(@NonNull ProceedingJoinPoint joinPoint, + @NonNull Throttle throttleAnnotation) throws Throwable { + + // if the current thread is already executing a throttled code, we want to skip further throttling + if (throttledThreads.contains(Thread.currentThread().getId())) { + // proceed with method execution + final Object result = joinPoint.proceed(); + if (result instanceof ThrottledFuture throttledFuture) { + // directly run throttled future + throttledFuture.run(null); + return throttledFuture; + } + return result; + } + + return doThrottle(joinPoint, throttleAnnotation); + } + + private synchronized @Nullable Object doThrottle(@NonNull ProceedingJoinPoint joinPoint, + @NonNull Throttle throttleAnnotation) throws Throwable { + + final MethodSignature signature = (MethodSignature) joinPoint.getSignature(); + + // construct the throttle instance key + final Identifier identifier = makeIdentifier(joinPoint, throttleAnnotation); + LOG.trace("Throttling task with key '{}'", identifier); + + synchronized (scheduledFutures) { + if (!identifier.getGroup().isBlank()) { + // check if there is a task with lower group + // and if so, cancel this task in favor of the lower group + final Map.Entry> lowerEntry = scheduledFutures.lowerEntry(identifier); + if (lowerEntry != null) { + final Future lowerFuture = lowerEntry.getValue(); + boolean hasGroupPrefix = identifier.hasGroupPrefix(lowerEntry.getKey().getGroup()); + if (hasGroupPrefix && !lowerFuture.isDone()) { + LOG.trace("Throttling canceled due to scheduled lower task '{}'", lowerEntry.getKey()); + return ThrottledFuture.canceled(); + } + } + + cancelWithHigherGroup(identifier); + } + } + + // if there is a scheduled task and this throttled instance was executed in the last configuration.getThrottleThreshold() + // cancel the scheduled task + // -> the execution is further delayed + Future oldScheduledFuture = scheduledFutures.get(identifier); + boolean throttleExpired = isThresholdExpired(identifier); + if (oldScheduledFuture != null && !throttleExpired) { + oldScheduledFuture.cancel(false); + synchronized (scheduledFutures) { + scheduledFutures.remove(identifier); + } + } + + // acquire a throttled future from a map, or make a new one + ThrottledFuture oldThrottledFuture = throttledFutures.getOrDefault(identifier, new ThrottledFuture<>()); + + final Pair> pair = getFutureTask(joinPoint, identifier, oldThrottledFuture); + ThrottledFuture future = pair.getSecond(); + future.setName(throttleAnnotation.name()); + // update the throttled future in the map, it might be just the same future, but it might be a new one + synchronized (throttledFutures) { + throttledFutures.put(identifier, future); + } + + Object result = resultVoidOrFuture(signature, future); + + if (future.isDone() || future.isRunning()) { + return result; + } + + if (oldScheduledFuture == null || oldThrottledFuture != future || oldScheduledFuture.isDone()) { + boolean oldFutureIsDone = oldScheduledFuture == null || oldScheduledFuture.isDone(); + if (oldThrottledFuture != future) { + oldThrottledFuture.then(ignored -> + schedule(identifier, pair.getFirst(), throttleExpired && oldFutureIsDone) + ); + } else { + schedule(identifier, pair.getFirst(), throttleExpired && oldFutureIsDone); + } + notifyTaskChanged(future); + } + + return result; + } + + /** + * Maps parameter names from the method signature to their values from {@link JoinPoint#getArgs()} + * + * @param map to fill + * @param signature the method signature + * @param joinPoint the join point + */ + private static void resolveParameters(Map map, MethodSignature signature, JoinPoint joinPoint) { + final String[] paramNames = signature.getParameterNames(); + final Object[] params = joinPoint.getArgs(); + + if (paramNames == null || params == null || params.length != paramNames.length) { + return; + } + + for (int i = 0; i < params.length; i++) { + map.putIfAbsent(paramNames[i], params[i]); + } + } + + private EvaluationContext makeContext(JoinPoint joinPoint, Map parameters) { + StandardEvaluationContext context = new StandardEvaluationContext(); + standardEvaluationContext.applyDelegatesTo(context); + context.setRootObject(joinPoint.getTarget()); + context.setVariables(parameters); + return context; + } + + private Pair> getFutureTask(@NonNull ProceedingJoinPoint joinPoint, + @NonNull Identifier identifier, + @NonNull ThrottledFuture future) + throws Throwable { + + final MethodSignature methodSignature = (MethodSignature) joinPoint.getSignature(); + final Class returnType = methodSignature.getReturnType(); + final boolean isFuture = returnType.isAssignableFrom(ThrottledFuture.class); + final boolean isVoid = returnType.equals(Void.class) || returnType.equals(Void.TYPE); + + ThrottledFuture throttledFuture = future; + + // Sets the task to the future. + // If the annotated method returns throttled future, transfer the new task into the future + // replacing the old one. + // If the method does not return a throttled future, + // fill the future with a task which calls the annotated method returning the result + + // the future must contain the same type - ensured by accessing with the unique key + if (isFuture) { + Object result = joinPoint.proceed(); + if (result instanceof ThrottledFuture throttledMethodFuture) { + // future acquired by key or a new future supplied, ensuring the same type + // ThrottledFuture#updateOther will create a new future when required + if (throttledMethodFuture.isDone()) { + throttledFuture = (ThrottledFuture) throttledMethodFuture; + } else { + // transfer the newer task from methodFuture -> to the (old) throttled future + throttledFuture = ((ThrottledFuture) throttledMethodFuture).transfer(throttledFuture); + } + } else { + throw new ThrottleAspectException("Returned value is not a ThrottledFuture"); + } + } else if (isVoid) { + throttledFuture = throttledFuture.update(() -> { + try { + return joinPoint.proceed(); + } catch (Throwable e) { + // exception happened inside throttled method, + // and the method returns null, if we rethrow the exception + // it will be stored inside the future + // and never retrieved (as the method returned null) + LOG.error("Exception thrown during task execution", e); + throw new TermItException(e); + } + }, List.of()); + } else { + throw new ThrottleAspectException("Invalid return type for " + joinPoint.getSignature() + " annotated with @Debounce, only Future or void allowed!"); + } + + final boolean withTransaction = methodSignature.getMethod() != null && methodSignature.getMethod() + .isAnnotationPresent(Transactional.class); + + // create a task which will be scheduled with executor + final Runnable toSchedule = createRunnableToSchedule(throttledFuture, identifier, withTransaction); + + return new Pair<>(toSchedule, throttledFuture); + } + + /** + * @return the number of throttled futures that are neither done nor running. + */ + private long countRemaining() { + synchronized (throttledFutures) { + return throttledFutures.values().stream().filter(f -> !f.isDone() && !f.isRunning()).count(); + } + } + + /** + * @return count of throttled threads + */ + private long countRunning() { + return throttledThreads.size(); + } + + private Runnable createRunnableToSchedule(ThrottledFuture throttledFuture, Identifier identifier, + boolean withTransaction) { + final Supplier securityContext = SecurityContextHolder.getDeferredContext(); + return () -> { + if (throttledFuture.isDone()) { + return; + } + // mark the thread as throttled + final Long threadId = Thread.currentThread().getId(); + throttledThreads.add(threadId); + + LOG.trace("Running throttled task [{} left] [{} running] '{}'", countRemaining() - 1, countRunning(), identifier); + + // restore the security context + SecurityContextHolder.setContext(securityContext.get()); + try { + // fulfill the future + if (withTransaction) { + transactionExecutor.execute(()->throttledFuture.run(this::notifyTaskChanged)); + } else { + throttledFuture.run(this::notifyTaskChanged); + } + // update last run timestamp + synchronized (lastRun) { + lastRun.put(identifier, Instant.now(clock)); + } + } finally { + if (!throttledFuture.isDone()) { + throttledFuture.cancel(false); + } + notifyTaskChanged(throttledFuture); // task done + // clear the security context + SecurityContextHolder.clearContext(); + LOG.trace("Finished throttled task [{} left] [{} running] '{}'", countRemaining(), countRunning() - 1, identifier); + + clearOldFutures(); + + // remove throttled mark + throttledThreads.remove(threadId); + } + }; + } + + /** + * Discards futures from {@link #throttledFutures}, {@link #lastRun} and {@link #scheduledFutures} maps. + *

Every completed future for which a {@link Configuration#throttleDiscardThreshold throttleDiscardThreshold} expired is discarded.

+ * @see #isThresholdExpired(Identifier) + */ + private void clearOldFutures() { + // if the last clear was performed less than a threshold ago, skip it for now + Instant last = lastClear.get(); + if (last.isAfter(Instant.now(clock).minus(configuration.getThrottleThreshold()).minus(configuration.getThrottleDiscardThreshold()))) { + return; + } + if (!lastClear.compareAndSet(last, Instant.now(clock))) { + return; + } + synchronized (throttledFutures) { // synchronize in the filed declaration order + synchronized (lastRun) { + synchronized (scheduledFutures) { + Stream.of(throttledFutures.keySet().stream(), scheduledFutures.keySet().stream(), lastRun.keySet() + .stream()) + .flatMap(s -> s).distinct().toList() // ensures safe modification of maps + .forEach(identifier -> { + if (isThresholdExpiredByMoreThan(identifier, configuration.getThrottleDiscardThreshold())) { + Optional.ofNullable(throttledFutures.get(identifier)).ifPresent(throttled -> { + if (throttled.isDone()) { + throttledFutures.remove(identifier); + } + }); + Optional.ofNullable(scheduledFutures.get(identifier)).ifPresent(scheduled -> { + if (scheduled.isDone()) { + scheduledFutures.remove(identifier); + } + }); + lastRun.remove(identifier); + } + }); + } + } + } + } + + /** + * @param identifier of the task + * @param duration to add to the throttle threshold + * @return Whether the last time when a task with specified {@code identifier} run + * is older than ({@link Configuration#throttleThreshold throttleThreshold} + {@code duration}) + */ + private boolean isThresholdExpiredByMoreThan(Identifier identifier, Duration duration) { + return lastRun.getOrDefault(identifier, Instant.MAX).isBefore(Instant.now(clock).minus(configuration.getThrottleThreshold()).minus(duration)); + } + + /** + * @return Whether the time when the identifier last run is older than the threshold, + * true when the task had never run + */ + private boolean isThresholdExpired(Identifier identifier) { + return lastRun.getOrDefault(identifier, Instant.EPOCH).isBefore(Instant.now(clock).minus(configuration.getThrottleThreshold())); + } + + @SuppressWarnings("unchecked") + private void schedule(Identifier identifier, Runnable task, boolean immediately) { + Instant startTime = Instant.now(clock).plus(configuration.getThrottleThreshold()); + if (immediately) { + startTime = Instant.now(clock); + } + synchronized (scheduledFutures) { + Future scheduled = taskScheduler.schedule(task, startTime); + // casting the type parameter to Object + scheduledFutures.put(identifier, (Future) scheduled); + } + } + + private void cancelWithHigherGroup(Identifier throttleAnnotation) { + if (throttleAnnotation.getGroup().isBlank()) { + return; + } + synchronized (throttledFutures) { // synchronize in the filed declaration order + synchronized (scheduledFutures) { + // look for any futures with higher group + // cancel them and remove from maps + Future higherFuture; + Identifier higherKey = scheduledFutures.higherKey(new Identifier(throttleAnnotation.getGroup(), "")); + while (higherKey != null) { + if (!higherKey.hasGroupPrefix(throttleAnnotation.getGroup()) || higherKey.getGroup() + .equals(throttleAnnotation.getGroup())) { + break; + } + + higherFuture = scheduledFutures.get(higherKey); + higherFuture.cancel(false); + final ThrottledFuture throttledFuture = throttledFutures.get(higherKey); + + // cancels future if it's not null (should not be) and removes it from map if it was canceled + if (throttledFuture != null && throttledFuture.cancel(false)) { + throttledFutures.remove(higherKey); + notifyTaskChanged(throttledFuture); + } + + scheduledFutures.remove(higherKey); + + higherKey = scheduledFutures.higherKey(higherKey); + } + } + } + } + + private Identifier makeIdentifier(JoinPoint joinPoint, Throttle throttleAnnotation) throws IllegalCallerException { + final String identifier = constructIdentifier(joinPoint, throttleAnnotation.value()); + final String groupIdentifier = constructIdentifier(joinPoint, throttleAnnotation.group()); + + return new Identifier(groupIdentifier, joinPoint.getSignature().toShortString() + "-" + identifier); + } + + private @Nullable Object resultVoidOrFuture(@NonNull MethodSignature signature, ThrottledFuture future) + throws IllegalCallerException { + Class returnType = signature.getReturnType(); + if (returnType.isAssignableFrom(ThrottledFuture.class)) { + return future; + } + if (Void.TYPE.equals(returnType) || Void.class.equals(returnType)) { + return null; + } + throw new ThrottleAspectException("Invalid return type for " + signature + " annotated with @Debounce, only Future or void allowed!"); + } + + + @SuppressWarnings({"unchecked"}) + private @NonNull String constructIdentifier(JoinPoint joinPoint, String expression) throws ThrottleAspectException { + if (expression == null || expression.isBlank()) { + return ""; + } + + final Map parameters = new HashMap<>(); + resolveParameters(parameters, (MethodSignature) joinPoint.getSignature(), joinPoint); + + final EvaluationContext context = makeContext(joinPoint, parameters); + + final Expression identifierExp = parser.parseExpression(expression); + try { + Object result = identifierExp.getValue(context); + + if (result instanceof String stringResult) { + return stringResult; + } + + // casting the expression result to the list of objects + // exception handled and rethrown by try-catch + Collection identifierList = (Collection) identifierExp.getValue(context); + Objects.requireNonNull(identifierList); + return identifierList.stream().map(Object::toString).collect(Collectors.joining("-")); + } catch (EvaluationException | ClassCastException | NullPointerException e) { + throw new ThrottleAspectException("The expression: '" + expression + "' has not been resolved to a Collection or String", e); + } + } + + /** + * A composed identifier of a throttled instance. + *

+     *     String group
+     *     String identifier
+     * 
+ * Implements comparable, first comparing group, then identifier. + */ + protected static class Identifier extends Pair.ComparablePair { + + public Identifier(String group, String identifier) { + super(group, identifier); + } + + public String getGroup() { + return this.getFirst(); + } + + public String getIdentifier() { + return this.getSecond(); + } + + public boolean hasGroupPrefix(@NonNull String group) { + return this.getGroup().indexOf(group) == 0 && !this.getGroup().isBlank() && !group.isBlank(); + } + + @Override + public String toString() { + return "ThrottleAspect.Identifier{group='" + getGroup() + "',identifier='" + getIdentifier() + "'}"; + } + } +} diff --git a/src/main/java/cz/cvut/kbss/termit/util/throttle/ThrottleGroupProvider.java b/src/main/java/cz/cvut/kbss/termit/util/throttle/ThrottleGroupProvider.java new file mode 100644 index 000000000..c832ef463 --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/util/throttle/ThrottleGroupProvider.java @@ -0,0 +1,29 @@ +package cz.cvut.kbss.termit.util.throttle; + +import java.net.URI; + +/** + * Provides static methods allowing construction of dynamic group identifiers + * used in {@link Throttle @Throttle} annotations. + */ +@SuppressWarnings("unused") // it is used from SpEL expressions +public class ThrottleGroupProvider { + + private ThrottleGroupProvider() { + throw new AssertionError(); + } + + private static final String TEXT_ANALYSIS_VOCABULARIES = "TEXT_ANALYSIS_VOCABULARIES"; + + public static String getTextAnalysisVocabulariesAll() { + return TEXT_ANALYSIS_VOCABULARIES; + } + + public static String getTextAnalysisVocabularyAllTerms(URI vocabulary) { + return TEXT_ANALYSIS_VOCABULARIES + "_" + vocabulary; + } + + public static String getTextAnalysisVocabularyTerm(URI vocabulary, URI term) { + return TEXT_ANALYSIS_VOCABULARIES + "_" + vocabulary + "_" + term; + } +} diff --git a/src/main/java/cz/cvut/kbss/termit/util/throttle/ThrottledFuture.java b/src/main/java/cz/cvut/kbss/termit/util/throttle/ThrottledFuture.java new file mode 100644 index 000000000..35947b403 --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/util/throttle/ThrottledFuture.java @@ -0,0 +1,265 @@ +package cz.cvut.kbss.termit.util.throttle; + +import cz.cvut.kbss.termit.exception.TermItException; +import cz.cvut.kbss.termit.util.Utils; +import cz.cvut.kbss.termit.util.longrunning.LongRunningTask; +import org.springframework.lang.NonNull; +import org.springframework.lang.Nullable; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Consumer; +import java.util.function.Supplier; + +public class ThrottledFuture implements CacheableFuture, LongRunningTask { + + private final ReentrantLock lock = new ReentrantLock(); + private final ReentrantLock callbackLock = new ReentrantLock(); + + private final UUID uuid = UUID.randomUUID(); + + private @Nullable T cachedResult = null; + + private final CompletableFuture future; + + private @Nullable Supplier task; + + private final List> onCompletion = new ArrayList<>(); + + private final AtomicReference startedAt = new AtomicReference<>(null); + + private @Nullable String name = null; + + private ThrottledFuture(@NonNull final Supplier task) { + this.task = task; + future = new CompletableFuture<>(); + } + + protected ThrottledFuture() { + future = new CompletableFuture<>(); + } + + public static ThrottledFuture of(@NonNull final Supplier supplier) { + return new ThrottledFuture<>(supplier); + } + + public static ThrottledFuture of(@NonNull final Runnable runnable) { + return new ThrottledFuture<>(() -> { + runnable.run(); + return null; + }); + } + + /** + * @return already canceled future + */ + public static ThrottledFuture canceled() { + ThrottledFuture f = new ThrottledFuture<>(); + f.cancel(true); + return f; + } + + /** + * @return already done future + */ + public static ThrottledFuture done(T result) { + ThrottledFuture f = ThrottledFuture.of(() -> result); + f.run(null); + return f; + } + + @Override + public Optional getCachedResult() { + return Optional.ofNullable(cachedResult); + } + + @Override + public ThrottledFuture setCachedResult(@Nullable final T cachedResult) { + this.cachedResult = cachedResult; + return this; + } + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return future.cancel(mayInterruptIfRunning); + } + + @Override + public boolean isCancelled() { + return future.isCancelled(); + } + + @Override + public boolean isDone() { + return future.isDone(); + } + + /** + * Does not execute the task, blocks the current thread until the result is available. + */ + @Override + public T get() throws InterruptedException, ExecutionException { + return future.get(); + } + + /** + * Does not execute the task, blocks the current thread until the result is available. + */ + @Override + public T get(long timeout, @NonNull TimeUnit unit) + throws InterruptedException, ExecutionException, TimeoutException { + return future.get(timeout, unit); + } + /** + * @param task the new task + * @return If the current task is already running, was canceled or already completed, returns a new future for the given task. + * Otherwise, replaces the current task and returns self. + */ + protected ThrottledFuture update(Supplier task, @NonNull List> onCompletion) { + boolean locked = false; + try { + locked = lock.tryLock(); + this.callbackLock.lock(); + ThrottledFuture updatedFuture = this; + if (!locked || isRunning() || isDone()) { + updatedFuture = ThrottledFuture.of(task); + } + updatedFuture.task = task; + updatedFuture.onCompletion.addAll(onCompletion); + return updatedFuture; + } finally { + if (locked) { + lock.unlock(); + } + this.callbackLock.unlock(); + } + } + + + /** + * Returns future with the task from the specified {@code throttledFuture}. + * If possible, transfers the task from this object to the specified {@code throttledFuture}. + * If the task was successfully transferred, this future is canceled. + * + * @param target the future to update + * @return target when current future is already being executed, was canceled or completed. + * New future when the target is being executed, was canceled or completed. + */ + protected ThrottledFuture transfer(ThrottledFuture target) { + boolean locked = false; + try { + locked = lock.tryLock(); + this.callbackLock.lock(); + if (!locked || isRunning() || isDone()) { + return target; + } + + ThrottledFuture result = target.update(this.task, this.onCompletion); + this.task = null; + this.onCompletion.clear(); + this.cancel(false); + return result; + } finally { + if (locked) { + lock.unlock(); + } + this.callbackLock.unlock(); + } + } + + /** + * Executes the task associated with this future + * @param startedCallback called once {@link #startedAt} is set and so execution is considered as running. + */ + protected void run(@Nullable Consumer> startedCallback) { + boolean locked = false; + try { + do { + locked = lock.tryLock(); + if (isRunning() || isDone()) { + return; + } else if (!locked) { + Thread.yield(); + } + } while (!locked); + + startedAt.set(Utils.timestamp()); + if (startedCallback != null) { + startedCallback.accept(this); + } + + try { + T result = null; + if (task != null) { + result = task.get(); + final T finalResult = result; + callbackLock.lock(); + onCompletion.forEach(c -> c.accept(finalResult)); + callbackLock.unlock(); + } + future.complete(result); + } catch (Exception e) { + future.completeExceptionally(e); + } + } finally { + if (locked) { + lock.unlock(); + } + } + } + + @Override + public @Nullable String getName() { + return this.name; + } + + protected void setName(@Nullable String name) { + this.name = name; + } + + @Override + public boolean isRunning() { + return startedAt.get() != null && !isDone(); + } + + @Override + public @NonNull Optional startedAt() { + return Optional.ofNullable(startedAt.get()); + } + + @Override + public @NonNull UUID getUuid() { + return uuid; + } + + @Override + public ThrottledFuture then(Consumer action) { + try { + callbackLock.lock(); + if (future.isDone() && !future.isCancelled()) { + try { + action.accept(future.get()); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new TermItException(e); + } catch (ExecutionException e) { + throw new TermItException(e); + } + } else { + onCompletion.add(action); + } + } finally { + callbackLock.unlock(); + } + return this; + } +} diff --git a/src/main/java/cz/cvut/kbss/termit/websocket/BaseWebSocketController.java b/src/main/java/cz/cvut/kbss/termit/websocket/BaseWebSocketController.java new file mode 100644 index 000000000..55f152033 --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/websocket/BaseWebSocketController.java @@ -0,0 +1,85 @@ +package cz.cvut.kbss.termit.websocket; + +import cz.cvut.kbss.termit.rest.BaseController; +import cz.cvut.kbss.termit.service.IdentifierResolver; +import cz.cvut.kbss.termit.util.Configuration; +import org.springframework.lang.NonNull; +import org.springframework.messaging.MessageHeaders; +import org.springframework.messaging.simp.SimpMessageHeaderAccessor; +import org.springframework.messaging.simp.SimpMessagingTemplate; +import org.springframework.messaging.simp.stomp.StompCommand; +import org.springframework.messaging.simp.stomp.StompHeaderAccessor; +import org.springframework.messaging.simp.user.DestinationUserNameProvider; +import org.springframework.util.LinkedMultiValueMap; + +import java.security.Principal; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +import static org.springframework.messaging.support.NativeMessageHeaderAccessor.NATIVE_HEADERS; + +public class BaseWebSocketController extends BaseController { + + protected final SimpMessagingTemplate messagingTemplate; + + protected BaseWebSocketController(IdentifierResolver idResolver, Configuration config, + SimpMessagingTemplate messagingTemplate) { + super(idResolver, config); + this.messagingTemplate = messagingTemplate; + } + + /** + * Resolves session id, when present, and sends to the specific session. + * When session id is not present, sends it to all sessions of specific user. + * + * @param destination the destination (without user prefix) + * @param payload payload to send + * @param replyHeaders native headers for the reply + * @param sourceHeaders original headers containing session id or name of the user + */ + protected void sendToSession(@NonNull String destination, @NonNull Object payload, + @NonNull Map replyHeaders, @NonNull MessageHeaders sourceHeaders) { + getSessionId(sourceHeaders) + .ifPresentOrElse(sessionId -> { // session id present + StompHeaderAccessor headerAccessor = StompHeaderAccessor.create(StompCommand.MESSAGE); + // add reply headers as native headers + headerAccessor.setHeader(NATIVE_HEADERS, new LinkedMultiValueMap<>(replyHeaders.size())); + replyHeaders.forEach((name, value) -> headerAccessor.addNativeHeader(name, Objects.toString(value))); + headerAccessor.setSessionId(sessionId); // pass session id to new headers + // send to user session + messagingTemplate.convertAndSendToUser(sessionId, destination, payload, headerAccessor.toMessageHeaders()); + }, + // session id not present, send to all user sessions + () -> getUser(sourceHeaders).ifPresent(user -> messagingTemplate.convertAndSendToUser(user, destination, payload, replyHeaders)) + ); + } + + /** + * Resolves name which can be used to send a message to the user with {@link SimpMessagingTemplate#convertAndSendToUser}. + * + * @return name or session id, or empty when information is not available. + */ + protected @NonNull Optional getUser(@NonNull MessageHeaders messageHeaders) { + return getUserName(messageHeaders).or(() -> getSessionId(messageHeaders)); + } + + private @NonNull Optional getSessionId(@NonNull MessageHeaders messageHeaders) { + return Optional.ofNullable(SimpMessageHeaderAccessor.getSessionId(messageHeaders)); + } + + /** + * Resolves the name of the user + * + * @return the name or null + */ + private @NonNull Optional getUserName(MessageHeaders headers) { + Principal principal = SimpMessageHeaderAccessor.getUser(headers); + if (principal != null) { + final String name = (principal instanceof DestinationUserNameProvider provider ? + provider.getDestinationUserName() : principal.getName()); + return Optional.ofNullable(name); + } + return Optional.empty(); + } +} diff --git a/src/main/java/cz/cvut/kbss/termit/websocket/LongRunningTasksWebSocketController.java b/src/main/java/cz/cvut/kbss/termit/websocket/LongRunningTasksWebSocketController.java new file mode 100644 index 000000000..f3d3ac18c --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/websocket/LongRunningTasksWebSocketController.java @@ -0,0 +1,42 @@ +package cz.cvut.kbss.termit.websocket; + +import cz.cvut.kbss.termit.event.LongRunningTaskChangedEvent; +import cz.cvut.kbss.termit.security.SecurityConstants; +import cz.cvut.kbss.termit.service.IdentifierResolver; +import cz.cvut.kbss.termit.util.Configuration; +import cz.cvut.kbss.termit.util.longrunning.LongRunningTasksRegistry; +import org.springframework.context.event.EventListener; +import org.springframework.lang.NonNull; +import org.springframework.messaging.MessageHeaders; +import org.springframework.messaging.handler.annotation.MessageMapping; +import org.springframework.messaging.simp.SimpMessagingTemplate; +import org.springframework.messaging.simp.annotation.SubscribeMapping; +import org.springframework.security.access.prepost.PreAuthorize; +import org.springframework.stereotype.Controller; + +import java.util.Map; + +@Controller +@MessageMapping("/long-running-tasks") +@PreAuthorize("hasRole('" + SecurityConstants.ROLE_RESTRICTED_USER + "')") +public class LongRunningTasksWebSocketController extends BaseWebSocketController { + + private final LongRunningTasksRegistry registry; + + protected LongRunningTasksWebSocketController(IdentifierResolver idResolver, Configuration config, + SimpMessagingTemplate messagingTemplate, + LongRunningTasksRegistry registry) { + super(idResolver, config, messagingTemplate); + this.registry = registry; + } + + @SubscribeMapping("/update") + public void tasksRequest(@NonNull MessageHeaders messageHeaders) { + sendToSession(WebSocketDestinations.LONG_RUNNING_TASKS_UPDATE, registry.getTasks(), Map.of(), messageHeaders); + } + + @EventListener(LongRunningTaskChangedEvent.class) + public void onTaskChanged() { + messagingTemplate.convertAndSend(WebSocketDestinations.LONG_RUNNING_TASKS_UPDATE, registry.getTasks()); + } +} diff --git a/src/main/java/cz/cvut/kbss/termit/websocket/ResultWithHeaders.java b/src/main/java/cz/cvut/kbss/termit/websocket/ResultWithHeaders.java deleted file mode 100644 index 8d74eeb72..000000000 --- a/src/main/java/cz/cvut/kbss/termit/websocket/ResultWithHeaders.java +++ /dev/null @@ -1,68 +0,0 @@ -package cz.cvut.kbss.termit.websocket; - -import cz.cvut.kbss.termit.websocket.handler.WebSocketMessageWithHeadersValueHandler; -import org.jetbrains.annotations.NotNull; -import org.jetbrains.annotations.Nullable; -import org.springframework.messaging.handler.annotation.SendTo; -import org.springframework.messaging.handler.invocation.HandlerMethodReturnValueHandler; -import org.springframework.messaging.simp.annotation.SendToUser; - -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; - -/** - * Wrapper carrying a result from WebSocket controller - * including the {@link #payload}, {@link #destination} and {@link #headers} for the resulting message. - *

- * Do not combine with other method-return-value handlers (like {@link SendTo @SendTo}) - *

- * The {@code ResultWithHeaders} is then handled by {@link WebSocketMessageWithHeadersValueHandler}. - * Every value returned from a controller method - * can be handled only by a single {@link HandlerMethodReturnValueHandler}. - * Annotations like {@link SendTo @SendTo}/{@link SendToUser @SendToUser} - * are handled by separate return value handlers, so only one can be used simultaneously. - * - * @param payload The actual result of the method - * @param destination The destination channel where the message will be sent - * @param headers Headers that will overwrite headers in the message. - * @param The type of the payload - * @see WebSocketMessageWithHeadersValueHandler - * @see HandlerMethodReturnValueHandler - */ -public record ResultWithHeaders(T payload, @NotNull String destination, @NotNull Map headers, - boolean toUser) { - - public static ResultWithHeadersBuilder result(T payload) { - return new ResultWithHeadersBuilder<>(payload); - } - - public static class ResultWithHeadersBuilder { - - private final T payload; - - private @Nullable Map headers = null; - - private ResultWithHeadersBuilder(T payload) { - this.payload = payload; - } - - /** - * All values will be mapped to strings with {@link Object#toString()} - */ - public ResultWithHeadersBuilder withHeaders(@NotNull Map headers) { - this.headers = new HashMap<>(); - headers.forEach((key, value) -> this.headers.put(key, value.toString())); - this.headers = Collections.unmodifiableMap(this.headers); - return this; - } - - public ResultWithHeaders sendTo(String destination) { - return new ResultWithHeaders<>(payload, destination, headers == null ? Map.of() : headers, false); - } - - public ResultWithHeaders sendToUser(String userDestination) { - return new ResultWithHeaders<>(payload, userDestination, headers == null ? Map.of() : headers, true); - } - } -} diff --git a/src/main/java/cz/cvut/kbss/termit/websocket/VocabularySocketController.java b/src/main/java/cz/cvut/kbss/termit/websocket/VocabularySocketController.java index 7e49b35e0..57578f45b 100644 --- a/src/main/java/cz/cvut/kbss/termit/websocket/VocabularySocketController.java +++ b/src/main/java/cz/cvut/kbss/termit/websocket/VocabularySocketController.java @@ -1,49 +1,124 @@ package cz.cvut.kbss.termit.websocket; +import cz.cvut.kbss.termit.event.FileTextAnalysisFinishedEvent; +import cz.cvut.kbss.termit.event.TermDefinitionTextAnalysisFinishedEvent; +import cz.cvut.kbss.termit.event.VocabularyEvent; +import cz.cvut.kbss.termit.event.VocabularyValidationFinishedEvent; import cz.cvut.kbss.termit.model.Vocabulary; import cz.cvut.kbss.termit.model.validation.ValidationResult; -import cz.cvut.kbss.termit.rest.BaseController; import cz.cvut.kbss.termit.security.SecurityConstants; import cz.cvut.kbss.termit.service.IdentifierResolver; import cz.cvut.kbss.termit.service.business.VocabularyService; import cz.cvut.kbss.termit.util.Configuration; import cz.cvut.kbss.termit.util.Constants; +import cz.cvut.kbss.termit.util.throttle.CacheableFuture; +import org.springframework.context.event.EventListener; +import org.springframework.lang.NonNull; +import org.springframework.messaging.MessageHeaders; import org.springframework.messaging.handler.annotation.DestinationVariable; import org.springframework.messaging.handler.annotation.Header; import org.springframework.messaging.handler.annotation.MessageMapping; +import org.springframework.messaging.simp.SimpMessagingTemplate; import org.springframework.security.access.prepost.PreAuthorize; import org.springframework.stereotype.Controller; import java.net.URI; -import java.util.List; +import java.util.Collection; +import java.util.HashMap; import java.util.Map; import java.util.Optional; -import static cz.cvut.kbss.termit.websocket.ResultWithHeaders.result; - @Controller @MessageMapping("/vocabularies") @PreAuthorize("hasRole('" + SecurityConstants.ROLE_RESTRICTED_USER + "')") -public class VocabularySocketController extends BaseController { +public class VocabularySocketController extends BaseWebSocketController { private final VocabularyService vocabularyService; protected VocabularySocketController(IdentifierResolver idResolver, Configuration config, - VocabularyService vocabularyService) { - super(idResolver, config); + SimpMessagingTemplate messagingTemplate, VocabularyService vocabularyService) { + super(idResolver, config, messagingTemplate); this.vocabularyService = vocabularyService; } /** * Validates the terms in a vocabulary with the specified identifier. + * Immediately responds with a result from the cache, if available. */ @MessageMapping("/{localName}/validate") - public ResultWithHeaders> validateVocabulary(@DestinationVariable String localName, - @Header(name = Constants.QueryParams.NAMESPACE, - required = false) Optional namespace) { + public void validateVocabulary(@DestinationVariable String localName, + @Header(name = Constants.QueryParams.NAMESPACE, + required = false) Optional namespace, + @NonNull MessageHeaders messageHeaders) { final URI identifier = resolveIdentifier(namespace.orElse(config.getNamespace().getVocabulary()), localName); final Vocabulary vocabulary = vocabularyService.getReference(identifier); - return result(vocabularyService.validateContents(vocabulary)).withHeaders(Map.of("vocabulary", identifier)) - .sendToUser("/vocabularies/validation"); + + final CacheableFuture> future = vocabularyService.validateContents(vocabulary.getUri()); + + future.getNow().ifPresentOrElse(validationResults -> + // if there is a result present (returned from cache), send it + sendToSession( + WebSocketDestinations.VOCABULARIES_VALIDATION, + validationResults, + getHeaders(identifier, + // results are cached if we received a future result, but the future is not done yet + Map.of("cached", !future.isDone())), + messageHeaders + ), () -> + // otherwise reply will be sent once the future is resolved + future.then(results -> + sendToSession( + WebSocketDestinations.VOCABULARIES_VALIDATION, + results, + getHeaders(identifier, + Map.of("cached", false)), + messageHeaders + )) + ); + + } + + /** + * Publishes results of validation to users. + */ + @EventListener + public void onVocabularyValidationFinished(VocabularyValidationFinishedEvent event) { + messagingTemplate.convertAndSend( + WebSocketDestinations.VOCABULARIES_VALIDATION, + event.getValidationResults(), + getHeaders(event.getVocabularyIri(), Map.of("cached", false)) + ); + } + + @EventListener + public void onFileTextAnalysisFinished(FileTextAnalysisFinishedEvent event) { + messagingTemplate.convertAndSend( + WebSocketDestinations.VOCABULARIES_TEXT_ANALYSIS_FINISHED_FILE, + event.getFileUri(), + getHeaders(event) + ); + } + + @EventListener + public void onTermDefinitionTextAnalysisFinished(TermDefinitionTextAnalysisFinishedEvent event) { + messagingTemplate.convertAndSend( + WebSocketDestinations.VOCABULARIES_TEXT_ANALYSIS_FINISHED_TERM_DEFINITION, + event.getTermUri(), + getHeaders(event) + ); + } + + protected @NonNull Map getHeaders(@NonNull VocabularyEvent event) { + return getHeaders(event.getVocabularyIri()); + } + + protected @NonNull Map getHeaders(@NonNull URI vocabularyUri) { + return getHeaders(vocabularyUri, Map.of()); + } + + protected @NonNull Map getHeaders(@NonNull URI vocabularyUri, Map headers) { + final Map headersMap = new HashMap<>(headers); + headersMap.put("vocabulary", vocabularyUri); + return headersMap; } } diff --git a/src/main/java/cz/cvut/kbss/termit/websocket/WebSocketDestinations.java b/src/main/java/cz/cvut/kbss/termit/websocket/WebSocketDestinations.java new file mode 100644 index 000000000..c5a525347 --- /dev/null +++ b/src/main/java/cz/cvut/kbss/termit/websocket/WebSocketDestinations.java @@ -0,0 +1,30 @@ +package cz.cvut.kbss.termit.websocket; + +public final class WebSocketDestinations { + + /** + * Used for publishing results of validation from server to clients + */ + public static final String VOCABULARIES_VALIDATION = "/vocabularies/validation"; + + private static final String VOCABULARIES_TEXT_ANALYSIS_FINISHED = "/vocabularies/text_analysis/finished"; + + /** + * Used for notifying clients about a text analysis end + */ + public static final String VOCABULARIES_TEXT_ANALYSIS_FINISHED_FILE = VOCABULARIES_TEXT_ANALYSIS_FINISHED + "/file"; + + /** + * Used for notifying clients about a text analysis end + */ + public static final String VOCABULARIES_TEXT_ANALYSIS_FINISHED_TERM_DEFINITION = VOCABULARIES_TEXT_ANALYSIS_FINISHED + "/term-definition"; + + /** + * Used for pushing updates about long-running tasks to clients + */ + public static final String LONG_RUNNING_TASKS_UPDATE = "/long-running-tasks/update"; + + private WebSocketDestinations() { + throw new AssertionError(); + } +} diff --git a/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketExceptionHandler.java b/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketExceptionHandler.java index e94b99450..caf20ffdd 100644 --- a/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketExceptionHandler.java +++ b/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketExceptionHandler.java @@ -34,8 +34,11 @@ import org.springframework.security.core.AuthenticationException; import org.springframework.security.core.userdetails.UsernameNotFoundException; import org.springframework.web.bind.annotation.ControllerAdvice; +import org.springframework.web.context.request.async.AsyncRequestNotUsableException; import org.springframework.web.multipart.MaxUploadSizeExceededException; +import static cz.cvut.kbss.termit.util.ExceptionUtils.isCausedBy; + /** * @implSpec Should reflect {@link cz.cvut.kbss.termit.rest.handler.RestExceptionHandler} */ @@ -70,7 +73,10 @@ private static void logException(Throwable ex, Message message) { } private static void logException(String message, Throwable ex) { - LOG.error(message, ex); + // prevents from logging exceptions caused be broken connection with a client + if (!isCausedBy(ex, AsyncRequestNotUsableException.class)) { + LOG.error(message, ex); + } } private static ErrorInfo errorInfo(Message message, Throwable e) { @@ -132,7 +138,8 @@ public ErrorInfo authorizationException(Message message, AuthorizationExcepti @MessageExceptionHandler(AuthenticationException.class) public ErrorInfo authenticationException(Message message, AuthenticationException e) { LOG.atDebug().setCause(e).log(e.getMessage()); - LOG.error("Authentication failure during message processing: {}\nMessage: {}", e.getMessage(), message.toString()); + LOG.atError().setMessage("Authentication failure during message processing: {}\nMessage: {}") + .addArgument(e.getMessage()).addArgument(message::toString).log(); return errorInfo(message, e); } @@ -141,13 +148,11 @@ public ErrorInfo authenticationException(Message message, AuthenticationExcep */ @MessageExceptionHandler(AccessDeniedException.class) public ErrorInfo accessDeniedException(Message message, AccessDeniedException e) { - LOG.atWarn().setMessage("[{}] Unauthorized access: {}").addArgument(() -> { - StompHeaderAccessor accessor = StompHeaderAccessor.wrap(message); - if (accessor.getUser() != null) { - return accessor.getUser().getName(); - } - return "(unknown user)"; - }).addArgument(e.getMessage()).log(); + StompHeaderAccessor accessor = StompHeaderAccessor.wrap(message); + if (accessor.getUser() != null) { + LOG.atWarn().setMessage("[{}] Unauthorized access: {}").addArgument(() -> accessor.getUser().getName()) + .addArgument(e.getMessage()).log(); + } return errorInfo(message, e); } diff --git a/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketMessageWithHeadersValueHandler.java b/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketMessageWithHeadersValueHandler.java deleted file mode 100644 index 5494294f2..000000000 --- a/src/main/java/cz/cvut/kbss/termit/websocket/handler/WebSocketMessageWithHeadersValueHandler.java +++ /dev/null @@ -1,46 +0,0 @@ -package cz.cvut.kbss.termit.websocket.handler; - -import cz.cvut.kbss.termit.exception.UnsupportedOperationException; -import cz.cvut.kbss.termit.websocket.ResultWithHeaders; -import org.jetbrains.annotations.NotNull; -import org.springframework.core.MethodParameter; -import org.springframework.messaging.Message; -import org.springframework.messaging.handler.invocation.HandlerMethodReturnValueHandler; -import org.springframework.messaging.simp.SimpMessageHeaderAccessor; -import org.springframework.messaging.simp.SimpMessagingTemplate; -import org.springframework.messaging.simp.annotation.support.MissingSessionUserException; -import org.springframework.messaging.simp.stomp.StompHeaderAccessor; - -public class WebSocketMessageWithHeadersValueHandler implements HandlerMethodReturnValueHandler { - - private final SimpMessagingTemplate simpMessagingTemplate; - - public WebSocketMessageWithHeadersValueHandler(SimpMessagingTemplate simpMessagingTemplate) { - this.simpMessagingTemplate = simpMessagingTemplate; - } - - @Override - public boolean supportsReturnType(MethodParameter returnType) { - return ResultWithHeaders.class.isAssignableFrom(returnType.getParameterType()); - } - - @Override - public void handleReturnValue(Object returnValue, @NotNull MethodParameter returnType, @NotNull Message message) - throws Exception { - if (returnValue instanceof ResultWithHeaders resultWithHeaders) { - final StompHeaderAccessor headerAccessor = StompHeaderAccessor.wrap(message); - resultWithHeaders.headers().forEach(headerAccessor::setNativeHeader); - if (resultWithHeaders.toUser()) { - final String sessionId = SimpMessageHeaderAccessor.getSessionId(headerAccessor.toMessageHeaders()); - if (sessionId == null || sessionId.isBlank()) { - throw new MissingSessionUserException(message); - } - simpMessagingTemplate.convertAndSendToUser(sessionId, resultWithHeaders.destination(), resultWithHeaders.payload(), headerAccessor.toMessageHeaders()); - } else { - simpMessagingTemplate.convertAndSend(resultWithHeaders.destination(), resultWithHeaders.payload(), headerAccessor.toMessageHeaders()); - } - return; - } - throw new UnsupportedOperationException("Unable to process returned value: " + returnValue + " of type " + returnType.getParameterType() + " from " + returnType.getMethod()); - } -} diff --git a/src/main/resources/spring-aop.xml b/src/main/resources/spring-aop.xml new file mode 100644 index 000000000..a33bec64d --- /dev/null +++ b/src/main/resources/spring-aop.xml @@ -0,0 +1,22 @@ + + + + AOP related definitions + + + + + + + + + + diff --git a/src/test/java/cz/cvut/kbss/termit/environment/config/TestRestSecurityConfig.java b/src/test/java/cz/cvut/kbss/termit/environment/config/TestRestSecurityConfig.java index c5e6fc891..b54a9ed4c 100644 --- a/src/test/java/cz/cvut/kbss/termit/environment/config/TestRestSecurityConfig.java +++ b/src/test/java/cz/cvut/kbss/termit/environment/config/TestRestSecurityConfig.java @@ -30,6 +30,7 @@ import org.springframework.boot.test.context.TestConfiguration; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Import; +import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.security.authentication.AuthenticationProvider; import static org.mockito.Mockito.mock; @@ -73,4 +74,9 @@ public AuthenticationProvider authenticationProvider() { public TermItUserDetailsService userDetailsService() { return mock(TermItUserDetailsService.class); } + + @Bean + public ThreadPoolTaskScheduler longRunningTaskScheduler() { + return mock(ThreadPoolTaskScheduler.class); + } } diff --git a/src/test/java/cz/cvut/kbss/termit/persistence/dao/TermDaoTest.java b/src/test/java/cz/cvut/kbss/termit/persistence/dao/TermDaoTest.java index 506ad4d67..461f2576b 100644 --- a/src/test/java/cz/cvut/kbss/termit/persistence/dao/TermDaoTest.java +++ b/src/test/java/cz/cvut/kbss/termit/persistence/dao/TermDaoTest.java @@ -27,7 +27,7 @@ import cz.cvut.kbss.termit.environment.Generator; import cz.cvut.kbss.termit.event.AssetPersistEvent; import cz.cvut.kbss.termit.event.AssetUpdateEvent; -import cz.cvut.kbss.termit.event.VocabularyContentModified; +import cz.cvut.kbss.termit.event.VocabularyContentModifiedEvent; import cz.cvut.kbss.termit.model.Asset; import cz.cvut.kbss.termit.model.Term; import cz.cvut.kbss.termit.model.Term_; @@ -369,9 +369,9 @@ void persistPublishesVocabularyContentModifiedEvent() { final ArgumentCaptor captor = ArgumentCaptor.forClass(ApplicationEvent.class); verify(eventPublisher, atLeastOnce()).publishEvent(captor.capture()); - final Optional evt = captor.getAllValues().stream() - .filter(VocabularyContentModified.class::isInstance) - .map(VocabularyContentModified.class::cast).findFirst(); + final Optional evt = captor.getAllValues().stream() + .filter(VocabularyContentModifiedEvent.class::isInstance) + .map(VocabularyContentModifiedEvent.class::cast).findFirst(); assertTrue(evt.isPresent()); assertEquals(vocabulary.getUri(), evt.get().getVocabularyIri()); } @@ -430,9 +430,9 @@ void updatePublishesVocabularyContentModifiedEvent() { transactional(() -> sut.update(term)); final ArgumentCaptor captor = ArgumentCaptor.forClass(ApplicationEvent.class); verify(eventPublisher, atLeastOnce()).publishEvent(captor.capture()); - final Optional evt = captor.getAllValues().stream() - .filter(VocabularyContentModified.class::isInstance) - .map(VocabularyContentModified.class::cast).findFirst(); + final Optional evt = captor.getAllValues().stream() + .filter(VocabularyContentModifiedEvent.class::isInstance) + .map(VocabularyContentModifiedEvent.class::cast).findFirst(); assertTrue(evt.isPresent()); assertEquals(vocabulary.getUri(), evt.get().getVocabularyIri()); } @@ -1306,9 +1306,9 @@ void removePublishesVocabularyContentModifiedEvent() { transactional(() -> sut.remove(term)); final ArgumentCaptor captor = ArgumentCaptor.forClass(ApplicationEvent.class); verify(eventPublisher, atLeastOnce()).publishEvent(captor.capture()); - final Optional evt = captor.getAllValues().stream() - .filter(VocabularyContentModified.class::isInstance) - .map(VocabularyContentModified.class::cast).findFirst(); + final Optional evt = captor.getAllValues().stream() + .filter(VocabularyContentModifiedEvent.class::isInstance) + .map(VocabularyContentModifiedEvent.class::cast).findFirst(); assertTrue(evt.isPresent()); assertEquals(vocabulary.getUri(), evt.get().getVocabularyIri()); } diff --git a/src/test/java/cz/cvut/kbss/termit/persistence/dao/TermOccurrenceDaoTest.java b/src/test/java/cz/cvut/kbss/termit/persistence/dao/TermOccurrenceDaoTest.java index da7f09793..bb1887a51 100644 --- a/src/test/java/cz/cvut/kbss/termit/persistence/dao/TermOccurrenceDaoTest.java +++ b/src/test/java/cz/cvut/kbss/termit/persistence/dao/TermOccurrenceDaoTest.java @@ -33,7 +33,6 @@ import cz.cvut.kbss.termit.model.resource.Document; import cz.cvut.kbss.termit.model.resource.File; import cz.cvut.kbss.termit.model.selector.TextQuoteSelector; -import cz.cvut.kbss.termit.persistence.dao.util.ScheduledContextRemover; import cz.cvut.kbss.termit.util.Vocabulary; import org.eclipse.rdf4j.model.ValueFactory; import org.eclipse.rdf4j.model.vocabulary.RDFS; @@ -73,9 +72,6 @@ class TermOccurrenceDaoTest extends BaseDaoTestRunner { @Autowired private EntityManager em; - @Autowired - private ScheduledContextRemover contextRemover; - @Autowired private TermOccurrenceDao sut; @@ -270,7 +266,6 @@ void removeAllRemovesSuggestedAndConfirmedOccurrences() { }))); transactional(() -> { sut.removeAll(file); - contextRemover.runContextRemoval(); }); assertTrue(sut.findAllTargeting(file).isEmpty()); assertFalse(em.createNativeQuery("ASK { ?x a ?termOccurrence . }", Boolean.class).setParameter("termOccurrence", @@ -292,7 +287,6 @@ void removeAllRemovesAlsoOccurrenceTargets() { }))); transactional(() -> { sut.removeAll(file); - contextRemover.runContextRemoval(); }); assertFalse(em.createNativeQuery("ASK { ?x a ?target . }", Boolean.class).setParameter("target", diff --git a/src/test/java/cz/cvut/kbss/termit/persistence/dao/VocabularyDaoTest.java b/src/test/java/cz/cvut/kbss/termit/persistence/dao/VocabularyDaoTest.java index 5655a6011..23b72777c 100644 --- a/src/test/java/cz/cvut/kbss/termit/persistence/dao/VocabularyDaoTest.java +++ b/src/test/java/cz/cvut/kbss/termit/persistence/dao/VocabularyDaoTest.java @@ -334,7 +334,7 @@ void getTransitivelyImportedVocabulariesReturnsAllImportedVocabulariesForVocabul em.persist(transitiveVocabulary, descriptorFactory.vocabularyDescriptor(transitiveVocabulary)); }); - final Collection result = sut.getTransitivelyImportedVocabularies(subjectVocabulary); + final Collection result = sut.getTransitivelyImportedVocabularies(subjectVocabulary.getUri()); assertEquals(3, result.size()); assertTrue(result.contains(importedVocabularyOne.getUri())); assertTrue(result.contains(importedVocabularyTwo.getUri())); @@ -372,7 +372,7 @@ void persistPublishesAssetPersistEvent() { transactional(() -> sut.persist(voc)); final ArgumentCaptor captor = ArgumentCaptor.forClass(ApplicationEvent.class); - verify(eventPublisher).publishEvent(captor.capture()); + verify(eventPublisher, atLeastOnce()).publishEvent(captor.capture()); final Optional evt = captor.getAllValues().stream() .filter(AssetPersistEvent.class::isInstance) .map(AssetPersistEvent.class::cast).findFirst(); @@ -767,7 +767,7 @@ void removePublishesEventAndDropsGraph() { VocabularyWillBeRemovedEvent event = eventCaptor.getValue(); assertNotNull(event); - assertEquals(event.getVocabulary(), vocabulary.getUri()); + assertEquals(event.getVocabularyIri(), vocabulary.getUri()); assertFalse(em.createNativeQuery("ASK WHERE{ GRAPH ?vocabulary { ?s ?p ?o }}", Boolean.class) .setParameter("vocabulary", vocabulary.getUri()) diff --git a/src/test/java/cz/cvut/kbss/termit/persistence/dao/util/ScheduledContextRemoverTest.java b/src/test/java/cz/cvut/kbss/termit/persistence/dao/util/ScheduledContextRemoverTest.java deleted file mode 100644 index cb2a78a92..000000000 --- a/src/test/java/cz/cvut/kbss/termit/persistence/dao/util/ScheduledContextRemoverTest.java +++ /dev/null @@ -1,48 +0,0 @@ -package cz.cvut.kbss.termit.persistence.dao.util; - -import cz.cvut.kbss.jopa.model.EntityManager; -import cz.cvut.kbss.jopa.vocabulary.RDFS; -import cz.cvut.kbss.termit.environment.Generator; -import cz.cvut.kbss.termit.persistence.dao.BaseDaoTestRunner; -import org.junit.jupiter.api.Test; -import org.springframework.beans.factory.annotation.Autowired; - -import java.net.URI; -import java.util.HashSet; -import java.util.Set; - -import static org.junit.jupiter.api.Assertions.assertFalse; - -class ScheduledContextRemoverTest extends BaseDaoTestRunner { - - @Autowired - private EntityManager em; - - @Autowired - private ScheduledContextRemover sut; - - @Test - void runContextRemovalDropsContextsRegisteredForRemoval() { - final Set graphs = generateGraphs(); - graphs.forEach(sut::scheduleForRemoval); - - sut.runContextRemoval(); - graphs.forEach(g -> assertFalse( - em.createNativeQuery("ASK { ?g ?y ?z . }", Boolean.class).setParameter("g", g).getSingleResult())); - } - - private Set generateGraphs() { - final Set result = new HashSet<>(); - transactional(() -> { - for (int i = 0; i < 5; i++) { - final URI graphUri = Generator.generateUri(); - em.createNativeQuery("INSERT DATA { GRAPH ?g { ?g a ?type } }", Void.class) - .setParameter("g", graphUri) - .setParameter("type", URI.create(RDFS.RESOURCE)) - .executeUpdate(); - result.add(graphUri); - } - }); - return result; - } -} diff --git a/src/test/java/cz/cvut/kbss/termit/persistence/validation/ResultCachingValidatorTest.java b/src/test/java/cz/cvut/kbss/termit/persistence/validation/ResultCachingValidatorTest.java index 8198751a2..b8cc2e42c 100644 --- a/src/test/java/cz/cvut/kbss/termit/persistence/validation/ResultCachingValidatorTest.java +++ b/src/test/java/cz/cvut/kbss/termit/persistence/validation/ResultCachingValidatorTest.java @@ -18,8 +18,10 @@ package cz.cvut.kbss.termit.persistence.validation; import cz.cvut.kbss.termit.environment.Generator; -import cz.cvut.kbss.termit.event.VocabularyContentModified; +import cz.cvut.kbss.termit.event.VocabularyContentModifiedEvent; +import cz.cvut.kbss.termit.model.Term; import cz.cvut.kbss.termit.model.validation.ValidationResult; +import cz.cvut.kbss.termit.util.throttle.ThrottledFuture; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -27,16 +29,21 @@ import org.mockito.junit.jupiter.MockitoExtension; import java.net.URI; +import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Set; +import static cz.cvut.kbss.termit.util.throttle.TestFutureRunner.runFuture; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertIterableEquals; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.anyCollection; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; @ExtendWith(MockitoExtension.class) @@ -47,42 +54,54 @@ class ResultCachingValidatorTest { private ResultCachingValidator sut; + private URI vocabulary; + + private ValidationResult validationResult; + @BeforeEach void setUp() { this.sut = spy(new ResultCachingValidator()); when(sut.getValidator()).thenReturn(validator); + + vocabulary = Generator.generateUri(); + Term term = Generator.generateTermWithId(vocabulary); + validationResult = new ValidationResult().setTermUri(term.getUri()); } @Test - void invokesInternalValidatorWhenNoResultsAreCached() { - final List results = Collections.singletonList(new ValidationResult()); - when(validator.validate(anyCollection())).thenReturn(results); - final Set vocabularies = Collections.singleton(Generator.generateUri()); - final List result = sut.validate(vocabularies); + void invokesInternalValidatorWhenNoResultsAreCached() throws Exception { + final List results = Collections.singletonList(validationResult); + when(validator.validate(any(), anyCollection())).thenReturn(ThrottledFuture.done(results)); + final Set vocabularies = Collections.singleton(vocabulary); + final Collection result = runFuture(sut.validate(vocabulary, vocabularies)); assertEquals(results, result); - verify(validator).validate(vocabularies); + verify(validator).validate(vocabulary, vocabularies); } @Test - void returnsCachedResultsWhenArgumentsMatch() { - final List results = Collections.singletonList(new ValidationResult()); - when(validator.validate(anyCollection())).thenReturn(results); - final Set vocabularies = Collections.singleton(Generator.generateUri()); - final List resultOne = sut.validate(vocabularies); - final List resultTwo = sut.validate(vocabularies); - assertEquals(resultOne, resultTwo); - verify(validator).validate(vocabularies); + void returnsCachedResultsWhenArgumentsMatch() throws Exception { + final List results = Collections.singletonList(validationResult); + when(validator.validate(any(), anyCollection())).thenReturn(ThrottledFuture.done(results)); + final Set vocabularies = Collections.singleton(vocabulary); + final Collection resultOne = runFuture(sut.validate(vocabulary, vocabularies)); + verify(validator).validate(vocabulary, vocabularies); + final Collection resultTwo = runFuture(sut.validate(vocabulary, vocabularies)); + verifyNoMoreInteractions(validator); + assertIterableEquals(resultOne, resultTwo); + assertSame(results, resultOne); } @Test - void evictCacheClearsCachedValidationResults() { - final List results = Collections.singletonList(new ValidationResult()); - when(validator.validate(anyCollection())).thenReturn(results); - final Set vocabularies = Collections.singleton(Generator.generateUri()); - final List resultOne = sut.validate(vocabularies); - sut.evictCache(new VocabularyContentModified(this, null)); - final List resultTwo = sut.validate(vocabularies); - verify(validator, times(2)).validate(vocabularies); - assertNotSame(resultOne, resultTwo); + void evictCacheClearsCachedValidationResults() throws Exception { + final List results = Collections.singletonList(validationResult); + when(validator.validate(any(), anyCollection())).thenReturn(ThrottledFuture.done(results)); + final Set vocabularies = Collections.singleton(vocabulary); + final Collection resultOne = runFuture(sut.validate(vocabulary, vocabularies)); + verify(validator).validate(vocabulary, vocabularies); + sut.markCacheDirty(new VocabularyContentModifiedEvent(this, vocabulary)); + final Collection resultTwo = runFuture(sut.validate(vocabulary, vocabularies)); + verify(validator, times(2)).validate(vocabulary, vocabularies); + assertEquals(resultOne, resultTwo); + assertSame(results, resultOne); } } diff --git a/src/test/java/cz/cvut/kbss/termit/persistence/validation/ValidatorTest.java b/src/test/java/cz/cvut/kbss/termit/persistence/validation/ValidatorTest.java index 092e68556..0a44e04e7 100644 --- a/src/test/java/cz/cvut/kbss/termit/persistence/validation/ValidatorTest.java +++ b/src/test/java/cz/cvut/kbss/termit/persistence/validation/ValidatorTest.java @@ -20,6 +20,8 @@ import cz.cvut.kbss.jopa.model.EntityManager; import cz.cvut.kbss.termit.environment.Environment; import cz.cvut.kbss.termit.environment.Generator; +import cz.cvut.kbss.termit.event.VocabularyValidationFinishedEvent; +import cz.cvut.kbss.termit.exception.TermItException; import cz.cvut.kbss.termit.model.Term; import cz.cvut.kbss.termit.model.User; import cz.cvut.kbss.termit.model.Vocabulary; @@ -31,12 +33,21 @@ import cz.cvut.kbss.termit.util.Constants; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ApplicationEvent; +import org.springframework.context.ApplicationEventPublisher; +import java.net.URI; +import java.util.Collection; import java.util.Collections; -import java.util.List; +import static cz.cvut.kbss.termit.util.throttle.TestFutureRunner.runFuture; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertIterableEquals; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.verify; class ValidatorTest extends BaseDaoTestRunner { @@ -52,6 +63,9 @@ class ValidatorTest extends BaseDaoTestRunner { @Autowired private Configuration config; + @Mock + private ApplicationEventPublisher eventPublisher; + @BeforeEach void setUp() { final User author = Generator.generateUserWithId(); @@ -63,8 +77,13 @@ void setUp() { void validateUsesOverrideRulesToAllowI18n() { final Vocabulary vocabulary = generateVocabulary(); transactional(() -> { - final Validator sut = new Validator(em, vocabularyContextMapper, config); - final List result = sut.validate(Collections.singleton(vocabulary.getUri())); + final Validator sut = new Validator(em, vocabularyContextMapper, config, eventPublisher); + final Collection result; + try { + result = runFuture(sut.validate(vocabulary.getUri(), Collections.singleton(vocabulary.getUri()))); + } catch (Exception e) { + throw new TermItException(e); + } assertTrue(result.stream().noneMatch( vr -> vr.getMessage().get("en").contains("The term does not have a preferred label in Czech"))); assertTrue(result.stream().noneMatch( @@ -74,6 +93,33 @@ void validateUsesOverrideRulesToAllowI18n() { }); } + /** + * Validation is a heavy and long-running task; validator must publish event signalizing validation end + * allowing other components to react on the result. + */ + @Test + void publishesVocabularyValidationFinishedEventAfterValidation() { + final Vocabulary vocabulary = generateVocabulary(); + transactional(() -> { + final Validator sut = new Validator(em, vocabularyContextMapper, config, eventPublisher); + final Collection iris = Collections.singleton(vocabulary.getUri()); + final Collection result; + try { + result = runFuture(sut.validate(vocabulary.getUri(), iris)); + } catch (Exception e) { + throw new TermItException(e); + } + + ArgumentCaptor eventCaptor = ArgumentCaptor.forClass(ApplicationEvent.class); + verify(eventPublisher).publishEvent(eventCaptor.capture()); + final ApplicationEvent event = eventCaptor.getValue(); + assertInstanceOf(VocabularyValidationFinishedEvent.class, event); + final VocabularyValidationFinishedEvent finished = (VocabularyValidationFinishedEvent) event; + assertIterableEquals(result, finished.getValidationResults()); + assertIterableEquals(iris, finished.getVocabularyIris()); + }); + } + private Vocabulary generateVocabulary() { final Vocabulary vocabulary = Generator.generateVocabularyWithId(); final Term term = Generator.generateTermWithId(vocabulary.getUri()); diff --git a/src/test/java/cz/cvut/kbss/termit/rest/ResourceControllerTest.java b/src/test/java/cz/cvut/kbss/termit/rest/ResourceControllerTest.java index 1b2021a72..bd50b7258 100644 --- a/src/test/java/cz/cvut/kbss/termit/rest/ResourceControllerTest.java +++ b/src/test/java/cz/cvut/kbss/termit/rest/ResourceControllerTest.java @@ -158,8 +158,7 @@ void getContentReturnsContentOfRequestedFile() throws Exception { final java.io.File content = createTemporaryHtmlFile(); when(resourceServiceMock.getContent(eq(file), any(ResourceRetrievalSpecification.class))) .thenReturn(new TypeAwareFileSystemResource(content, MediaType.TEXT_HTML_VALUE)); - final MvcResult mvcResult = mockMvc - .perform(get(PATH + "/" + FILE_NAME + "/content")) + final MvcResult mvcResult = mockMvc.perform(get(PATH + "/" + FILE_NAME + "/content")) .andExpect(status().isOk()).andReturn(); final String resultContent = mvcResult.getResponse().getContentAsString(); assertEquals(HTML_CONTENT, resultContent); @@ -379,8 +378,8 @@ void getContentSupportsReturningContentAsAttachment() throws Exception { final java.io.File content = createTemporaryHtmlFile(); when(resourceServiceMock.getContent(eq(file), any(ResourceRetrievalSpecification.class))) .thenReturn(new TypeAwareFileSystemResource(content, MediaType.TEXT_HTML_VALUE)); - final MvcResult mvcResult = mockMvc - .perform(get(PATH + "/" + FILE_NAME + "/content").param("attachment", Boolean.toString(true))) + final MvcResult mvcResult = mockMvc.perform( + get(PATH + "/" + FILE_NAME + "/content").param("attachment", Boolean.toString(true))) .andExpect(status().isOk()).andReturn(); assertThat(mvcResult.getResponse().getHeader(HttpHeaders.CONTENT_DISPOSITION), containsString("attachment")); assertThat(mvcResult.getResponse().getHeader(HttpHeaders.CONTENT_DISPOSITION), @@ -420,8 +419,7 @@ void getContentWithTimestampReturnsContentOfRequestedFileAtSpecifiedTimestamp() final Instant at = Utils.timestamp().truncatedTo(ChronoUnit.SECONDS); when(resourceServiceMock.getContent(eq(file), any(ResourceRetrievalSpecification.class))) .thenReturn(new TypeAwareFileSystemResource(content, MediaType.TEXT_HTML_VALUE)); - final MvcResult mvcResult = mockMvc - .perform(get(PATH + "/" + FILE_NAME + "/content") + final MvcResult mvcResult = mockMvc.perform(get(PATH + "/" + FILE_NAME + "/content") .queryParam("at", Constants.TIMESTAMP_FORMATTER.format(at))) .andExpect(status().isOk()).andReturn(); final String resultContent = mvcResult.getResponse().getContentAsString(); @@ -454,8 +452,7 @@ void getContentWithoutUnconfirmedOccurrencesReturnsContentOfRequestedFileAtWitho final java.io.File content = createTemporaryHtmlFile(); when(resourceServiceMock.getContent(eq(file), any(ResourceRetrievalSpecification.class))) .thenReturn(new TypeAwareFileSystemResource(content, MediaType.TEXT_HTML_VALUE)); - final MvcResult mvcResult = mockMvc - .perform(get(PATH + "/" + FILE_NAME + "/content") + final MvcResult mvcResult = mockMvc.perform(get(PATH + "/" + FILE_NAME + "/content") .queryParam("withoutUnconfirmedOccurrences", Boolean.toString(true))) .andExpect(status().isOk()).andReturn(); final String resultContent = mvcResult.getResponse().getContentAsString(); diff --git a/src/test/java/cz/cvut/kbss/termit/service/business/TermServiceTest.java b/src/test/java/cz/cvut/kbss/termit/service/business/TermServiceTest.java index 2a4781da5..5ea15780f 100644 --- a/src/test/java/cz/cvut/kbss/termit/service/business/TermServiceTest.java +++ b/src/test/java/cz/cvut/kbss/termit/service/business/TermServiceTest.java @@ -214,6 +214,7 @@ void persistUsesRepositoryServiceToPersistTermAsChildOfSpecifiedParentTerm() { void updateUsesRepositoryServiceToUpdateTerm() { final Term term = generateTermWithId(); when(termRepositoryService.findRequired(term.getUri())).thenReturn(term); + when(termRepositoryService.update(term)).thenReturn(term); sut.update(term); verify(termRepositoryService).update(term); } @@ -311,26 +312,26 @@ void removeRemovesTermViaRepositoryService() { void runTextAnalysisInvokesTextAnalysisOnSpecifiedTerm() { when(vocabularyContextMapper.getVocabularyContext(vocabulary.getUri())).thenReturn(vocabulary.getUri()); final Term toAnalyze = generateTermWithId(); + when(termRepositoryService.findRequired(toAnalyze.getUri())).thenReturn(toAnalyze); sut.analyzeTermDefinition(toAnalyze, vocabulary.getUri()); verify(textAnalysisService).analyzeTermDefinition(toAnalyze, vocabulary.getUri()); } @Test - void persistChildInvokesTextAnalysisOnPersistedChildTerm() { - when(vocabularyContextMapper.getVocabularyContext(vocabulary.getUri())).thenReturn(vocabulary.getUri()); + void persistChildInvokesTextAnalysisOnAllTermsInVocabulary() { + when(vocabularyService.findRequired(vocabulary.getUri())).thenReturn(vocabulary); final Term parent = generateTermWithId(); parent.setVocabulary(vocabulary.getUri()); final Term childToPersist = generateTermWithId(); sut.persistChild(childToPersist, parent); - verify(textAnalysisService).analyzeTermDefinition(childToPersist, parent.getVocabulary()); + verify(vocabularyService).runTextAnalysisOnAllTerms(vocabulary); } @Test - void persistRootInvokesTextAnalysisOnPersistedRootTerm() { - when(vocabularyContextMapper.getVocabularyContext(vocabulary.getUri())).thenReturn(vocabulary.getUri()); + void persistRootInvokesTextAnalysisOnAllTermsInVocabulary() { final Term toPersist = generateTermWithId(); sut.persistRoot(toPersist, vocabulary); - verify(textAnalysisService).analyzeTermDefinition(toPersist, vocabulary.getUri()); + verify(vocabularyService).runTextAnalysisOnAllTerms(vocabulary); } @Test @@ -338,14 +339,31 @@ void updateInvokesTextAnalysisOnUpdatedTerm() { when(vocabularyContextMapper.getVocabularyContext(vocabulary.getUri())).thenReturn(vocabulary.getUri()); final Term original = generateTermWithId(vocabulary.getUri()); final Term toUpdate = new Term(original.getUri()); + toUpdate.setLabel(original.getLabel()); final String newDefinition = "This term has acquired a new definition"; toUpdate.setVocabulary(vocabulary.getUri()); when(termRepositoryService.findRequired(toUpdate.getUri())).thenReturn(original); toUpdate.setDefinition(MultilingualString.create(newDefinition, Environment.LANGUAGE)); + when(termRepositoryService.update(toUpdate)).thenReturn(toUpdate); sut.update(toUpdate); verify(textAnalysisService).analyzeTermDefinition(toUpdate, toUpdate.getVocabulary()); } + @Test + void updateOfTermLabelInvokesTextAnalysisOnAllTermsInVocabulary() { + final Term original = generateTermWithId(vocabulary.getUri()); + final Term toUpdate = new Term(original.getUri()); + toUpdate.setLabel(MultilingualString.create("new Label", Environment.LANGUAGE)); + final String newDefinition = "This term has acquired a new definition"; + toUpdate.setVocabulary(vocabulary.getUri()); + toUpdate.setDefinition(MultilingualString.create(newDefinition, Environment.LANGUAGE)); + when(termRepositoryService.findRequired(toUpdate.getUri())).thenReturn(original); + when(termRepositoryService.update(toUpdate)).thenReturn(toUpdate); + when(vocabularyService.getReference(vocabulary.getUri())).thenReturn(vocabulary); + sut.update(toUpdate); + verify(vocabularyService).runTextAnalysisOnAllTerms(vocabulary); + } + @Test void setTermDefinitionSourceSetsTermOnDefinitionAndPersistsIt() { final Term term = Generator.generateTermWithId(); @@ -456,7 +474,6 @@ void persistChildInvokesTextAnalysisOnAllTermsInParentTermVocabulary() { parent.setVocabulary(vocabulary.getUri()); final Term childToPersist = generateTermWithId(); when(vocabularyService.findRequired(vocabulary.getUri())).thenReturn(vocabulary); - when(vocabularyContextMapper.getVocabularyContext(vocabulary.getUri())).thenReturn(vocabulary.getUri()); sut.persistChild(childToPersist, parent); final InOrder inOrder = inOrder(termRepositoryService, vocabularyService); @@ -474,6 +491,7 @@ void updateInvokesTextAnalysisOnAllTermsInTermsVocabularyWhenLabelHasChanged() { update.setVocabulary(vocabulary.getUri()); when(termRepositoryService.findRequired(original.getUri())).thenReturn(original); when(vocabularyService.getReference(vocabulary.getUri())).thenReturn(vocabulary); + when(termRepositoryService.update(update)).thenReturn(update); update.getLabel().set(Environment.LANGUAGE, "updatedLabel"); sut.update(update); @@ -637,6 +655,7 @@ void updateVerifiesThatStateExistsTermState() { update.setDescription(new MultilingualString(original.getDescription().getValue())); update.setVocabulary(vocabulary.getUri()); update.setState(Generator.randomItem(Generator.TERM_STATES)); + when(termRepositoryService.update(update)).thenReturn(update); sut.update(update); final InOrder inOrder = inOrder(languageService, termRepositoryService); inOrder.verify(languageService).verifyStateExists(update.getState()); diff --git a/src/test/java/cz/cvut/kbss/termit/service/repository/VocabularyServiceTest.java b/src/test/java/cz/cvut/kbss/termit/service/business/VocabularyServiceTest.java similarity index 92% rename from src/test/java/cz/cvut/kbss/termit/service/repository/VocabularyServiceTest.java rename to src/test/java/cz/cvut/kbss/termit/service/business/VocabularyServiceTest.java index 151796815..277c26714 100644 --- a/src/test/java/cz/cvut/kbss/termit/service/repository/VocabularyServiceTest.java +++ b/src/test/java/cz/cvut/kbss/termit/service/business/VocabularyServiceTest.java @@ -15,7 +15,7 @@ * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ -package cz.cvut.kbss.termit.service.repository; +package cz.cvut.kbss.termit.service.business; import cz.cvut.kbss.termit.dto.Snapshot; import cz.cvut.kbss.termit.dto.acl.AccessControlListDto; @@ -23,6 +23,7 @@ import cz.cvut.kbss.termit.dto.listing.VocabularyDto; import cz.cvut.kbss.termit.environment.Environment; import cz.cvut.kbss.termit.environment.Generator; +import cz.cvut.kbss.termit.event.VocabularyContentModifiedEvent; import cz.cvut.kbss.termit.event.VocabularyCreatedEvent; import cz.cvut.kbss.termit.exception.NotFoundException; import cz.cvut.kbss.termit.model.Term; @@ -34,10 +35,9 @@ import cz.cvut.kbss.termit.model.util.HasIdentifier; import cz.cvut.kbss.termit.persistence.context.VocabularyContextMapper; import cz.cvut.kbss.termit.persistence.snapshot.SnapshotCreator; -import cz.cvut.kbss.termit.service.business.AccessControlListService; -import cz.cvut.kbss.termit.service.business.VocabularyService; -import cz.cvut.kbss.termit.service.business.async.AsyncTermService; import cz.cvut.kbss.termit.service.export.ExportFormat; +import cz.cvut.kbss.termit.service.repository.ChangeRecordService; +import cz.cvut.kbss.termit.service.repository.VocabularyRepositoryService; import cz.cvut.kbss.termit.service.security.authorization.VocabularyAuthorizationService; import cz.cvut.kbss.termit.util.Configuration; import cz.cvut.kbss.termit.util.TypeAwareResource; @@ -49,6 +49,7 @@ import org.mockito.InOrder; import org.mockito.InjectMocks; import org.mockito.Mock; +import org.mockito.Spy; import org.mockito.junit.jupiter.MockitoExtension; import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationEvent; @@ -61,7 +62,6 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.Optional; import static cz.cvut.kbss.termit.environment.Environment.termsToDtos; @@ -72,6 +72,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.any; import static org.mockito.Mockito.anyBoolean; +import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -81,7 +82,7 @@ class VocabularyServiceTest { @Mock - private AsyncTermService termService; + TermService termService; @Mock private VocabularyRepositoryService repositoryService; @@ -104,6 +105,7 @@ class VocabularyServiceTest { @Mock private ApplicationContext appContext; + @Spy @InjectMocks private VocabularyService sut; @@ -121,9 +123,10 @@ void runTextAnalysisOnAllTermsInvokesTextAnalysisOnAllTermsInVocabulary() { when(termService.findAll(vocabulary)).thenReturn(terms); when(contextMapper.getVocabularyContext(vocabulary.getUri())).thenReturn(vocabulary.getUri()); when(repositoryService.getTransitivelyImportedVocabularies(vocabulary)).thenReturn(Collections.emptyList()); + when(repositoryService.findRequired(vocabulary.getUri())).thenReturn(vocabulary); sut.runTextAnalysisOnAllTerms(vocabulary); - verify(termService).asyncAnalyzeTermDefinitions(Map.of(termOne, vocabulary.getUri(), - termTwo, vocabulary.getUri())); + verify(termService).analyzeTermDefinition(termOne, vocabulary.getUri()); + verify(termService).analyzeTermDefinition(termTwo, vocabulary.getUri()); } @Test @@ -136,7 +139,8 @@ void runTextAnalysisOnAllTermsInvokesTextAnalysisOnAllVocabularies() { when(contextMapper.getVocabularyContext(v.getUri())).thenReturn(v.getUri()); when(termService.findAll(v)).thenReturn(Collections.singletonList(new TermDto(term))); sut.runTextAnalysisOnAllVocabularies(); - verify(termService).asyncAnalyzeTermDefinitions(Map.of(term, v.getUri())); + + verify(termService).analyzeTermDefinition(term, v.getUri()); } @Test @@ -266,7 +270,7 @@ void updateAccessControlLevelRetrievesACLForVocabularyAndUpdatesSpecifiedRecord( @Test void persistCreatesAccessControlListAndSetsItOnVocabularyInstance() { final AccessControlList acl = Generator.generateAccessControlList(true); - final Vocabulary toPersist = Generator.generateVocabulary(); + final Vocabulary toPersist = Generator.generateVocabularyWithId(); when(aclService.createFor(toPersist)).thenReturn(acl); sut.persist(toPersist); @@ -370,9 +374,10 @@ void importNewVocabularyPublishesVocabularyCreatedEvent() { sut.importVocabulary(false, fileToImport); final ArgumentCaptor captor = ArgumentCaptor.forClass(ApplicationEvent.class); - verify(eventPublisher).publishEvent(captor.capture()); - assertInstanceOf(VocabularyCreatedEvent.class, captor.getValue()); - assertEquals(persisted, captor.getValue().getSource()); + verify(eventPublisher, atLeastOnce()).publishEvent(captor.capture()); + Optional event = captor.getAllValues().stream().filter(e -> e instanceof VocabularyCreatedEvent).map(e->(VocabularyCreatedEvent)e).findAny(); + assertTrue(event.isPresent()); + assertEquals(persisted.getUri(), event.get().getVocabularyIri()); } @Test @@ -386,4 +391,15 @@ void getExcelTemplateFileReturnsResourceRepresentingExcelTemplateFile() throws E final File expectedFile = new File(getClass().getClassLoader().getResource("template/termit-import.xlsx").toURI()); assertEquals(expectedFile, result.getFile()); } + + /** + * The goal for this is to get the results cached and do not force users to wait for validation + * when they request it. + */ + @Test + void publishingVocabularyContentModifiedEventTriggersContentsValidation() { + final VocabularyContentModifiedEvent event = new VocabularyContentModifiedEvent(this, Generator.generateUri()); + sut.onVocabularyContentModified(event); + verify(repositoryService).validateContents(event.getVocabularyIri()); + } } diff --git a/src/test/java/cz/cvut/kbss/termit/service/document/AnnotationGeneratorTest.java b/src/test/java/cz/cvut/kbss/termit/service/document/AnnotationGeneratorTest.java index e94fca73f..6e49852a5 100644 --- a/src/test/java/cz/cvut/kbss/termit/service/document/AnnotationGeneratorTest.java +++ b/src/test/java/cz/cvut/kbss/termit/service/document/AnnotationGeneratorTest.java @@ -193,7 +193,7 @@ void generateAnnotationsThrowsAnnotationGenerationExceptionForUnsupportedFileTyp try (final InputStream content = loadFile("application.yml")) { file.setLabel(generateIncompatibleFile()); final AnnotationGenerationException ex = assertThrows(AnnotationGenerationException.class, - () -> sut.generateAnnotations(content, file)); + () -> sut.generateAnnotations(content, file)); assertThat(ex.getMessage(), containsString("Unsupported type of file")); } } @@ -224,7 +224,7 @@ void generateAnnotationsResolvesOverlappingAnnotations() throws Exception { void generateAnnotationsThrowsAnnotationGenerationExceptionForUnknownTermIdentifier() throws Exception { final InputStream content = setUnknownTermIdentifier(loadFile("data/rdfa-simple.html")); final AnnotationGenerationException ex = assertThrows(AnnotationGenerationException.class, - () -> sut.generateAnnotations(content, file)); + () -> sut.generateAnnotations(content, file)); assertThat(ex.getMessage(), containsString("Term with id ")); assertThat(ex.getMessage(), containsString("not found")); } diff --git a/src/test/java/cz/cvut/kbss/termit/service/document/SynchronousTermOccurrenceSaverTest.java b/src/test/java/cz/cvut/kbss/termit/service/document/TermOccurrenceSaverTest.java similarity index 93% rename from src/test/java/cz/cvut/kbss/termit/service/document/SynchronousTermOccurrenceSaverTest.java rename to src/test/java/cz/cvut/kbss/termit/service/document/TermOccurrenceSaverTest.java index c335e27b5..9f02b3f8f 100644 --- a/src/test/java/cz/cvut/kbss/termit/service/document/SynchronousTermOccurrenceSaverTest.java +++ b/src/test/java/cz/cvut/kbss/termit/service/document/TermOccurrenceSaverTest.java @@ -17,13 +17,13 @@ import static org.mockito.Mockito.inOrder; @ExtendWith(MockitoExtension.class) -class SynchronousTermOccurrenceSaverTest { +class TermOccurrenceSaverTest { @Mock private TermOccurrenceDao occurrenceDao; @InjectMocks - private SynchronousTermOccurrenceSaver sut; + private TermOccurrenceSaver sut; @Test void saveOccurrencesRemovesAllExistingOccurrencesAndPersistsSpecifiedOnes() { diff --git a/src/test/java/cz/cvut/kbss/termit/service/document/TextAnalysisServiceTest.java b/src/test/java/cz/cvut/kbss/termit/service/document/TextAnalysisServiceTest.java index 560d6ddd0..aa431671e 100644 --- a/src/test/java/cz/cvut/kbss/termit/service/document/TextAnalysisServiceTest.java +++ b/src/test/java/cz/cvut/kbss/termit/service/document/TextAnalysisServiceTest.java @@ -17,12 +17,15 @@ */ package cz.cvut.kbss.termit.service.document; +import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import cz.cvut.kbss.jopa.model.MultilingualString; import cz.cvut.kbss.termit.dto.TextAnalysisInput; import cz.cvut.kbss.termit.environment.Environment; import cz.cvut.kbss.termit.environment.Generator; import cz.cvut.kbss.termit.environment.PropertyMockingApplicationContextInitializer; +import cz.cvut.kbss.termit.event.FileTextAnalysisFinishedEvent; +import cz.cvut.kbss.termit.event.TermDefinitionTextAnalysisFinishedEvent; import cz.cvut.kbss.termit.exception.NotFoundException; import cz.cvut.kbss.termit.exception.WebServiceIntegrationException; import cz.cvut.kbss.termit.model.Term; @@ -44,6 +47,7 @@ import org.mockito.junit.jupiter.MockitoSettings; import org.mockito.quality.Strictness; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.ApplicationEventPublisher; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; @@ -67,6 +71,7 @@ import static org.hamcrest.Matchers.containsString; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.any; @@ -100,6 +105,9 @@ class TextAnalysisServiceTest extends BaseServiceTestRunner { @Autowired private Configuration config; + @Mock + private ApplicationEventPublisher eventPublisher; + @Autowired private DocumentManager documentManager; @@ -128,12 +136,14 @@ void setUp() throws Exception { this.file = new File(); file.setUri(Generator.generateUri()); file.setLabel(FILE_NAME); + file.setDocument(Generator.generateDocumentWithId()); + file.getDocument().setVocabulary(vocabulary.getUri()); generateFile(); this.documentManagerSpy = spy(documentManager); doCallRealMethod().when(documentManagerSpy).loadFileContent(any()); doNothing().when(documentManagerSpy).createBackup(any()); this.sut = new TextAnalysisService(restTemplate, config, documentManagerSpy, annotationGeneratorMock, - textAnalysisRecordDao); + textAnalysisRecordDao, eventPublisher); } @Test @@ -149,8 +159,7 @@ private void generateFile() throws IOException { final java.io.File dir = Files.createTempDirectory("termit").toFile(); dir.deleteOnExit(); config.getFile().setStorage(dir.getAbsolutePath()); - final java.io.File docDir = new java.io.File(dir.getAbsolutePath() + java.io.File.separator + - file.getDirectoryName()); + final java.io.File docDir = new java.io.File(dir.getAbsolutePath() + java.io.File.separator + file.getDirectoryName()); Files.createDirectory(docDir.toPath()); docDir.deleteOnExit(); final java.io.File content = new java.io.File( @@ -407,4 +416,38 @@ void analyzeTermDefinitionInvokesTextAnalysisServiceWithVocabularyRepositoryUser sut.analyzeTermDefinition(term, vocabulary.getUri()); mockServer.verify(); } + + @Test + void analyzeFilePublishesAnalysisFinishedEvent() { + mockServer.expect(requestTo(config.getTextAnalysis().getUrl())) + .andExpect(method(HttpMethod.POST)).andExpect(content().string(containsString(CONTENT))) + .andRespond(withSuccess(CONTENT, MediaType.APPLICATION_XML)); + sut.analyzeFile(file, Collections.singleton(vocabulary.getUri())); + + ArgumentCaptor eventCaptor = ArgumentCaptor.forClass(FileTextAnalysisFinishedEvent.class); + verify(eventPublisher).publishEvent(eventCaptor.capture()); + assertNotNull(eventCaptor.getValue()); + assertEquals(file.getUri(), eventCaptor.getValue().getFileUri()); + assertEquals(vocabulary.getUri(), eventCaptor.getValue().getVocabularyIri()); + } + + @Test + void analyzeTermDefinitionPublishesAnalysisFinishedEvent() throws JsonProcessingException { + final Term term = Generator.generateTermWithId(); + term.setVocabulary(vocabulary.getUri()); + final TextAnalysisInput input = textAnalysisInput(); + input.setContent(term.getDefinition().get(Environment.LANGUAGE)); + mockServer.expect(requestTo(config.getTextAnalysis().getUrl())) + .andExpect(method(HttpMethod.POST)) + .andExpect(content().string(objectMapper.writeValueAsString(input))) + .andRespond(withSuccess(CONTENT, MediaType.APPLICATION_XML)); + + sut.analyzeTermDefinition(term, vocabulary.getUri()); + + ArgumentCaptor eventCaptor = ArgumentCaptor.forClass(TermDefinitionTextAnalysisFinishedEvent.class); + verify(eventPublisher).publishEvent(eventCaptor.capture()); + assertNotNull(eventCaptor.getValue()); + assertEquals(term.getUri(), eventCaptor.getValue().getTermUri()); + assertEquals(vocabulary.getUri(), eventCaptor.getValue().getVocabularyIri()); + } } diff --git a/src/test/java/cz/cvut/kbss/termit/service/document/html/HtmlTermOccurrenceResolverTest.java b/src/test/java/cz/cvut/kbss/termit/service/document/html/HtmlTermOccurrenceResolverTest.java index bf85ea85e..fdcfd3ed8 100644 --- a/src/test/java/cz/cvut/kbss/termit/service/document/html/HtmlTermOccurrenceResolverTest.java +++ b/src/test/java/cz/cvut/kbss/termit/service/document/html/HtmlTermOccurrenceResolverTest.java @@ -47,6 +47,7 @@ import java.util.List; import java.util.Optional; import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.endsWith; @@ -116,8 +117,7 @@ void findTermOccurrencesExtractsAlsoScoreFromRdfa() { final File file = initFile(); final InputStream is = cz.cvut.kbss.termit.environment.Environment.loadFile("data/rdfa-simple.html"); sut.parseContent(is, file); - final List result = sut.findTermOccurrences(); - result.forEach(to -> { + sut.findTermOccurrences(to -> { assertNotNull(to.getScore()); assertThat(to.getScore(), greaterThan(0.0)); }); @@ -136,8 +136,7 @@ void findTermOccurrencesHandlesRdfaWithoutScore() { final File file = initFile(); final InputStream is = cz.cvut.kbss.termit.environment.Environment.loadFile("data/rdfa-simple-no-score.html"); sut.parseContent(is, file); - final List result = sut.findTermOccurrences(); - result.forEach(to -> assertNull(to.getScore())); + sut.findTermOccurrences(to -> assertNull(to.getScore())); } @Test @@ -147,8 +146,7 @@ void findTermOccurrencesHandlesInvalidScoreInRdfa() { final InputStream is = cz.cvut.kbss.termit.environment.Environment .loadFile("data/rdfa-simple-invalid-score.html"); sut.parseContent(is, file); - final List result = sut.findTermOccurrences(); - result.forEach(to -> assertNull(to.getScore())); + sut.findTermOccurrences(to -> assertNull(to.getScore())); } @Test @@ -162,10 +160,14 @@ void findTermOccurrencesGeneratesOccurrenceUriBasedOnAnnotationAbout() { final File file = initFile(); final InputStream is = cz.cvut.kbss.termit.environment.Environment.loadFile("data/rdfa-simple.html"); sut.parseContent(is, file); - final List result = sut.findTermOccurrences(); - assertEquals(1, result.size()); - assertThat(result.get(0).getUri().toString(), startsWith(file.getUri() + "/" + TermOccurrence.CONTEXT_SUFFIX)); - assertThat(result.get(0).getUri().toString(), endsWith("1")); + AtomicInteger resultSize = new AtomicInteger(0); + sut.findTermOccurrences(to -> { + resultSize.incrementAndGet(); + assertThat(to.getUri().toString(), startsWith(file.getUri() + "/" + TermOccurrence.CONTEXT_SUFFIX)); + assertThat(to.getUri().toString(), endsWith("1")); + }); + assertEquals(1,resultSize.get()); + } @Test @@ -174,8 +176,7 @@ void findTermOccurrencesMarksOccurrencesAsSuggested() { final File file = initFile(); final InputStream is = cz.cvut.kbss.termit.environment.Environment.loadFile("data/rdfa-simple.html"); sut.parseContent(is, file); - final List result = sut.findTermOccurrences(); - result.forEach(to -> assertThat(to.getTypes(), hasItem(Vocabulary.s_c_navrzeny_vyskyt_termu))); + sut.findTermOccurrences(to -> assertThat(to.getTypes(), hasItem(Vocabulary.s_c_navrzeny_vyskyt_termu))); } @Test @@ -191,9 +192,12 @@ void findTermOccurrencesSetsFoundOccurrencesAsApprovedWhenCorrespondingExistingO sut.parseContent(is, file); sut.setExistingOccurrences(List.of(existing)); - final List result = sut.findTermOccurrences(); - assertEquals(1, result.size()); - assertThat(result.get(0).getTypes(), not(hasItem(Vocabulary.s_c_navrzeny_vyskyt_termu))); + AtomicInteger resultSize = new AtomicInteger(0); + sut.findTermOccurrences(to -> { + resultSize.incrementAndGet(); + assertThat(to.getTypes(), not(hasItem(Vocabulary.s_c_navrzeny_vyskyt_termu))); + }); + assertEquals(1, resultSize.get()); final org.jsoup.nodes.Document document = Jsoup.parse(sut.getContent(), StandardCharsets.UTF_8.name(), ""); final Elements annotations = document.select("span[about]"); assertEquals(1, annotations.size()); diff --git a/src/test/java/cz/cvut/kbss/termit/service/jmx/AppAdminBeanTest.java b/src/test/java/cz/cvut/kbss/termit/service/jmx/AppAdminBeanTest.java index e10e552e3..479d67a6b 100644 --- a/src/test/java/cz/cvut/kbss/termit/service/jmx/AppAdminBeanTest.java +++ b/src/test/java/cz/cvut/kbss/termit/service/jmx/AppAdminBeanTest.java @@ -19,7 +19,6 @@ import cz.cvut.kbss.termit.event.EvictCacheEvent; import cz.cvut.kbss.termit.event.RefreshLastModifiedEvent; -import cz.cvut.kbss.termit.event.VocabularyContentModified; import cz.cvut.kbss.termit.util.Configuration; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -62,12 +61,4 @@ void invalidateCachesPublishesRefreshLastModifiedEvent() { verify(eventPublisherMock, atLeastOnce()).publishEvent(captor.capture()); assertTrue(captor.getAllValues().stream().anyMatch(RefreshLastModifiedEvent.class::isInstance)); } - - @Test - void invalidateCachesPublishesVocabularyContentModifiedEventToForceEvictionOfVocabularyContentBasedCaches() { - sut.invalidateCaches(); - final ArgumentCaptor captor = ArgumentCaptor.forClass(ApplicationEvent.class); - verify(eventPublisherMock, atLeastOnce()).publishEvent(captor.capture()); - assertTrue(captor.getAllValues().stream().anyMatch(VocabularyContentModified.class::isInstance)); - } } diff --git a/src/test/java/cz/cvut/kbss/termit/util/throttle/MockedMethodSignature.java b/src/test/java/cz/cvut/kbss/termit/util/throttle/MockedMethodSignature.java new file mode 100644 index 000000000..310d99398 --- /dev/null +++ b/src/test/java/cz/cvut/kbss/termit/util/throttle/MockedMethodSignature.java @@ -0,0 +1,84 @@ +package cz.cvut.kbss.termit.util.throttle; + +import org.aspectj.lang.reflect.MethodSignature; + +import java.lang.reflect.Method; + +/** + * Allows construction of {@link MethodSignature} for testing purposes + */ +public class MockedMethodSignature implements MethodSignature { + + private final String methodName; + private Class returnType; + + private Class[] parameterTypes; + + private String[] parameterNames; + + public MockedMethodSignature(String methodName, Class returnType, Class[] parameterTypes, String[] parameterNames) { + this.methodName = methodName; + this.returnType = returnType; + this.parameterTypes = parameterTypes; + this.parameterNames = parameterNames; + } + + @Override + public Class getReturnType() { + return returnType; + } + + public void setReturnType(Class returnType) { + this.returnType = returnType; + } + + @Override + public Method getMethod() { + return null; + } + + @Override + public Class[] getParameterTypes() { + return parameterTypes; + } + + @Override + public String[] getParameterNames() { + return parameterNames; + } + + @Override + public Class[] getExceptionTypes() { + return new Class[0]; + } + + @Override + public String toShortString() { + return "shortMethodSignatureString" + methodName; + } + + @Override + public String toLongString() { + return "longMethodSignatureString" + methodName; + } + + @Override + public String getName() { + return methodName; + } + + @Override + public int getModifiers() { + return 0; + } + + @Override + public Class getDeclaringType() { + return null; + } + + @Override + public String getDeclaringTypeName() { + return ""; + } +} diff --git a/src/test/java/cz/cvut/kbss/termit/util/throttle/MockedThrottle.java b/src/test/java/cz/cvut/kbss/termit/util/throttle/MockedThrottle.java new file mode 100644 index 000000000..398b435ca --- /dev/null +++ b/src/test/java/cz/cvut/kbss/termit/util/throttle/MockedThrottle.java @@ -0,0 +1,48 @@ +package cz.cvut.kbss.termit.util.throttle; + +import org.springframework.lang.NonNull; + +import java.lang.annotation.Annotation; + +/** + * Implementation of annotation interface allowing instancing for testing purposes + */ +public class MockedThrottle implements Throttle { + + private String value; + + private String group; + + public MockedThrottle(@NonNull String value, @NonNull String group) { + this.value = value; + this.group = group; + } + + @Override + public @NonNull String value() { + return value; + } + + @Override + public @NonNull String group() { + return group; + } + + @Override + public String name() { + return "NameOfMockedThrottle"+group+value; + } + + @Override + public Class annotationType() { + return Throttle.class; + } + + public void setValue(@NonNull String value) { + this.value = value; + } + + public void setGroup(@NonNull String group) { + this.group = group; + } +} diff --git a/src/test/java/cz/cvut/kbss/termit/util/throttle/ScheduledFutureTask.java b/src/test/java/cz/cvut/kbss/termit/util/throttle/ScheduledFutureTask.java new file mode 100644 index 000000000..f48c43dbd --- /dev/null +++ b/src/test/java/cz/cvut/kbss/termit/util/throttle/ScheduledFutureTask.java @@ -0,0 +1,30 @@ +package cz.cvut.kbss.termit.util.throttle; + +import org.springframework.lang.NonNull; + +import java.util.concurrent.Callable; +import java.util.concurrent.Delayed; +import java.util.concurrent.FutureTask; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; + +public class ScheduledFutureTask extends FutureTask implements ScheduledFuture { + + public ScheduledFutureTask(@NonNull Callable callable) { + super(callable); + } + + public ScheduledFutureTask(@NonNull Runnable runnable, T result) { + super(runnable, result); + } + + @Override + public long getDelay(@NonNull TimeUnit unit) { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public int compareTo(@NonNull Delayed o) { + throw new UnsupportedOperationException("Not implemented"); + } +} diff --git a/src/test/java/cz/cvut/kbss/termit/util/throttle/TestFutureRunner.java b/src/test/java/cz/cvut/kbss/termit/util/throttle/TestFutureRunner.java new file mode 100644 index 000000000..ce9b1e70c --- /dev/null +++ b/src/test/java/cz/cvut/kbss/termit/util/throttle/TestFutureRunner.java @@ -0,0 +1,24 @@ +package cz.cvut.kbss.termit.util.throttle; + +import java.util.concurrent.ExecutionException; + +public class TestFutureRunner { + + private TestFutureRunner() { + throw new AssertionError(); + } + + /** + * Executes the task inside the future and returns its result. + * + * @implNote Note that this method is intended only for testing purposes. + */ + public static T runFuture(ThrottledFuture future) { + future.run(null); + try { + return future.get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + } +} diff --git a/src/test/java/cz/cvut/kbss/termit/util/throttle/ThrottleAspectBeanTest.java b/src/test/java/cz/cvut/kbss/termit/util/throttle/ThrottleAspectBeanTest.java new file mode 100644 index 000000000..987414093 --- /dev/null +++ b/src/test/java/cz/cvut/kbss/termit/util/throttle/ThrottleAspectBeanTest.java @@ -0,0 +1,72 @@ +package cz.cvut.kbss.termit.util.throttle; + +import cz.cvut.kbss.termit.util.longrunning.LongRunningTasksRegistry; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.ConfigDataApplicationContextInitializer; +import org.springframework.boot.test.mock.mockito.MockBean; +import org.springframework.boot.test.mock.mockito.SpyBean; +import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; +import org.springframework.test.annotation.DirtiesContext; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit.jupiter.SpringExtension; + +import java.lang.reflect.Method; +import java.time.Instant; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +// intentionally not enabling test profile +@ExtendWith(MockitoExtension.class) +@ExtendWith(SpringExtension.class) +@ContextConfiguration(classes = {ThrottleAspectTestContextConfig.class}, + initializers = {ConfigDataApplicationContextInitializer.class}) +@DirtiesContext(classMode = DirtiesContext.ClassMode.AFTER_CLASS) +class ThrottleAspectBeanTest { + + @Autowired + ThreadPoolTaskScheduler longRunningTaskScheduler; + + @SpyBean + ThrottleAspect throttleAspect; + + @MockBean + LongRunningTasksRegistry longRunningTasksRegistry; + + @Autowired + ThrottledService throttledService; + + @BeforeEach + void beforeEach() { + reset(longRunningTaskScheduler); + when(longRunningTaskScheduler.schedule(any(Runnable.class), any(Instant.class))).then(invocation -> { + Runnable task = invocation.getArgument(0, Runnable.class); + return new ScheduledFutureTask<>(task, null); + }); + } + + @Test + void throttleAspectIsCreated() { + assertNotNull(throttleAspect); + } + + @Test + void aspectIsCalledWhenThrottleAnnotationIsPresent() throws Throwable { + throttledService.annotatedMethod(); + + final Method method = ThrottledService.class.getMethod("annotatedMethod"); + final Throttle annotation = method.getAnnotation(Throttle.class); + assertNotNull(annotation); + + verify(throttleAspect).throttleMethodCall(any(), eq(annotation)); + } + +} diff --git a/src/test/java/cz/cvut/kbss/termit/util/throttle/ThrottleAspectTest.java b/src/test/java/cz/cvut/kbss/termit/util/throttle/ThrottleAspectTest.java new file mode 100644 index 000000000..3200ecd2b --- /dev/null +++ b/src/test/java/cz/cvut/kbss/termit/util/throttle/ThrottleAspectTest.java @@ -0,0 +1,914 @@ +package cz.cvut.kbss.termit.util.throttle; + +import com.vladsch.flexmark.util.collection.OrderedMap; +import cz.cvut.kbss.termit.exception.TermItException; +import cz.cvut.kbss.termit.exception.ThrottleAspectException; +import cz.cvut.kbss.termit.util.Configuration; +import cz.cvut.kbss.termit.util.longrunning.LongRunningTasksRegistry; +import org.aspectj.lang.ProceedingJoinPoint; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.springframework.expression.spel.SpelParseException; +import org.springframework.scheduling.TaskScheduler; +import org.springframework.scheduling.support.TaskUtils; +import org.springframework.transaction.annotation.Transactional; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.time.ZoneId; +import java.time.temporal.ChronoUnit; +import java.util.Map; +import java.util.NavigableMap; +import java.util.Optional; +import java.util.TreeMap; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.FutureTask; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; +import java.util.stream.Stream; + +import static org.awaitility.Awaitility.await; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +class ThrottleAspectTest { + private static final long THREAD_JOIN_TIMEOUT_MILLIS = 60 * 1000; + private final Configuration configuration = new Configuration(); + + /** + * Throttled futures from {@link #sut} + */ + OrderedMap> throttledFutures; + + /** + * Last run map from {@link #sut} + */ + OrderedMap lastRun; + + /** + * Scheduled futures from {@link #sut} + */ + NavigableMap> scheduledFutures; + + /** + * Mocked task scheduler. + * Does not execute tasks automatically, + * they need to be executed with {@link #executeScheduledTasks()} + * Tasks are wrapped into {@link FutureTask} and saved to {@link #taskSchedulerTasks}. + */ + TaskScheduler taskScheduler; + + SynchronousTransactionExecutor transactionExecutor; + + /** + * Tasks that were submitted to {@link #taskScheduler} + * @see #beforeEach() + */ + OrderedMap taskSchedulerTasks; + + ThrottleAspect sut; + + MockedThrottle throttleA; + + MockedThrottle throttleB; + + MockedThrottle throttleC; + + /** + * Default mock:
+ * return type: primitive {@code void}
+ * parameters: {@link Object Object paramA}, {@link Object Object paramB}
+ */ + MockedMethodSignature signatureA; + + /** + * Default mock:
+ * return type: wrapped {@link Void}
+ * parameters: {@link Map Map<String,String> paramName}
+ */ + MockedMethodSignature signatureB; + + /** + * Default mock:
+ * return type: primitive {@code void}
+ * parameters: {@link Object Object paramA}, {@link Object Object paramB}
+ */ + MockedMethodSignature signatureC; + + /** + * Default mock: returning {@code null} on proceed, + * method called with two {@link Object} arguments + * @see #signatureA + */ + ProceedingJoinPoint joinPointA; + + /** + * Default mock: returning {@code null} on proceed, + * method called with one parameter ({@link Map Map<String,String>} with two entries) + * @see #signatureB + */ + ProceedingJoinPoint joinPointB; + + /** + * Default mock: returning {@code null} on proceed, + * method called with two {@link Object} parameters + * @see #signatureC + */ + ProceedingJoinPoint joinPointC; + + LongRunningTasksRegistry longRunningTasksRegistry; + + Clock clock = Clock.fixed(Instant.now(), ZoneId.of("UTC")); + + void mockA() throws Throwable { + joinPointA = mock(ProceedingJoinPoint.class); + when(joinPointA.proceed()).thenReturn(null); + signatureA = spy(new MockedMethodSignature("methodA", Void.TYPE, new Class[]{Object.class, Object.class}, new String[]{ + "paramA", "paramB"})); + when(joinPointA.getSignature()).thenReturn(signatureA); + when(joinPointA.getArgs()).thenReturn(new Object[]{new Object(), new Object()}); + when(joinPointA.getTarget()).thenReturn(this); + + throttleA = new MockedThrottle("'string literal'", "'my.testing.group.A'"); + } + + void mockB() throws Throwable { + joinPointB = mock(ProceedingJoinPoint.class); + when(joinPointB.proceed()).thenReturn(null); + signatureB = spy(new MockedMethodSignature("methodB", Void.class, new Class[]{Map.class}, new String[]{"paramName"})); + when(joinPointB.getSignature()).thenReturn(signatureB); + + when(joinPointB.getArgs()).thenReturn(new Object[]{Map.of("first", "firstValue", "second", "secondValue")}); + when(joinPointB.getTarget()).thenReturn(this); + + throttleB = new MockedThrottle("{#paramName.get('second'), #paramName.get('first')}", "'my.testing.group.B'"); + } + + void mockC() throws Throwable { + joinPointC = mock(ProceedingJoinPoint.class); + when(joinPointC.proceed()).thenReturn(null); + signatureC = spy(new MockedMethodSignature("methodC", Void.TYPE, new Class[]{Object.class, Object.class}, new String[]{ + "paramC", "paramD"})); + when(joinPointC.getSignature()).thenReturn(signatureC); + when(joinPointC.getArgs()).thenReturn(new Object[]{new Object(), new Object()}); + when(joinPointC.getTarget()).thenReturn(this); + + throttleC = new MockedThrottle("'string literal'", "'my.testing'"); + } + + @BeforeEach + void beforeEach() throws Throwable { + mockA(); + mockB(); + mockC(); + + taskSchedulerTasks = new OrderedMap<>(); + + taskScheduler = mock(TaskScheduler.class); + + when(taskScheduler.schedule(any(Runnable.class), any(Instant.class))).then(invocation -> { + final Runnable decorated = TaskUtils.decorateTaskWithErrorHandler(invocation.getArgument(0, Runnable.class), null, false); + final ScheduledFutureTask task = new ScheduledFutureTask<>(Executors.callable(decorated)); + taskSchedulerTasks.put(task, invocation.getArgument(1, Instant.class)); + System.out.println("Scheduled task at " + invocation.getArgument(1, Instant.class)); + return task; + }); + + throttledFutures = new OrderedMap<>(); + lastRun = new OrderedMap<>(); + scheduledFutures = new TreeMap<>(); + + Clock mockedClock = mock(Clock.class); + when(mockedClock.instant()).then(invocation -> getInstant()); + + transactionExecutor = spy(SynchronousTransactionExecutor.class); + longRunningTasksRegistry = mock(LongRunningTasksRegistry.class); + + sut = new ThrottleAspect(throttledFutures, lastRun, scheduledFutures, taskScheduler, mockedClock, transactionExecutor, longRunningTasksRegistry, configuration); + } + + /** + * @return current timestamp based on mocked {@link #clock} + */ + Instant getInstant() { + return clock.instant().truncatedTo(ChronoUnit.SECONDS); + } + + void addSecond() { + clock = Clock.fixed(clock.instant().plusSeconds(1), ZoneId.of("UTC")); + } + + void skipThreshold() { + clock = Clock.fixed(clock.instant().plus(configuration.getThrottleThreshold()), ZoneId.of("UTC")); + } + + void skipDiscardThreshold() { + clock = Clock.fixed(clock.instant() + .plus(configuration.getThrottleDiscardThreshold()) + .plus(configuration.getThrottleThreshold()) + .plusSeconds(1), + ZoneId.of("UTC")); + } + + /** + * Executes all tasks in {@link #taskSchedulerTasks} and clears the map. + */ + void executeScheduledTasks() { + taskSchedulerTasks.forEach((runnable, instant) -> runnable.run()); + taskSchedulerTasks.clear(); + } + + void joinThread(Thread thread) throws InterruptedException { + thread.join(THREAD_JOIN_TIMEOUT_MILLIS); + if (thread.isAlive()) { + thread.interrupt(); + fail("task thread thread interrupted due to timeout"); + } + } + + /** + * If a task was executed more than a threshold period before, it should NOT be debounced + */ + @Test + void firstCallAfterThresholdIsScheduledImmediately() throws Throwable { + sut.throttleMethodCall(joinPointA, throttleA); // first call of the method + executeScheduledTasks(); // simulate that the task was executed + + // simulate that there was a delay before next call, + // and it was greater than threshold + skipThreshold(); + addSecond(); + + final Instant expectedTime = getInstant(); // note current time + sut.throttleMethodCall(joinPointA, throttleA); // second call of the method + + // verify that the task from the second call was scheduled immediately + // because the last time, task was executed was before more than the threshold period + assertEquals(1, taskSchedulerTasks.size()); + assertEquals(expectedTime, taskSchedulerTasks.getValue(0)); + } + + /** + * Calling the annotated method three times + * will execute it only once with the newest data. + * Task is scheduled by the first call. + */ + @Test + void threeImmediateCallsScheduleFirstCallWithLastTask() throws Throwable { + // define a future as the return type of the method + signatureA.setReturnType(Future.class); + + final Supplier methodResult = () -> "method result"; + final Supplier anotherMethodResult = () -> "another method result"; + + final ThrottledFuture methodFuture = ThrottledFuture.of(methodResult); + + // for each method call, make new future + doAnswer(invocation -> ThrottledFuture.of(anotherMethodResult)).when(joinPointA).proceed(); + + final Instant firstCall = getInstant(); + // simulate first call + sut.throttleMethodCall(joinPointA, throttleA); + + addSecond(); + // simulate second call + sut.throttleMethodCall(joinPointA, throttleA); // both tasks should return anotherMethodResult + addSecond(); + + // change the return value of the method to the prepared future + doReturn(methodFuture).when(joinPointA).proceed(); + + // simulate last call + sut.throttleMethodCall(joinPointA, throttleA); // should return methodResult + + // there should be only a single scheduled future + // threshold was not reached and no task was executed, calls should be merged + // scheduled for immediate execution from the first call with the newest data from the last call + assertEquals(1, scheduledFutures.size()); + assertEquals(1, taskSchedulerTasks.size()); + assertEquals(1, throttledFutures.size()); + + final Instant scheduledAt = taskSchedulerTasks.getValue(0); + final Runnable scheduledTask = taskSchedulerTasks.getKey(0); + assertNotNull(scheduledAt); + assertNotNull(scheduledTask); + // the task should be scheduled at the first call + assertEquals(firstCall, scheduledAt); + + final ThrottledFuture future = throttledFutures.getValue(0); + assertNotNull(future); + + // perform task execution + executeScheduledTasks(); + // the future should be completed + assertTrue(future.isDone()); + // check that the task in the future is from the last method call + assertEquals(methodResult.get(), future.get()); + } + + /** + * When method is called in the throttle interval + * calls are merged and method will be executed only once. + * Ensures that both futures returned from method calls are same. + */ + @Test + void callsInThrottleIntervalAreMerged() throws Throwable { + final String[] params = new String[]{"param1", "param2", "param3", "param4", "param5", "param6"}; + // define a future as the return type of the method + signatureA.setReturnType(Future.class); + + // for each method call, make new future with "another method task" + doAnswer(invocation -> new ThrottledFuture()).when(joinPointA).proceed(); + + // simulate first call + when(joinPointA.getArgs()).thenReturn(new Object[]{params[0], params[1]}); + final Object result1 = sut.throttleMethodCall(joinPointA, throttleA); + + addSecond(); + // simulate second call + when(joinPointA.getArgs()).thenReturn(new Object[]{params[2], params[3]}); + final Object result2 = sut.throttleMethodCall(joinPointA, throttleA); + + // both calls returned the same future + // this ensures that calls are actually merged and a single result satisfies all merged calls + assertInstanceOf(Future.class, result1); + assertInstanceOf(Future.class, result2); + assertEquals(result1, result2); + } + + /** + * Within the threshold interval, when a task from first call is already executed, during a new call, + * new future is scheduled. + */ + @Test + @SuppressWarnings("unchecked") + void schedulesNewFutureWhenTheOldOneIsCompletedDuringThreshold() throws Throwable { + // set return type as future + signatureA.setReturnType(Future.class); + // return new throttled future on each method call + when(joinPointA.proceed()).then(invocation -> ThrottledFuture.of(() -> "result")); + + // first call of the method + ThrottledFuture firstFuture = (ThrottledFuture) sut.throttleMethodCall(joinPointA, throttleA); + addSecond(); // changing time (but not more than the threshold) + + // verify that a future was returned + assertNotNull(firstFuture); + // verify that the future is pending + assertFalse(firstFuture.isDone()); + assertFalse(firstFuture.isCancelled()); + assertFalse(firstFuture.isRunning()); + + // verify that a single task was scheduled + assertEquals(1, taskSchedulerTasks.size()); + // execute the task + executeScheduledTasks(); + // verify that the task was completed + assertTrue(firstFuture.isDone()); + assertFalse(firstFuture.isCancelled()); + assertFalse(firstFuture.isRunning()); + + // perform a second call, throttled interval was not reached + ThrottledFuture secondFuture = (ThrottledFuture) sut.throttleMethodCall(joinPointA, throttleA); + addSecond(); + + // verify returned second future + assertNotNull(secondFuture); + + // verify that returned futures are not same + assertNotEquals(firstFuture, secondFuture); + + // it was not completed yet + assertFalse(secondFuture.isDone()); + assertFalse(secondFuture.isCancelled()); + assertFalse(secondFuture.isRunning()); + + // verify a new future was scheduled + assertEquals(1, scheduledFutures.size()); + assertEquals(1, taskSchedulerTasks.size()); + // execute new task + executeScheduledTasks(); + + // the new future was completed + assertTrue(secondFuture.isDone()); + assertFalse(secondFuture.isCancelled()); + assertFalse(secondFuture.isRunning()); + } + + /** + * Ensures that calling the annotated method even outside the threshold + * merges calls when no future was resolved yet (and task is not running). + */ + @SuppressWarnings("unchecked") + @Test + void callsAreMergedWhenCalledOutsideTheThresholdButNoFutureExecutedYet() throws Throwable { + // change return type to future + signatureA.setReturnType(Future.class); + + final String firstResult = "first result"; + final String secondResult = "second result"; + + // on each method call return a new throttled future with firstResult + when(joinPointA.proceed()).then(invocation -> ThrottledFuture.of(() -> firstResult)); + + // first method call + Future firstFuture = (Future) sut.throttleMethodCall(joinPointA, throttleA); + + // ensure that threshold was reached + skipThreshold(); + addSecond(); + + // change method call result to throttled future with secondResult + when(joinPointA.proceed()).then(invocation -> ThrottledFuture.of(() -> secondResult)); + + // second method call + Future secondFuture = (Future) sut.throttleMethodCall(joinPointA, throttleA); + + // verify that the returned future is not null and was not completed yet + assertNotNull(firstFuture); + assertFalse(firstFuture.isDone()); + assertFalse(firstFuture.isCancelled()); + + // verify that calls were merged and returned futures are same + assertEquals(firstFuture, secondFuture); + + // only one task was scheduled + assertEquals(1, scheduledFutures.size()); + assertEquals(1, taskSchedulerTasks.size()); + + executeScheduledTasks(); + + assertTrue(firstFuture.isDone()); + assertFalse(firstFuture.isCancelled()); + + // verify that the future was resolved with the newest call data + assertEquals(secondResult, firstFuture.get()); + + assertTrue(firstFuture.isDone()); + assertFalse(firstFuture.isCancelled()); + assertEquals(secondResult, firstFuture.get()); + } + + @Test + void cancelsAllScheduledFuturesWhenNewTaskWithLowerGroupIsScheduled() throws Throwable { + throttleA.setGroup("'the.group.identifier.first'"); + throttleB.setGroup("'the.group.identifier.second'"); + throttleC.setGroup("'the.group.identifier'"); + + sut.throttleMethodCall(joinPointA, throttleA); + sut.throttleMethodCall(joinPointB, throttleB); + + final Map> futures = Map.copyOf(scheduledFutures); + + assertEquals(2, throttledFutures.size()); + assertEquals(2, scheduledFutures.size()); + assertEquals(2, taskSchedulerTasks.size()); + + sut.throttleMethodCall(joinPointC, throttleC); + + assertEquals(1, throttledFutures.size()); + assertEquals(1, scheduledFutures.size()); + assertEquals(3, taskSchedulerTasks.size()); + + assertEquals(2, futures.size()); + futures.forEach((k, f) -> assertTrue(f.isCancelled())); + } + + @Test + void immediatelyCancelsNewFutureWhenLowerGroupIsAlreadyScheduled() throws Throwable { + throttleA.setGroup("'the.group.identifier'"); + throttleB.setGroup("'the.group.identifier.with.higher.value'"); + + signatureB.setReturnType(Future.class); + when(joinPointB.proceed()).then(invocation -> new ThrottledFuture<>()); + + sut.throttleMethodCall(joinPointA, throttleA); + + final Map> futures = Map.copyOf(scheduledFutures); + + Object result = sut.throttleMethodCall(joinPointB, throttleB); + assertNotNull(result); + assertInstanceOf(ThrottledFuture.class, result); + Future secondCall = (Future) result; + + assertEquals(1, scheduledFutures.size()); + + final Future oldFuture = futures.values().iterator().next(); + final Future currentFuture = scheduledFutures.values().iterator().next(); + assertEquals(oldFuture, currentFuture); + assertFalse(currentFuture.isDone()); + assertFalse(currentFuture.isCancelled()); + + assertTrue(secondCall.isCancelled()); + } + + /** + * When a thread is executing a task from throttled method, and it reaches another throttled method, + * no further task should be scheduled and the throttled method should be executed synchronously. + */ + @Test + void callToThrottledMethodReturningVoidFromAlreadyThrottledThreadResultsInSynchronousExecution() throws Throwable { + AtomicLong threadId = new AtomicLong(-1); + + // prepare a simulated nested throttled method + when(joinPointB.proceed()).then(invocation -> { + threadId.set(Thread.currentThread().getId()); + return null; // void return type + }); + + // when method A is executed, call throttled method B + when(joinPointA.proceed()).then(invocation -> { + sut.throttleMethodCall(joinPointB, throttleB); + return null; // void return type + }); + + sut.throttleMethodCall(joinPointA, throttleA); + + // execute a single scheduled task + Thread runThread = new Thread(taskSchedulerTasks.getKey(0)); + runThread.start(); + + joinThread(runThread); + + assertNotEquals(-1, threadId.get()); + assertEquals(runThread.getId(), threadId.get()); + } + + /** + * Same as {@link #callToThrottledMethodReturningVoidFromAlreadyThrottledThreadResultsInSynchronousExecution} + * but with method returning a future + */ + @Test + void callToThrottledMethodReturningFutureFromAlreadyThrottledThreadResultsInSynchronousExecution() throws Throwable { + AtomicLong threadId = new AtomicLong(-1); + + signatureA.setReturnType(Future.class); + signatureB.setReturnType(Future.class); + + // prepare a simulated nested throttled method + when(joinPointB.proceed()).then(invocation -> ThrottledFuture.of(()-> + threadId.set(Thread.currentThread().getId()) + )); + + // when method A is executed, call throttled method B + when(joinPointA.proceed()).then(invocation -> ThrottledFuture.of(() -> { + try { + sut.throttleMethodCall(joinPointB, throttleB); + } catch (Throwable t) { + fail(t); + } + })); + + sut.throttleMethodCall(joinPointA, throttleA); + + // execute a single scheduled task + Thread runThread = new Thread(taskSchedulerTasks.getKey(0)); + runThread.start(); + + joinThread(runThread); + + assertNotEquals(-1, threadId.get()); + assertEquals(runThread.getId(), threadId.get()); + } + + /** + * When a throttled method is annotated with {@link Transactional @Transactional} + * the asynchronous task should be executed with {@link SynchronousTransactionExecutor} + * by a throttled thread. + */ + @Test + void taskFromMethodAnnotatedWithTransactionalIsExecutedWithTransactionExecutor() throws Throwable { + // simulates a method object with transactional annotation + when(signatureA.getMethod()).thenReturn(SynchronousTransactionExecutor.class.getDeclaredMethod("execute", Runnable.class)); + signatureA.setReturnType(Future.class); + Runnable task = () -> {}; + when(joinPointA.proceed()).thenReturn(ThrottledFuture.of(task)); + + ThrottledFuture result = (ThrottledFuture) sut.throttleMethodCall(joinPointA, throttleA); + + assertNotNull(result); + verifyNoInteractions(transactionExecutor); + + executeScheduledTasks(); + + verify(transactionExecutor).execute(any()); + } + + /** + * When a task is executed, all three maps are cleared from + * entries older than {@link Configuration#throttleDiscardThreshold throttleDiscardThreshold} + * plus {@link Configuration#throttleThreshold throttleThreshold} + */ + @Test + void allMapsAreClearedAfterDiscardThreshold() throws Throwable { + sut.throttleMethodCall(joinPointA, throttleA); + sut.throttleMethodCall(joinPointB, throttleB); + sut.throttleMethodCall(joinPointC, throttleC); + skipThreshold(); + executeScheduledTasks(); + sut.throttleMethodCall(joinPointA, throttleA); + sut.throttleMethodCall(joinPointB, throttleB); + sut.throttleMethodCall(joinPointC, throttleC); + executeScheduledTasks(); + skipThreshold(); + sut.throttleMethodCall(joinPointA, throttleA); + sut.throttleMethodCall(joinPointB, throttleB); + sut.throttleMethodCall(joinPointC, throttleC); + addSecond(); + executeScheduledTasks(); + + // skip discard threshold + skipDiscardThreshold(); + sut.throttleMethodCall(joinPointA, throttleA); + + executeScheduledTasks(); + + // only single task left (the last one, which cleared the maps) + assertEquals(1, scheduledFutures.size()); + assertEquals(1, throttledFutures.size()); + assertEquals(1, lastRun.size()); + } + + + @Test + void aspectDoesNotThrowWhenMethodReturnsUnboxedVoidBySignature() throws Throwable { + signatureA.setReturnType(Void.TYPE); + when(joinPointA.proceed()).thenReturn(null); + + assertDoesNotThrow(() -> sut.throttleMethodCall(joinPointA, throttleA)); + } + + @Test + void aspectDoesNotThrowWhenMethodReturnsBoxedVoidBySignature() throws Throwable { + signatureA.setReturnType(Void.class); + when(joinPointA.proceed()).thenReturn(null); + + assertDoesNotThrow(() -> sut.throttleMethodCall(joinPointA, throttleA)); + } + + @Test + void aspectDoesNotThrowWhenMethodReturnsFutureBySignature() throws Throwable { + signatureA.setReturnType(Future.class); + when(joinPointA.proceed()).then(invocation -> new ThrottledFuture<>()); + + assertDoesNotThrow(() -> sut.throttleMethodCall(joinPointA, throttleA)); + } + + @Test + void aspectDoesNotThrowWhenMethodReturnsThrottledFutureBySignature() throws Throwable { + signatureA.setReturnType(ThrottledFuture.class); + when(joinPointA.proceed()).then(invocation -> new ThrottledFuture<>()); + + assertDoesNotThrow(() -> sut.throttleMethodCall(joinPointA, throttleA)); + } + + @ParameterizedTest + // just few sample classes + @ValueSource(classes = {String.class, Integer.class, Optional.class, FutureTask.class, CompletableFuture.class}) + void aspectThrowsWhenMethodNotReturnsVoidOrFutureBySignature(Class returnType) throws Throwable { + signatureA.setReturnType(returnType); + when(joinPointA.proceed()).thenReturn(new Object()); + + assertThrows(ThrottleAspectException.class, () -> sut.throttleMethodCall(joinPointA, throttleA)); + } + + @Test + void aspectDoesThrowWhenMethodReturnsNullByValueAndFutureBySignature() throws Throwable { + signatureA.setReturnType(Future.class); + when(joinPointA.proceed()).thenReturn(null); + + assertThrows(ThrottleAspectException.class, () -> sut.throttleMethodCall(joinPointA, throttleA)); + } + + @ParameterizedTest + @ValueSource(classes = {Future.class, ThrottledFuture.class}) + void aspectThrowsWhenMethodDoesNotReturnsThrottledFutureObject(Class returnType) throws Throwable { + signatureA.setReturnType(returnType); + when(joinPointA.proceed()).thenReturn(new FutureTask<>(() -> "")); + + assertThrows(ThrottleAspectException.class, () -> sut.throttleMethodCall(joinPointA, throttleA)); + } + + @ParameterizedTest + @ValueSource(classes = {Future.class, ThrottledFuture.class}) + void aspectThrowsWhenMethodReturnsNullWithFutureBySignature(Class returnType) throws Throwable { + signatureA.setReturnType(returnType); + when(joinPointA.proceed()).thenReturn(null); + + assertThrows(ThrottleAspectException.class, () -> sut.throttleMethodCall(joinPointA, throttleA)); + } + + @Test + void aspectResolvesThrottleGroupProviderClassInSpEL() throws Throwable { + throttleA.setGroup("T(ThrottleGroupProvider).getTextAnalysisVocabulariesAll()"); + + sut.throttleMethodCall(joinPointA, throttleA); + + final String expectedGroup = ThrottleGroupProvider.getTextAnalysisVocabulariesAll(); + final String resolvedGroup = scheduledFutures.firstEntry().getKey().getGroup(); + assertEquals(expectedGroup, resolvedGroup); + } + + @Test + void exceptionPropagatedWhenJoinPointProceedThrows() throws Throwable { + when(joinPointA.proceed()).thenThrow(new RuntimeException()); + + sut.throttleMethodCall(joinPointA, throttleA); + + assertDoesNotThrow(() -> taskSchedulerTasks.forEach((r, i) -> r.run())); + } + + /** + * Ensures that when an exception is thrown during throttled task execution, + * it is stored within the future and rethrown on future#get. + */ + @Test + void exceptionPropagatedFromFutureTask() throws Throwable { + final String exceptionMessage = "termit exception"; + when(joinPointA.proceed()).then(invocation -> ThrottledFuture.of(() -> { + throw new TermItException(exceptionMessage); + })); + signatureA.setReturnType(Future.class); + + sut.throttleMethodCall(joinPointA, throttleA); + + assertEquals(1, taskSchedulerTasks.size()); + assertEquals(1, scheduledFutures.size()); + assertEquals(1, throttledFutures.size()); + Runnable scheduled = taskSchedulerTasks.getKey(0); + Future scheduledFuture = scheduledFutures.firstEntry().getValue(); + Future future = throttledFutures.getValue(0); + + assertNotNull(scheduled); + assertNotNull(future); + assertDoesNotThrow(scheduled::run); // exception is thrown here, but future stores it + // exception is then re-thrown during future#get() + ExecutionException e = assertThrows(ExecutionException.class, future::get); + assertEquals(exceptionMessage, e.getCause().getMessage()); + assertTrue(scheduledFuture.isDone()); + assertTrue(future.isDone()); + } + + @Test + void resolvedFutureFromMethodIsReturnedWithoutSchedule() throws Throwable { + signatureA.setReturnType(Future.class); + final String result = "result of the method"; + when(joinPointA.proceed()).then(invocation -> ThrottledFuture.done(result)); + + Future future = (Future) sut.throttleMethodCall(joinPointA, throttleA); + + assertNotNull(future); + assertTrue(future.isDone()); + assertFalse(future.isCancelled()); + assertEquals(result, future.get()); + assertTrue(scheduledFutures.isEmpty()); + assertTrue(taskSchedulerTasks.isEmpty()); + } + + @Test + void aspectThrowsOnMalformedSpel() { + throttleA.setValue("invalid spel expression"); + assertThrows(SpelParseException.class, () -> sut.throttleMethodCall(joinPointA, throttleA)); + } + + @Test + void aspectThrowsOnInvalidSpelParamReference() { + throttleA.setValue("{#nonExistingParameter}"); + assertThrows(ThrottleAspectException.class, () -> sut.throttleMethodCall(joinPointA, throttleA)); + } + + @Test + void aspectDoesNotThrowsOnStringLiteral() { + throttleA.setValue("'valid spel'"); + assertDoesNotThrow(() -> sut.throttleMethodCall(joinPointA, throttleA)); + } + + @Test + void aspectDoesNotThrowsOnEmptyIdentifier() { + throttleA.setValue(""); + assertDoesNotThrow(() -> sut.throttleMethodCall(joinPointA, throttleA)); + } + + @Test + void aspectDownNotThrowsOnEmptyGroup() { + throttleA.setGroup(""); + assertDoesNotThrow(() -> sut.throttleMethodCall(joinPointA, throttleA)); + } + + @Test + void aspectConstructsFromAutowiredConstructor() { + assertDoesNotThrow(() -> new ThrottleAspect(taskScheduler, transactionExecutor, longRunningTasksRegistry, configuration)); + } + + @Test + void futureWithHigherGroupIsNotCanceledWhenFutureWithLowerGroupIsCanceled() throws Throwable { + // future C has lower group than future A and B + sut.throttleMethodCall(joinPointC, throttleC); + // cancel the scheduled future + scheduledFutures.firstEntry().getValue().cancel(false); + + signatureA.setReturnType(Future.class); + when(joinPointA.proceed()).thenReturn(ThrottledFuture.of(() -> null)); + + Future higherFuture = (Future) sut.throttleMethodCall(joinPointA, throttleA); + + assertNotNull(higherFuture); + assertFalse(higherFuture.isCancelled()); + assertEquals(2, scheduledFutures.size()); + } + + @Test + void newScheduleWithBlankGroupDoesNotCancelsAnyOtherFuture() throws Throwable { + throttleA.setGroup(""); + throttleC.setGroup(""); + + sut.throttleMethodCall(joinPointA, throttleA); // blank group + sut.throttleMethodCall(joinPointB, throttleB); // non blank group + sut.throttleMethodCall(joinPointC, throttleC); // blank group + // no future should be canceled + + assertEquals(3, throttledFutures.size()); + assertEquals(3, scheduledFutures.size()); + + Stream.concat(throttledFutures.values().stream(), scheduledFutures.values().stream()) + .forEach(future -> assertFalse(future.isCancelled())); + } + + @Test + void mapsAreNotClearedWhenFutureIsNotDone() throws Throwable { + sut.throttleMethodCall(joinPointA, throttleA); // blank group + skipDiscardThreshold(); + sut.throttleMethodCall(joinPointB, throttleB); // non blank group + + taskSchedulerTasks.getKey(1).run(); + + assertEquals(2, scheduledFutures.size()); + assertEquals(2, throttledFutures.size()); + } + + /** + * Scenario:
+ *
    + *
  1. Method is called
  2. + *
  3. Task is scheduled
  4. + *
  5. Task execution starts
  6. + *
  7. Method is called again
  8. + *
  9. New task should be scheduled, but the old one is still executing
  10. + *
+ * This test verify that the second task won't start execution until the old one finishes. + */ + @Test + void noTwoTasksWithTheSameIdentifierShouldBeExecutedConcurrently() throws Throwable { + final AtomicBoolean taskRunning = new AtomicBoolean(false); + final AtomicBoolean allowFinish = new AtomicBoolean(false); + when(joinPointA.proceed()).then(invocation -> { + taskRunning.set(true); + while(!allowFinish.get()) { + Thread.yield(); + } + return null; + }); + + sut.throttleMethodCall(joinPointA, throttleA); + + final Thread firstTask = new Thread(taskSchedulerTasks.getKey(0)); + firstTask.start(); + + await("task execution start").atMost(Duration.ofSeconds(30)).untilTrue(taskRunning); + + assertEquals(1, taskSchedulerTasks.size()); + final ThrottledFuture oldFuture = throttledFutures.getValue(0); + + sut.throttleMethodCall(joinPointA, throttleA); + + assertEquals(1, taskSchedulerTasks.size()); + assertNotEquals(oldFuture, throttledFutures.getValue(0)); + + allowFinish.set(true); + joinThread(firstTask); + + assertEquals(2, taskSchedulerTasks.size()); // new task scheduled after the old one finished + } +} diff --git a/src/test/java/cz/cvut/kbss/termit/util/throttle/ThrottleAspectTestContextConfig.java b/src/test/java/cz/cvut/kbss/termit/util/throttle/ThrottleAspectTestContextConfig.java new file mode 100644 index 000000000..029cc0455 --- /dev/null +++ b/src/test/java/cz/cvut/kbss/termit/util/throttle/ThrottleAspectTestContextConfig.java @@ -0,0 +1,36 @@ +package cz.cvut.kbss.termit.util.throttle; + +import cz.cvut.kbss.termit.util.Configuration; +import org.mockito.Mockito; +import org.springframework.boot.test.context.TestConfiguration; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.ComponentScan; +import org.springframework.context.annotation.EnableAspectJAutoProxy; +import org.springframework.context.annotation.ImportResource; +import org.springframework.context.annotation.aspectj.EnableSpringConfigured; +import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; + +import static org.mockito.Answers.RETURNS_SMART_NULLS; + +@TestConfiguration +@EnableSpringConfigured +@ImportResource("classpath*:spring-aop.xml") +@EnableAspectJAutoProxy(proxyTargetClass = true) +@ComponentScan(value = "cz.cvut.kbss.termit.util.throttle") +public class ThrottleAspectTestContextConfig { + + @Bean + public ThreadPoolTaskScheduler longRunningTaskScheduler() { + return Mockito.mock(ThreadPoolTaskScheduler.class, RETURNS_SMART_NULLS); + } + + @Bean + public ThrottledService throttledService() { + return new ThrottledService(); + } + + @Bean + public Configuration configuration() { + return new Configuration(); + } +} diff --git a/src/test/java/cz/cvut/kbss/termit/util/throttle/ThrottledFutureTest.java b/src/test/java/cz/cvut/kbss/termit/util/throttle/ThrottledFutureTest.java new file mode 100644 index 000000000..bf8f4f4e0 --- /dev/null +++ b/src/test/java/cz/cvut/kbss/termit/util/throttle/ThrottledFutureTest.java @@ -0,0 +1,447 @@ +package cz.cvut.kbss.termit.util.throttle; + +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.test.util.ReflectionTestUtils; + +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Consumer; +import java.util.function.Supplier; + +import static org.awaitility.Awaitility.await; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.notNull; +import static org.mockito.Mockito.doCallRealMethod; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +class ThrottledFutureTest { + + @Test + void cancelledFactoryMethodReturnsCancelledFuture() { + final ThrottledFuture future = ThrottledFuture.canceled(); + assertTrue(future.isCancelled()); + assertTrue(future.isDone()); // future is done when it is cancelled + assertFalse(future.isRunning()); + } + + @Test + void doneFactoryMethodReturnsDoneFuture() throws Throwable { + final Object result = new Object(); + final ThrottledFuture future = ThrottledFuture.done(result); + assertFalse(future.isCancelled()); + assertTrue(future.isDone()); + assertFalse(future.isRunning()); + final Object futureResult = future.get(1, TimeUnit.SECONDS); + assertNotNull(futureResult); + assertEquals(result, futureResult); + } + + @Test + void getNowReturnsCacheWhenCacheIsAvailable() { + final Object cache = new Object(); + final ThrottledFuture future = ThrottledFuture.of(Object::new).setCachedResult(cache); + final Optional cached = future.getNow(); + assertNotNull(cached); + assertTrue(cached.isPresent()); + assertEquals(cache, cached.get()); + } + + @Test + void getNowReturnsEmptyWhenCacheIsNotAvailable() { + final ThrottledFuture future = ThrottledFuture.of(Object::new); + final Optional cached = future.getNow(); + assertNotNull(cached); + assertTrue(cached.isEmpty()); + } + + @Test + void getNowReturnsEmptyWhenCacheIsNull() { + final ThrottledFuture future = ThrottledFuture.of(Object::new).setCachedResult(null); + final Optional cached = future.getNow(); + assertNotNull(cached); + assertTrue(cached.isEmpty()); + } + + @Test + void thenActionIsExecutedSynchronouslyWhenFutureIsAlreadyDoneAndNotCanceled() { + final Object result = new Object(); + final ThrottledFuture future = ThrottledFuture.of(() -> result); + final AtomicBoolean completed = new AtomicBoolean(false); + final AtomicReference futureResult = new AtomicReference<>(null); + future.run(null); + assertTrue(future.isDone()); + assertFalse(future.isCancelled()); + future.then(fResult -> { + completed.set(true); + futureResult.set(fResult); + }); + assertTrue(completed.get()); + assertEquals(result, futureResult.get()); + } + + @Test + void thenActionIsNotExecutedWhenFutureIsAlreadyCancelled() { + final ThrottledFuture future = ThrottledFuture.of(Object::new); + final AtomicBoolean completed = new AtomicBoolean(false); + future.cancel(false); + assertTrue(future.isCancelled()); + future.then(result -> completed.set(true)); + assertFalse(completed.get()); + } + + @Test + void thenActionIsExecutedOnceFutureIsRun() { + final Object result = new Object(); + final AtomicBoolean completed = new AtomicBoolean(false); + final AtomicReference fResult = new AtomicReference<>(null); + final ThrottledFuture future = ThrottledFuture.of(() -> result); + future.then(futureResult -> { + completed.set(true); + fResult.set(futureResult); + }); + assertNull(fResult.get()); + assertFalse(completed.get()); // action was not executed yet + future.run(null); + assertTrue(completed.get()); + assertEquals(result, fResult.get()); + } + + @Test + void thenActionIsNotExecutedOnceFutureIsCancelled() { + final Object result = new Object(); + final AtomicBoolean completed = new AtomicBoolean(false); + final ThrottledFuture future = ThrottledFuture.of(() -> result); + future.then(futureResult -> completed.set(true)); + assertFalse(completed.get()); // action was not executed yet + future.cancel(false); + assertFalse(completed.get()); + } + + @Test + void callingRunWillExecuteFutureOnlyOnce() { + AtomicInteger count = new AtomicInteger(0); + final ThrottledFuture future = ThrottledFuture.of(() -> { + count.incrementAndGet(); + }); + + future.run(null); + final Optional runningSince = future.startedAt(); + assertTrue(runningSince.isPresent()); + assertTrue(future.isDone()); + assertFalse(future.isCancelled()); + assertFalse(future.isRunning()); + + future.run(null); + assertTrue(future.isDone()); + assertFalse(future.isCancelled()); + assertFalse(future.isRunning()); + + // verify that timestamp did not change + assertTrue(future.startedAt().isPresent()); + assertEquals(runningSince.get(), future.startedAt().get()); + } + + /** + * Verifies locks and that second thread exists fast when calls run on already running future. + */ + @Test + void callingRunWillExecuteFutureOnlyOnceAndWontBlockSecondThreadAsync() throws Throwable { + AtomicBoolean allowExit = new AtomicBoolean(false); + AtomicInteger count = new AtomicInteger(0); + final ThrottledFuture future = ThrottledFuture.of(() -> { + count.incrementAndGet(); + while (!allowExit.get()) { + Thread.yield(); + } + }); + final Thread threadA = new Thread(() -> future.run(null)); + final Thread threadB = new Thread(() -> future.run(null)); + threadA.start(); + + await("count incrementation").atMost(Duration.ofSeconds(30)).until(() -> count.get() > 0); + // now there is a threadA spinning in the future task + // locks in the future should be held + assertTrue(future.isRunning()); + assertFalse(future.isDone()); + assertFalse(future.isCancelled()); + + final Optional runningSince = future.startedAt(); + assertTrue(runningSince.isPresent()); + + threadB.start(); + + // thread B should not be blocked + await("threadB start").atMost(Duration.ofSeconds(30)).until(() -> threadB.getState().equals(Thread.State.TERMINATED)); + assertTrue(future.isRunning()); + + allowExit.set(true); + threadA.join(60 * 1000); + threadB.join(60 * 1000); + + assertFalse(threadA.isAlive()); + assertFalse(threadB.isAlive()); + + assertEquals(1, count.get()); + assertTrue(future.startedAt().isPresent()); + assertEquals(runningSince.get(), future.startedAt().get()); + } + + @Test + void getNowReturnsCachedResultWhenItsAvailable() { + final String futureResult = "future"; + final String cachedResult = "cached"; + ThrottledFuture future = ThrottledFuture.of(() -> futureResult).setCachedResult(cachedResult); + + Optional result = future.getNow(); + assertTrue(result.isPresent()); + assertEquals(cachedResult, result.get()); + } + + @Test + void getNowReturnsEmptyWhenCacheIsNotSet() { + final String futureResult = "future"; + ThrottledFuture future = ThrottledFuture.of(() -> futureResult); + + Optional result = future.getNow(); + assertTrue(result.isEmpty()); + } + + @Test + void getNowReturnsEmptyWhenNullCacheIsSet() { + final String futureResult = "future"; + ThrottledFuture future = ThrottledFuture.of(() -> futureResult).setCachedResult(null); + + Optional result = future.getNow(); + assertTrue(result.isEmpty()); + } + + @Test + void getNowReturnsFutureResultWhenItsDoneAndNotCancelled() { + final String futureResult = "future"; + final String cachedResult = "cached"; + ThrottledFuture future = ThrottledFuture.of(() -> futureResult).setCachedResult(cachedResult); + future.run(null); + + Optional result = future.getNow(); + assertTrue(result.isPresent()); + assertEquals(futureResult, result.get()); + } + + @Test + void getNowReturnsCachedResultWhenFutureIsCancelled() { + final String futureResult = "future"; + final String cachedResult = "cached"; + ThrottledFuture future = ThrottledFuture.of(() -> futureResult).setCachedResult(cachedResult); + future.cancel(false); + + Optional result = future.getNow(); + assertTrue(result.isPresent()); + assertEquals(cachedResult, result.get()); + } + + @Test + void onCompletionCallbacksAreNotExecutedWhenTaskIsNull() { + final AtomicBoolean callbackExecuted = new AtomicBoolean(false); + final ThrottledFuture future = new ThrottledFuture<>(); + future.then(ignored -> callbackExecuted.set(true)); + future.run(null); + assertFalse(callbackExecuted.get()); + } + + @Test + void transferUpdatesSecondFutureWithTask() { + final Supplier firstTask = () -> null; + final ThrottledFuture firstFuture = ThrottledFuture.of(firstTask); + final ThrottledFuture secondFuture = mock(ThrottledFuture.class); + + firstFuture.transfer(secondFuture); + + verify(secondFuture).update(eq(firstTask), anyList()); + + // now verifies that the task in the first future is null + Object task = ReflectionTestUtils.getField(firstFuture, "task"); + assertNull(task); + assertTrue(firstFuture.isCancelled()); + } + + @Test + void transferUpdatesSecondFutureWithCallbacks() { + final Consumer firstCallback = (result) -> {}; + final Consumer secondCallback = (result) -> {}; + final ThrottledFuture firstFuture = ThrottledFuture.of(()->"").then(firstCallback); + final ThrottledFuture secondFuture = ThrottledFuture.of(()->"").then(secondCallback); + final ThrottledFuture mocked = mock(ThrottledFuture.class); + final List> captured = new ArrayList<>(2); + + when(mocked.update(any(), any())).then(invocation -> { + captured.addAll(invocation.getArgument(1, List.class)); + return mocked; + }); + + firstFuture.transfer(secondFuture); + secondFuture.transfer(mocked); + + verify(mocked).update(notNull(), notNull()); + assertEquals(2, captured.size()); + assertTrue(captured.contains(firstCallback)); + // verifies that callbacks are added to the current ones and do not replace them + assertTrue(captured.contains(secondCallback)); + } + + @Test + void callbacksAreClearedAfterTransferring() { + final Consumer firstCallback = (result) -> {}; + final Consumer secondCallback = (result) -> {}; + final ThrottledFuture future = ThrottledFuture.of(()->"").then(firstCallback).then(secondCallback); + final ThrottledFuture mocked = mock(ThrottledFuture.class); + + future.transfer(mocked); + + final ArgumentCaptor>> captor = ArgumentCaptor.forClass(List.class); + + verify(mocked).update(notNull(), captor.capture()); + // captor takes the original list from the future + // which is cleared afterward + assertTrue(captor.getValue().isEmpty()); + } + + @Test + void transferReturnsTargetWhenFutureIsRunning() { + final ThrottledFuture future = spy(ThrottledFuture.of(()->"")); + final ThrottledFuture target = ThrottledFuture.of(()->""); + when(future.isRunning()).thenReturn(true); + doCallRealMethod().when(future).transfer(any()); + + final ThrottledFuture result = future.transfer(target); + assertEquals(target, result); + } + + @Test + void transferReturnsTargetWhenFutureIsDone() { + final ThrottledFuture future = ThrottledFuture.done(""); + final ThrottledFuture target = ThrottledFuture.of(()->""); + + final ThrottledFuture result = future.transfer(target); + assertEquals(target, result); + } + + @Test + void transferReturnsTargetWhenLockIsNotLockedForTransfer() throws Throwable { + final ReentrantLock futureLock = new ReentrantLock(); + final ThrottledFuture future = ThrottledFuture.of(()->""); + final ThrottledFuture target = ThrottledFuture.of(()->""); + + final Thread thread = new Thread(futureLock::lock); + thread.start(); + thread.join(); + + ReflectionTestUtils.setField(future, "lock", futureLock); + + final ThrottledFuture result = future.transfer(target); + assertEquals(target, result); + } + + @Test + void updateSetsTask() { + final Supplier task = ()->""; + final ThrottledFuture future = ThrottledFuture.of(() -> ""); + + future.update(task, List.of()); + + assertEquals(task, ReflectionTestUtils.getField(future, "task")); + } + + @Test + void updateAddsCallbacksToTheCurrentOnes() { + final Consumer callback = result -> {}; + final Consumer originalCallback = result -> {}; + final ThrottledFuture future = ThrottledFuture.of(() -> "").then(originalCallback); + + future.update(()->"", List.of(callback)); + + final Collection> callbacks = + (Collection>) ReflectionTestUtils.getField(future, "onCompletion"); + + assertNotNull(callbacks); + assertEquals(2, callbacks.size()); + assertTrue(callbacks.contains(originalCallback)); + assertTrue(callbacks.contains(callback)); + } + + @Test + void updateReturnsNewFutureWhenFutureIsRunning() { + final ThrottledFuture future = spy(ThrottledFuture.of(()->"")); + when(future.isRunning()).thenReturn(true); + doCallRealMethod().when(future).update(any(), any()); + + final ThrottledFuture result = future.update(()->"", List.of()); + assertNotEquals(future, result); + } + + @Test + void updateReturnsSelfWhenFutureIsNotRunningAndNotDone() { + final ThrottledFuture future = ThrottledFuture.of(()->""); + + final ThrottledFuture result = future.update(()->"", List.of()); + assertEquals(future, result); + } + + @Test + void updateReturnsNewFutureWhenFutureIsDone() { + final ThrottledFuture future = ThrottledFuture.done(""); + + final ThrottledFuture result = future.update(()->"", List.of()); + assertNotEquals(future, result); + } + + @Test + void updateReturnsNewFutureWhenLockIsNotLockedForUpdate() throws Throwable { + final ReentrantLock futureLock = new ReentrantLock(); + final ThrottledFuture future = ThrottledFuture.of(()->""); + + final Thread thread = new Thread(futureLock::lock); + thread.start(); + thread.join(); + + ReflectionTestUtils.setField(future, "lock", futureLock); + + final ThrottledFuture result = future.update(()->"", List.of()); + assertNotEquals(future, result); + } + + @Test + void runExecutionCallbackIsExecutedAfterStartedAtIsSetAndBeforeTaskExecution() { + final AtomicBoolean taskExecuted = new AtomicBoolean(false); + final ThrottledFuture future = ThrottledFuture.of(()->{ + taskExecuted.set(true); + }); + + future.run(f -> { + assertEquals(future, f); + assertTrue(f.startedAt().isPresent()); + }); + + assertTrue(taskExecuted.get()); + } +} diff --git a/src/test/java/cz/cvut/kbss/termit/util/throttle/ThrottledService.java b/src/test/java/cz/cvut/kbss/termit/util/throttle/ThrottledService.java new file mode 100644 index 000000000..d89cd7149 --- /dev/null +++ b/src/test/java/cz/cvut/kbss/termit/util/throttle/ThrottledService.java @@ -0,0 +1,11 @@ +package cz.cvut.kbss.termit.util.throttle; + +import java.util.concurrent.Future; + +public class ThrottledService { + + @Throttle + public Future annotatedMethod() { + return ThrottledFuture.of(() -> true); + } +} diff --git a/src/test/java/cz/cvut/kbss/termit/websocket/BaseWebSocketControllerTestRunner.java b/src/test/java/cz/cvut/kbss/termit/websocket/BaseWebSocketControllerTestRunner.java index da6684097..e2038aa0b 100644 --- a/src/test/java/cz/cvut/kbss/termit/websocket/BaseWebSocketControllerTestRunner.java +++ b/src/test/java/cz/cvut/kbss/termit/websocket/BaseWebSocketControllerTestRunner.java @@ -2,9 +2,13 @@ import cz.cvut.kbss.termit.environment.config.TestRestSecurityConfig; import cz.cvut.kbss.termit.environment.config.TestWebSocketConfig; +import cz.cvut.kbss.termit.service.IdentifierResolver; import cz.cvut.kbss.termit.util.Configuration; +import cz.cvut.kbss.termit.util.longrunning.LongRunningTasksRegistry; +import cz.cvut.kbss.termit.websocket.handler.StompExceptionHandler; +import cz.cvut.kbss.termit.websocket.handler.WebSocketExceptionHandler; import cz.cvut.kbss.termit.websocket.util.CachingChannelInterceptor; -import jakarta.annotation.PostConstruct; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.junit.jupiter.MockitoExtension; @@ -14,6 +18,8 @@ import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.boot.test.context.ConfigDataApplicationContextInitializer; +import org.springframework.boot.test.mock.mockito.MockBean; +import org.springframework.boot.test.mock.mockito.SpyBean; import org.springframework.messaging.Message; import org.springframework.messaging.support.AbstractSubscribableChannel; import org.springframework.test.annotation.DirtiesContext; @@ -26,18 +32,31 @@ import java.util.UUID; import static cz.cvut.kbss.termit.websocket.util.ReturnValueCollectingSimpMessagingTemplate.MESSAGE_IDENTIFIER_HEADER; +import static org.mockito.Mockito.verifyNoMoreInteractions; @ActiveProfiles("test") @ExtendWith(SpringExtension.class) @ExtendWith(MockitoExtension.class) @EnableConfigurationProperties({Configuration.class}) -@DirtiesContext(classMode = DirtiesContext.ClassMode.AFTER_CLASS) +@DirtiesContext(classMode = DirtiesContext.ClassMode.AFTER_EACH_TEST_METHOD) @ContextConfiguration(classes = {TestRestSecurityConfig.class, TestWebSocketConfig.class}, initializers = {ConfigDataApplicationContextInitializer.class}) public abstract class BaseWebSocketControllerTestRunner { private static final Logger LOG = LoggerFactory.getLogger(BaseWebSocketControllerTestRunner.class); + @SpyBean + protected WebSocketExceptionHandler webSocketExceptionHandler; + + @SpyBean + protected StompExceptionHandler stompExceptionHandler; + + @MockBean + protected IdentifierResolver identifierResolver; + + @MockBean + protected LongRunningTasksRegistry longRunningTasksRegistry; + /** * Simulated messages from client to server */ @@ -68,8 +87,8 @@ public abstract class BaseWebSocketControllerTestRunner { protected CachingChannelInterceptor brokerChannelInterceptor; - @PostConstruct - protected void runnerPostConstruct() { + @BeforeEach + protected void runnerBeforeEach() { this.brokerChannelInterceptor = new CachingChannelInterceptor(); this.serverOutboundChannelInterceptor = new CachingChannelInterceptor(); @@ -77,11 +96,9 @@ protected void runnerPostConstruct() { this.serverOutboundChannel.addInterceptor(this.serverOutboundChannelInterceptor); } - @BeforeEach - protected void runnerBeforeEach() { - this.serverOutboundChannelInterceptor.reset(); - this.brokerChannelInterceptor.reset(); - this.returnedValuesMap.clear(); + @AfterEach + protected void runnerAfterEach() { + verifyNoMoreInteractions(webSocketExceptionHandler, stompExceptionHandler); } /** diff --git a/src/test/java/cz/cvut/kbss/termit/websocket/BaseWebSocketIntegrationTestRunner.java b/src/test/java/cz/cvut/kbss/termit/websocket/BaseWebSocketIntegrationTestRunner.java index ce30493d2..a8cd731c5 100644 --- a/src/test/java/cz/cvut/kbss/termit/websocket/BaseWebSocketIntegrationTestRunner.java +++ b/src/test/java/cz/cvut/kbss/termit/websocket/BaseWebSocketIntegrationTestRunner.java @@ -13,7 +13,11 @@ import cz.cvut.kbss.termit.security.model.TermItUserDetails; import cz.cvut.kbss.termit.service.security.TermItUserDetailsService; import cz.cvut.kbss.termit.util.Configuration; +import cz.cvut.kbss.termit.util.longrunning.LongRunningTasksRegistry; +import cz.cvut.kbss.termit.websocket.handler.StompExceptionHandler; +import cz.cvut.kbss.termit.websocket.handler.WebSocketExceptionHandler; import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.junit.jupiter.MockitoExtension; @@ -24,6 +28,7 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.boot.test.context.ConfigDataApplicationContextInitializer; import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.mock.mockito.MockBean; import org.springframework.boot.test.mock.mockito.SpyBean; import org.springframework.context.annotation.ComponentScan; import org.springframework.context.annotation.EnableAspectJAutoProxy; @@ -43,10 +48,10 @@ import org.springframework.web.socket.client.standard.StandardWebSocketClient; import org.springframework.web.socket.messaging.WebSocketStompClient; -import java.util.concurrent.Future; import java.util.concurrent.atomic.AtomicReference; import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.verifyNoMoreInteractions; @ActiveProfiles("test") @EnableSpringConfigured @@ -61,12 +66,21 @@ initializers = {ConfigDataApplicationContextInitializer.class}) @ComponentScan( {"cz.cvut.kbss.termit.security", "cz.cvut.kbss.termit.websocket", "cz.cvut.kbss.termit.websocket.handler"}) -@DirtiesContext(classMode = DirtiesContext.ClassMode.AFTER_CLASS) +@DirtiesContext(classMode = DirtiesContext.ClassMode.AFTER_EACH_TEST_METHOD) @SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT) public abstract class BaseWebSocketIntegrationTestRunner { protected Logger LOG = LoggerFactory.getLogger(this.getClass()); + @SpyBean + protected WebSocketExceptionHandler webSocketExceptionHandler; + + @SpyBean + protected StompExceptionHandler stompExceptionHandler; + + @MockBean + protected LongRunningTasksRegistry longRunningTasksRegistry; + protected WebSocketStompClient stompClient; @Value("ws://localhost:${local.server.port}/ws") @@ -80,10 +94,6 @@ public abstract class BaseWebSocketIntegrationTestRunner { protected TermItUserDetails userDetails; - protected Future connect(StompSessionHandlerAdapter sessionHandler) { - return stompClient.connectAsync(url, sessionHandler); - } - protected String generateToken() { return jwtUtils.generateToken(userDetails.getUser(), userDetails.getAuthorities()); } @@ -96,6 +106,11 @@ void runnerSetup() { doReturn(userDetails).when(userDetailsService).loadUserByUsername(userDetails.getUsername()); } + @AfterEach + protected void runnerAfterEach() { + verifyNoMoreInteractions(webSocketExceptionHandler, stompExceptionHandler); + } + protected class TestWebSocketSessionHandler implements WebSocketHandler { @Override diff --git a/src/test/java/cz/cvut/kbss/termit/websocket/IntegrationWebSocketSecurityTest.java b/src/test/java/cz/cvut/kbss/termit/websocket/IntegrationWebSocketSecurityTest.java index b3bbcc043..c1d2d81b0 100644 --- a/src/test/java/cz/cvut/kbss/termit/websocket/IntegrationWebSocketSecurityTest.java +++ b/src/test/java/cz/cvut/kbss/termit/websocket/IntegrationWebSocketSecurityTest.java @@ -37,6 +37,8 @@ import static org.awaitility.Awaitility.await; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.notNull; +import static org.mockito.Mockito.verify; class IntegrationWebSocketSecurityTest extends BaseWebSocketIntegrationTestRunner { @@ -49,10 +51,11 @@ class IntegrationWebSocketSecurityTest extends BaseWebSocketIntegrationTestRunne ObjectMapper objectMapper; /** - * @return Stream of argument pairs with StompCommand (CONNECT excluded) and true + false value for each command + * @return Stream of argument pairs with StompCommand (CONNECT & DISCONNECT excluded) and true + false value for each command */ public static Stream stompCommands() { - return Arrays.stream(StompCommand.values()).filter(c -> c != StompCommand.CONNECT).map(Enum::name) + return Arrays.stream(StompCommand.values()).filter(c -> c != StompCommand.CONNECT && c != StompCommand.DISCONNECT) + .map(Enum::name) .flatMap(name -> Stream.of(Arguments.of(name, true), Arguments.of(name, false))); } @@ -83,6 +86,7 @@ void connectionIsClosedOnAnyMessageBeforeConnect(String stompCommand, Boolean wi assertTrue(receivedError.get()); assertFalse(session.isOpen()); assertFalse(receivedReply.get()); + verify(webSocketExceptionHandler).messageDeliveryException(notNull(), notNull()); } WebSocketHandler makeWebSocketHandler(AtomicBoolean receivedReply, AtomicBoolean receivedError) { @@ -127,6 +131,7 @@ void connectWithInvalidAuthorizationIsRejected() throws Throwable { assertTrue(receivedError.get()); assertFalse(session.isOpen()); assertFalse(receivedReply.get()); + verify(webSocketExceptionHandler).messageDeliveryException(notNull(), notNull()); } /** @@ -161,6 +166,8 @@ void connectWithInvalidJwtAuthorizationIsRejected() throws Throwable { assertTrue(receivedError.get()); assertFalse(session.isOpen()); assertFalse(receivedReply.get()); + + verify(webSocketExceptionHandler).messageDeliveryException(notNull(), notNull()); } /** diff --git a/src/test/java/cz/cvut/kbss/termit/websocket/VocabularySocketControllerTest.java b/src/test/java/cz/cvut/kbss/termit/websocket/VocabularySocketControllerTest.java index 65507b909..7e36c462f 100644 --- a/src/test/java/cz/cvut/kbss/termit/websocket/VocabularySocketControllerTest.java +++ b/src/test/java/cz/cvut/kbss/termit/websocket/VocabularySocketControllerTest.java @@ -7,6 +7,7 @@ import cz.cvut.kbss.termit.model.validation.ValidationResult; import cz.cvut.kbss.termit.service.IdentifierResolver; import cz.cvut.kbss.termit.service.business.VocabularyService; +import cz.cvut.kbss.termit.util.throttle.ThrottledFuture; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.springframework.boot.test.mock.mockito.MockBean; @@ -30,9 +31,6 @@ class VocabularySocketControllerTest extends BaseWebSocketControllerTestRunner { - @MockBean - IdentifierResolver idResolver; - @MockBean VocabularyService vocabularyService; @@ -52,8 +50,9 @@ public void setup() { vocabulary = Generator.generateVocabularyWithId(); fragment = IdentifierResolver.extractIdentifierFragment(vocabulary.getUri()).substring(1); namespace = vocabulary.getUri().toString().substring(0, vocabulary.getUri().toString().lastIndexOf('/')); - when(idResolver.resolveIdentifier(namespace, fragment)).thenReturn(vocabulary.getUri()); + when(identifierResolver.resolveIdentifier(namespace, fragment)).thenReturn(vocabulary.getUri()); when(vocabularyService.getReference(vocabulary.getUri())).thenReturn(vocabulary); + when(vocabularyService.validateContents(vocabulary.getUri())).thenReturn(ThrottledFuture.done(List.of())); messageHeaders = StompHeaderAccessor.create(StompCommand.MESSAGE); messageHeaders.setSessionId("0"); @@ -71,7 +70,7 @@ void validateVocabularyValidatesContents() { this.serverInboundChannel.send(MessageBuilder.withPayload("").setHeaders(messageHeaders).build()); - verify(vocabularyService).validateContents(vocabulary); + verify(vocabularyService).validateContents(vocabulary.getUri()); } @Test @@ -86,7 +85,7 @@ void validateVocabularyReturnsValidationResults() { .setSeverity(Generator.generateUri()) .setIssueCauseUri(Generator.generateUri()); final List validationResults = List.of(validationResult); - when(vocabularyService.validateContents(vocabulary)).thenReturn(validationResults); + when(vocabularyService.validateContents(vocabulary.getUri())).thenReturn(ThrottledFuture.done(validationResults)); this.serverInboundChannel.send(MessageBuilder.withPayload("").setHeaders(messageHeaders).build()); diff --git a/src/test/java/cz/cvut/kbss/termit/websocket/WebSocketExceptionHandlerTest.java b/src/test/java/cz/cvut/kbss/termit/websocket/WebSocketExceptionHandlerTest.java index 52df68045..a833a953b 100644 --- a/src/test/java/cz/cvut/kbss/termit/websocket/WebSocketExceptionHandlerTest.java +++ b/src/test/java/cz/cvut/kbss/termit/websocket/WebSocketExceptionHandlerTest.java @@ -3,11 +3,9 @@ import cz.cvut.kbss.termit.environment.Environment; import cz.cvut.kbss.termit.environment.Generator; import cz.cvut.kbss.termit.exception.PersistenceException; -import cz.cvut.kbss.termit.websocket.handler.WebSocketExceptionHandler; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.springframework.boot.test.mock.mockito.MockBean; -import org.springframework.boot.test.mock.mockito.SpyBean; import org.springframework.messaging.simp.stomp.StompCommand; import org.springframework.messaging.simp.stomp.StompHeaderAccessor; import org.springframework.messaging.support.MessageBuilder; @@ -18,14 +16,11 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.notNull; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; class WebSocketExceptionHandlerTest extends BaseWebSocketControllerTestRunner { - @SpyBean - WebSocketExceptionHandler sut; - @MockBean VocabularySocketController controller; @@ -51,9 +46,9 @@ void sendMessage() { @Test void handlerIsCalledForPersistenceException() { final PersistenceException e = new PersistenceException(new Exception("mocked exception")); - when(controller.validateVocabulary(any(), any())).thenThrow(e); + doThrow(e).when(controller).validateVocabulary(any(), any(), any()); sendMessage(); - verify(sut).persistenceException(notNull(), eq(e)); + verify(webSocketExceptionHandler).persistenceException(notNull(), eq(e)); } }