diff --git a/datashare-app/src/main/java/org/icij/datashare/CliApp.java b/datashare-app/src/main/java/org/icij/datashare/CliApp.java index a522966ff..bf466649b 100644 --- a/datashare-app/src/main/java/org/icij/datashare/CliApp.java +++ b/datashare-app/src/main/java/org/icij/datashare/CliApp.java @@ -20,7 +20,6 @@ import org.slf4j.LoggerFactory; import java.io.IOException; -import java.nio.file.Path; import java.util.List; import java.util.Properties; @@ -112,37 +111,37 @@ private static void runTaskWorker(CommonMode mode, Properties properties) throws logger.info("executing {}", pipeline); if (pipeline.has(Stage.DEDUPLICATE)) { taskManager.startTask( - new Task<>(DeduplicateTask.class.getName(), nullUser(), propertiesToMap(properties))); + new Task<>(DeduplicateTask.class.getName(), propertiesToMap(properties)), nullUser()); } if (pipeline.has(Stage.SCANIDX)) { taskManager.startTask( - new Task<>(ScanIndexTask.class.getName(), nullUser(), propertiesToMap(properties))); + new Task<>(ScanIndexTask.class.getName(), propertiesToMap(properties)), nullUser()); } if (pipeline.has(Stage.SCAN)) { taskManager.startTask( - new Task<>(ScanTask.class.getName(), nullUser(), propertiesToMap(properties))); + new Task<>(ScanTask.class.getName(), propertiesToMap(properties)), nullUser()); } if (pipeline.has(Stage.INDEX)) { taskManager.startTask( - new Task<>(IndexTask.class.getName(), nullUser(), propertiesToMap(properties))); + new Task<>(IndexTask.class.getName(), propertiesToMap(properties)), nullUser()); } if (pipeline.has(Stage.ENQUEUEIDX)) { taskManager.startTask( - new Task<>(EnqueueFromIndexTask.class.getName(), nullUser(), propertiesToMap(properties))); + new Task<>(EnqueueFromIndexTask.class.getName(), propertiesToMap(properties)), nullUser()); } if (pipeline.has(Stage.NLP)) { taskManager.startTask( - new Task<>(ExtractNlpTask.class.getName(), nullUser(), propertiesToMap(properties))); + new Task<>(ExtractNlpTask.class.getName(), propertiesToMap(properties)), nullUser()); } if (pipeline.has(Stage.ARTIFACT)) { taskManager.startTask( - new Task<>(ArtifactTask.class.getName(), nullUser(), propertiesToMap(properties))); + new Task<>(ArtifactTask.class.getName(), propertiesToMap(properties)), nullUser()); } taskManager.shutdownAndAwaitTermination(Integer.MAX_VALUE, SECONDS); indexer.close(); diff --git a/datashare-app/src/main/java/org/icij/datashare/mode/CommonMode.java b/datashare-app/src/main/java/org/icij/datashare/mode/CommonMode.java index b50568802..25c999e94 100644 --- a/datashare-app/src/main/java/org/icij/datashare/mode/CommonMode.java +++ b/datashare-app/src/main/java/org/icij/datashare/mode/CommonMode.java @@ -199,7 +199,7 @@ private void configureBatchQueuesMemory() { } private void configureBatchQueuesRedis(RedissonClient redissonClient) { - bind(new TypeLiteral>>(){}).toInstance(new RedisBlockingQueue<>(redissonClient, DS_TASKS_QUEUE_NAME, new org.icij.datashare.asynctasks.TaskManagerRedis.TaskViewCodec())); + bind(new TypeLiteral>>(){}).toInstance(new RedisBlockingQueue<>(redissonClient, DS_TASKS_QUEUE_NAME, new org.icij.datashare.asynctasks.TaskManagerRedis.RedisCodec<>(Task.class))); } public Properties properties() { diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/ArtifactTask.java b/datashare-app/src/main/java/org/icij/datashare/tasks/ArtifactTask.java index aaa8fe8aa..f6482d417 100644 --- a/datashare-app/src/main/java/org/icij/datashare/tasks/ArtifactTask.java +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/ArtifactTask.java @@ -7,11 +7,11 @@ import org.icij.datashare.asynctasks.Task; import org.icij.datashare.asynctasks.TaskGroup; import org.icij.datashare.extract.DocumentCollectionFactory; -import org.icij.datashare.function.Pair; import org.icij.datashare.text.Document; import org.icij.datashare.text.Project; import org.icij.datashare.text.indexing.Indexer; import org.icij.datashare.text.indexing.elasticsearch.SourceExtractor; +import org.icij.datashare.user.User; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -20,7 +20,6 @@ import java.util.concurrent.TimeUnit; import java.util.function.Function; -import static java.util.Optional.ofNullable; import static org.icij.datashare.cli.DatashareCliOptions.ARTIFACT_DIR_OPT; import static org.icij.datashare.cli.DatashareCliOptions.DEFAULT_DEFAULT_PROJECT; import static org.icij.datashare.cli.DatashareCliOptions.DEFAULT_PROJECT_OPT; @@ -33,8 +32,8 @@ public class ArtifactTask extends PipelineTask { private final Path artifactDir; @Inject - public ArtifactTask(DocumentCollectionFactory factory, Indexer indexer, PropertiesProvider propertiesProvider, @Assisted Task taskView, @Assisted final Function updateCallback) { - super(Stage.ARTIFACT, taskView.getUser(), factory, propertiesProvider, String.class); + public ArtifactTask(DocumentCollectionFactory factory, Indexer indexer, PropertiesProvider propertiesProvider, @Assisted Task taskView, @Assisted final User user, @Assisted final Function updateCallback) { + super(Stage.ARTIFACT, user, factory, propertiesProvider, String.class); this.indexer = indexer; project = Project.project(propertiesProvider.get(DEFAULT_PROJECT_OPT).orElse(DEFAULT_DEFAULT_PROJECT)); artifactDir = Path.of(propertiesProvider.get(ARTIFACT_DIR_OPT).orElseThrow(() -> new IllegalArgumentException(String.format("cannot create artifact task with empty %s", ARTIFACT_DIR_OPT)))); diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/BatchSearchRunner.java b/datashare-app/src/main/java/org/icij/datashare/tasks/BatchSearchRunner.java index 709506632..5fed628c0 100644 --- a/datashare-app/src/main/java/org/icij/datashare/tasks/BatchSearchRunner.java +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/BatchSearchRunner.java @@ -66,22 +66,24 @@ public class BatchSearchRunner implements CancellableTask, UserTask, Callable taskView; + private final User user; protected volatile boolean cancelAsked = false; protected volatile Thread callThread; protected volatile boolean requeueCancel; @Inject public BatchSearchRunner(Indexer indexer, PropertiesProvider propertiesProvider, BatchSearchRepository repository, - @Assisted Task taskView, @Assisted Function updateCallback) { - this(indexer, propertiesProvider, repository, taskView, updateCallback, new CountDownLatch(1)); + @Assisted Task taskView, @Assisted final User user, @Assisted Function updateCallback) { + this(indexer, propertiesProvider, repository, taskView, user, updateCallback, new CountDownLatch(1)); } BatchSearchRunner(Indexer indexer, PropertiesProvider propertiesProvider, BatchSearchRepository repository, - Task taskView, Function updateCallback, CountDownLatch latch) { + Task taskView, User user, Function updateCallback, CountDownLatch latch) { this.indexer = indexer; this.propertiesProvider = propertiesProvider; this.repository = repository; this.taskView = (Task) taskView; + this.user = user; this.updateCallback = updateCallback; this.callWaiterLatch = latch; } @@ -100,7 +102,7 @@ public Integer call() throws Exception { int scrollSize = min(scrollSizeFromParams, MAX_SCROLL_SIZE); callThread = Thread.currentThread(); callWaiterLatch.countDown(); // for tests - BatchSearch batchSearch = repository.get(taskView.getUser(), taskView.id); + BatchSearch batchSearch = repository.get(user, taskView.id); if (batchSearch == null) { logger.warn("batch search {} not found in database (check that database url is the same as datashare backend)", taskView.id); return 0; @@ -164,7 +166,7 @@ public Integer call() throws Exception { @Override public User getUser() { - return taskView.getUser(); + return user; } /** diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/DatashareTaskFactory.java b/datashare-app/src/main/java/org/icij/datashare/tasks/DatashareTaskFactory.java index d59343609..ac72c3692 100644 --- a/datashare-app/src/main/java/org/icij/datashare/tasks/DatashareTaskFactory.java +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/DatashareTaskFactory.java @@ -7,16 +7,16 @@ public interface DatashareTaskFactory extends org.icij.datashare.asynctasks.TaskFactory { - BatchSearchRunner createBatchSearchRunner(Task taskView, Function updateCallback); - BatchDownloadRunner createBatchDownloadRunner(Task taskView, Function updateCallback); + BatchSearchRunner createBatchSearchRunner(Task taskView, User user, Function updateCallback); + BatchDownloadRunner createBatchDownloadRunner(Task taskView, User user, Function updateCallback); - ScanTask createScanTask(Task taskView, Function updateCallback); - IndexTask createIndexTask(Task taskView, Function updateCallback); - ScanIndexTask createScanIndexTask(Task taskView, Function updateCallback); - ExtractNlpTask createExtractNlpTask(Task taskView, Function updateCallback); - EnqueueFromIndexTask createEnqueueFromIndexTask(Task taskView, Function updateCallback); - DeduplicateTask createDeduplicateTask(Task taskView, Function updateCallback); - ArtifactTask createArtifactTask(Task taskView, Function updateCallback); + ScanTask createScanTask(Task taskView, User user, Function updateCallback); + IndexTask createIndexTask(Task taskView, User user, Function updateCallback); + ScanIndexTask createScanIndexTask(Task taskView, User user, Function updateCallback); + ExtractNlpTask createExtractNlpTask(Task taskView, User user, Function updateCallback); + EnqueueFromIndexTask createEnqueueFromIndexTask(Task taskView, User user, Function updateCallback); + DeduplicateTask createDeduplicateTask(Task taskView, User user, Function updateCallback); + ArtifactTask createArtifactTask(Task taskView, User user, Function updateCallback); GenApiKeyTask createGenApiKey(User user); DelApiKeyTask createDelApiKey(User user); diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/DeduplicateTask.java b/datashare-app/src/main/java/org/icij/datashare/tasks/DeduplicateTask.java index 481dfaaef..e2be22e24 100644 --- a/datashare-app/src/main/java/org/icij/datashare/tasks/DeduplicateTask.java +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/DeduplicateTask.java @@ -7,6 +7,7 @@ import org.icij.datashare.asynctasks.Task; import org.icij.datashare.asynctasks.TaskGroup; import org.icij.datashare.extract.DocumentCollectionFactory; +import org.icij.datashare.user.User; import org.icij.extract.queue.DocumentQueue; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -24,8 +25,8 @@ public class DeduplicateTask extends PipelineTask { private final DocumentCollectionFactory factory; @Inject - public DeduplicateTask(final DocumentCollectionFactory factory, @Assisted Task taskView, @Assisted final Function updateCallback) { - super(Stage.DEDUPLICATE, taskView.getUser(), factory, new PropertiesProvider(taskView.args), Path.class); + public DeduplicateTask(final DocumentCollectionFactory factory, @Assisted Task taskView, @Assisted final User user, @Assisted final Function updateCallback) { + super(Stage.DEDUPLICATE, user, factory, new PropertiesProvider(taskView.args), Path.class); this.factory = factory; } diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/EnqueueFromIndexTask.java b/datashare-app/src/main/java/org/icij/datashare/tasks/EnqueueFromIndexTask.java index 7f41538ff..2810d4ac6 100644 --- a/datashare-app/src/main/java/org/icij/datashare/tasks/EnqueueFromIndexTask.java +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/EnqueueFromIndexTask.java @@ -3,7 +3,6 @@ import com.google.inject.Inject; import com.google.inject.assistedinject.Assisted; -import java.util.Optional; import java.util.function.Function; import org.icij.datashare.Entity; import org.icij.datashare.PropertiesProvider; @@ -12,10 +11,10 @@ import org.icij.datashare.asynctasks.TaskGroup; import org.icij.datashare.extract.DocumentCollectionFactory; import org.icij.datashare.text.Document; -import org.icij.datashare.text.ProjectProxy; import org.icij.datashare.text.indexing.Indexer; import org.icij.datashare.text.indexing.SearchQuery; import org.icij.datashare.text.nlp.Pipeline; +import org.icij.datashare.user.User; import org.icij.extract.queue.DocumentQueue; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -47,8 +46,8 @@ public class EnqueueFromIndexTask extends PipelineTask { @Inject public EnqueueFromIndexTask(final DocumentCollectionFactory factory, final Indexer indexer, - @Assisted Task taskView, @Assisted final Function updateCallback) { - super(Stage.ENQUEUEIDX, taskView.getUser(), factory, new PropertiesProvider(taskView.args), String.class); + @Assisted Task taskView, @Assisted final User user, @Assisted final Function updateCallback) { + super(Stage.ENQUEUEIDX, user, factory, new PropertiesProvider(taskView.args), String.class); this.factory = factory; this.indexer = indexer; this.nlpPipeline = Pipeline.Type.parse((String) taskView.args.getOrDefault(NLP_PIPELINE_OPT, Pipeline.Type.CORENLP.name())); diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/ExtractNlpTask.java b/datashare-app/src/main/java/org/icij/datashare/tasks/ExtractNlpTask.java index a182a163e..85bc641c7 100644 --- a/datashare-app/src/main/java/org/icij/datashare/tasks/ExtractNlpTask.java +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/ExtractNlpTask.java @@ -16,6 +16,7 @@ import org.icij.datashare.text.Project; import org.icij.datashare.text.indexing.Indexer; import org.icij.datashare.text.nlp.Pipeline; +import org.icij.datashare.user.User; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -41,13 +42,14 @@ public class ExtractNlpTask extends PipelineTask implements Monitorable private final int maxContentLengthChars; @Inject - public ExtractNlpTask(Indexer indexer, PipelineRegistry registry, final DocumentCollectionFactory factory, @Assisted Task taskView, @Assisted final Function updateCallback) { - this(indexer, registry.get(Pipeline.Type.parse((String)taskView.args.get(NLP_PIPELINE_OPT))), factory, taskView, updateCallback); + public ExtractNlpTask(Indexer indexer, PipelineRegistry registry, final DocumentCollectionFactory factory, @Assisted Task taskView, @Assisted final User user, @Assisted final Function updateCallback) { + this(indexer, registry.get(Pipeline.Type.parse((String)taskView.args.get(NLP_PIPELINE_OPT))), factory, taskView, user, updateCallback); } - ExtractNlpTask(Indexer indexer, Pipeline pipeline, final DocumentCollectionFactory factory, @Assisted Task taskView, @Assisted final Function updateCallback) { - super(Stage.NLP, taskView.getUser(), factory, new PropertiesProvider(taskView.args), String.class); + ExtractNlpTask(Indexer indexer, Pipeline pipeline, final DocumentCollectionFactory factory, @Assisted Task taskView, + @Assisted User user, @Assisted final Function updateCallback) { + super(Stage.NLP, user, factory, new PropertiesProvider(taskView.args), String.class); this.nlpPipeline = pipeline; project = Project.project(ofNullable((String)taskView.args.get(DEFAULT_PROJECT_OPT)).orElse(DEFAULT_DEFAULT_PROJECT)); maxContentLengthChars = (int) HumanReadableSize.parse(ofNullable((String)taskView.args.get(MAX_CONTENT_LENGTH_OPT)).orElse(valueOf(DEFAULT_MAX_CONTENT_LENGTH))); diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/IndexTask.java b/datashare-app/src/main/java/org/icij/datashare/tasks/IndexTask.java index d9d6c7980..112d79673 100644 --- a/datashare-app/src/main/java/org/icij/datashare/tasks/IndexTask.java +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/IndexTask.java @@ -10,6 +10,7 @@ import org.icij.datashare.extract.DocumentCollectionFactory; import org.icij.datashare.monitoring.Monitorable; import org.icij.datashare.text.indexing.elasticsearch.ElasticsearchSpewer; +import org.icij.datashare.user.User; import org.icij.extract.document.DocumentFactory; import org.icij.extract.extractor.DocumentConsumer; import org.icij.extract.extractor.Extractor; @@ -43,8 +44,9 @@ public class IndexTask extends PipelineTask implements Monitorable{ private final Integer parallelism; @Inject - public IndexTask(final ElasticsearchSpewer spewer, final DocumentCollectionFactory factory, @Assisted Task taskView, @Assisted final Function updateCallback) throws IOException { - super(Stage.INDEX, taskView.getUser(), factory, new PropertiesProvider(taskView.args), Path.class); + public IndexTask(final ElasticsearchSpewer spewer, final DocumentCollectionFactory factory, @Assisted Task taskView, + @Assisted final User user, @Assisted final Function updateCallback) throws IOException { + super(Stage.INDEX, user, factory, new PropertiesProvider(taskView.args), Path.class); parallelism = propertiesProvider.get(PARALLELISM_OPT).map(Integer::parseInt).orElse(Runtime.getRuntime().availableProcessors()); Options allTaskOptions = options().createFrom(Options.from(taskView.args)); diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/ScanIndexTask.java b/datashare-app/src/main/java/org/icij/datashare/tasks/ScanIndexTask.java index 08e87dc75..bc50c9357 100644 --- a/datashare-app/src/main/java/org/icij/datashare/tasks/ScanIndexTask.java +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/ScanIndexTask.java @@ -11,6 +11,7 @@ import org.icij.datashare.extract.DocumentCollectionFactory; import org.icij.datashare.text.Document; import org.icij.datashare.text.indexing.Indexer; +import org.icij.datashare.user.User; import org.icij.extract.extractor.ExtractionStatus; import org.icij.extract.report.Report; import org.icij.extract.report.ReportMap; @@ -53,8 +54,8 @@ public class ScanIndexTask extends PipelineTask { @Inject public ScanIndexTask(DocumentCollectionFactory factory, final Indexer indexer, - @Assisted Task taskView, @Assisted Function updateCallback) { - super(Stage.SCANIDX, taskView.getUser(), factory, new PropertiesProvider(taskView.args), Path.class); + @Assisted Task taskView, @Assisted final User user, @Assisted Function updateCallback) { + super(Stage.SCANIDX, user, factory, new PropertiesProvider(taskView.args), Path.class); this.scrollDuration = propertiesProvider.get(SCROLL_DURATION_OPT).orElse(DEFAULT_SCROLL_DURATION); this.scrollSize = parseInt(propertiesProvider.get(SCROLL_SIZE_OPT).orElse(valueOf(DEFAULT_SCROLL_SIZE))); this.scrollSlices = parseInt(propertiesProvider.get(SCROLL_SLICES_OPT).orElse(valueOf(DEFAULT_SCROLL_SLICES))); diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/ScanTask.java b/datashare-app/src/main/java/org/icij/datashare/tasks/ScanTask.java index ed76db18e..719442e70 100644 --- a/datashare-app/src/main/java/org/icij/datashare/tasks/ScanTask.java +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/ScanTask.java @@ -9,6 +9,7 @@ import org.icij.datashare.asynctasks.TaskGroup; import org.icij.datashare.cli.DatashareCliOptions; import org.icij.datashare.extract.DocumentCollectionFactory; +import org.icij.datashare.user.User; import org.icij.extract.Scanner; import org.icij.extract.ScannerVisitor; import org.icij.task.Options; @@ -24,8 +25,8 @@ public class ScanTask extends PipelineTask { private final Path path; @Inject - public ScanTask(DocumentCollectionFactory factory, @Assisted Task task, @Assisted Function updateCallback) { - super(Stage.SCAN, task.getUser(), factory, new PropertiesProvider(task.args), Path.class); + public ScanTask(DocumentCollectionFactory factory, @Assisted Task task, @Assisted final User user, @Assisted Function updateCallback) { + super(Stage.SCAN, user, factory, new PropertiesProvider(task.args), Path.class); scanner = new Scanner(outputQueue).configure(options().createFrom(Options.from(task.args))); path = Paths.get((String)task.args.get(DatashareCliOptions.DATA_DIR_OPT)); } diff --git a/datashare-app/src/main/java/org/icij/datashare/tasks/TaskManagerAmqp.java b/datashare-app/src/main/java/org/icij/datashare/tasks/TaskManagerAmqp.java index b85eed1ba..48e1f1f0a 100644 --- a/datashare-app/src/main/java/org/icij/datashare/tasks/TaskManagerAmqp.java +++ b/datashare-app/src/main/java/org/icij/datashare/tasks/TaskManagerAmqp.java @@ -6,11 +6,8 @@ import org.icij.datashare.PropertiesProvider; import org.icij.datashare.asynctasks.TaskManagerRedis; -import org.icij.datashare.asynctasks.Task; import org.icij.datashare.asynctasks.bus.amqp.AmqpInterlocutor; -import org.icij.datashare.cli.DatashareCliOptions; import org.icij.datashare.mode.CommonMode; -import org.jetbrains.annotations.NotNull; import org.redisson.Redisson; import org.redisson.RedissonMap; import org.redisson.api.RedissonClient; @@ -32,8 +29,8 @@ public TaskManagerAmqp(AmqpInterlocutor amqp, RedissonClient redissonClient, Pro super(amqp, createTaskQueue(redissonClient), Utils.getRoutingStrategy(propertiesProvider), eventCallback); } - private static RedissonMap> createTaskQueue(RedissonClient redissonClient) { - return new RedissonMap<>(new TaskManagerRedis.TaskViewCodec(), + private static RedissonMap> createTaskQueue(RedissonClient redissonClient) { + return new RedissonMap<>(new TaskManagerRedis.RedisCodec<>(TaskMetadata.class), new CommandSyncService(((Redisson) redissonClient).getConnectionManager(), new RedissonObjectBuilder(redissonClient)), CommonMode.DS_TASK_MANAGER_QUEUE_NAME, diff --git a/datashare-app/src/main/java/org/icij/datashare/web/TaskResource.java b/datashare-app/src/main/java/org/icij/datashare/web/TaskResource.java index 753a3877e..1c232d56c 100644 --- a/datashare-app/src/main/java/org/icij/datashare/web/TaskResource.java +++ b/datashare-app/src/main/java/org/icij/datashare/web/TaskResource.java @@ -197,8 +197,8 @@ public Payload indexFile(@Parameter(name = "filePath", description = "path of th // TODO remove taskFactory.createScanIndexTask would allow to get rid of taskfactory dependency in taskresource // problem for now is that if we call taskManager.startTask(ScanIndexTask.class.getName(), user, propertiesToMap(properties)) // the task will be run as a background task that will have race conditions with indexTask report loading - scanIndex = new Task<>(ScanIndexTask.class.getName(), user, propertiesToMap(properties)); - taskFactory.createScanIndexTask(scanIndex, (p) -> null).call(); + scanIndex = new Task<>(ScanIndexTask.class.getName(), propertiesToMap(properties)); + taskFactory.createScanIndexTask(scanIndex, user, (p) -> null).call(); taskIds.add(scanIndex.id); } else { properties.remove(MAP_NAME_OPTION); // avoid use of reportMap to override ES docs @@ -317,8 +317,8 @@ public static String getReportMapNameFor(Properties properties) { return "extract:report:" + projectName; } - private static Task forbiddenIfNotSameUser(Context context, Task task) { - if (!task.getUser().equals(context.currentUser())) throw new ForbiddenException(); + private Task forbiddenIfNotSameUser(Context context, Task task) { + if (!taskManager.getTaskUser(task.id).equals(context.currentUser())) throw new ForbiddenException(); return task; } diff --git a/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/Task.java b/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/Task.java index 3ff984cf2..e907d2040 100644 --- a/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/Task.java +++ b/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/Task.java @@ -10,13 +10,10 @@ import org.icij.datashare.asynctasks.bus.amqp.Event; import org.icij.datashare.asynctasks.bus.amqp.TaskError; import org.icij.datashare.asynctasks.bus.amqp.UriResult; -import org.icij.datashare.user.User; -import java.io.Serial; import java.io.Serializable; import java.util.Collections; import java.util.HashMap; -import java.util.LinkedHashMap; import java.util.Map; import java.util.Objects; import java.util.concurrent.Callable; @@ -29,12 +26,10 @@ @JsonInclude(JsonInclude.Include.NON_NULL) public class Task extends Event implements Entity { - public static final String USER_KEY = "user"; - public static final String GROUP_KEY = "group"; @JsonIgnore private StateLatch stateLatch; @JsonIgnore private final Object lock = new Object(); - public enum State {CREATED, QUEUED, RUNNING, CANCELLED, ERROR, DONE;} + public enum State {CREATED, QUEUED, RUNNING, CANCELLED, ERROR, DONE} @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, property = "@type") public final Map args; public final String id; @@ -50,20 +45,16 @@ public enum State {CREATED, QUEUED, RUNNING, CANCELLED, ERROR, DONE;} @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "@type") private volatile V result; - public Task(String name, User user, Map args) { - this(randomUUID().toString(), name, user, args); + public Task(String name, Map args) { + this(randomUUID().toString(), name, State.CREATED, 0, null, args); } - public Task(String name, User user, Group group, Map args) { - this(randomUUID().toString(), name, State.CREATED, 0, null, addTo(args, user, group)); + public Task(String id, String name) { + this(id, name, new HashMap<>()); } - public Task(String id, String name, User user, Group group) { - this(id, name, user,addTo(new HashMap<>(), user, group)); - } - - public Task(String id, String name, User user, Map args) { - this(id, name, State.CREATED, 0, null, addTo(args, user)); + public Task(String id, String name, Map args) { + this(id, name, State.CREATED, 0, null, args); } @JsonCreator @@ -179,16 +170,6 @@ public boolean isNull() { return id == null; } - @JsonIgnore - public User getUser() { - return (User) args.get(USER_KEY); - } - - @JsonIgnore - public Group getGroup() { - return (Group) args.get(GROUP_KEY); - } - public static String getId(Callable task) { return task.toString(); } @@ -210,17 +191,4 @@ private void setState(State state) { this.state = state; ofNullable(stateLatch).ifPresent(sl -> sl.setTaskState(state)); } - - private static Map addTo(Map properties, User user) { - LinkedHashMap result = new LinkedHashMap<>(properties); - result.put(USER_KEY, user); - return result; - } - - private static Map addTo(Map properties, User user, Group group) { - LinkedHashMap result = new LinkedHashMap<>(properties); - result.put(USER_KEY, user); - result.put(GROUP_KEY, group); - return result; - } } \ No newline at end of file diff --git a/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/TaskAlreadyExists.java b/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/TaskAlreadyExists.java new file mode 100644 index 000000000..b7aa671c7 --- /dev/null +++ b/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/TaskAlreadyExists.java @@ -0,0 +1,9 @@ +package org.icij.datashare.asynctasks; + +public class TaskAlreadyExists extends Exception { + final String taskId; + + public TaskAlreadyExists(String taskId) { + this.taskId = taskId; + } +} diff --git a/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/TaskManager.java b/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/TaskManager.java index 839cd0a8c..b139bea85 100644 --- a/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/TaskManager.java +++ b/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/TaskManager.java @@ -1,5 +1,7 @@ package org.icij.datashare.asynctasks; +import static java.util.stream.Collectors.toMap; + import org.icij.datashare.asynctasks.bus.amqp.CancelledEvent; import org.icij.datashare.asynctasks.bus.amqp.ErrorEvent; import org.icij.datashare.asynctasks.bus.amqp.ProgressEvent; @@ -18,34 +20,44 @@ import java.util.regex.Pattern; import java.util.stream.Stream; -import static java.util.stream.Collectors.toList; -import static java.util.stream.Collectors.toMap; public interface TaskManager extends Closeable { Logger logger = LoggerFactory.getLogger(TaskManager.class); + record TaskMetadata(Task task, User user, Group group) { + String taskId() { + return task.id; + } + TaskMetadata withTask(Task task) { + return new TaskMetadata<>(task, this.user, this.group); + } + } + boolean stopTask(String taskId) throws IOException; Task clearTask(String taskId) throws IOException; boolean shutdownAndAwaitTermination(int timeout, TimeUnit timeUnit) throws InterruptedException, IOException; Task getTask(String taskId) throws IOException; List> getTasks() throws IOException; List> getTasks(User user, Pattern pattern) throws IOException; + User getTaskUser(String taskId); + Group getTaskGroup(String taskId); List> clearDoneTasks() throws IOException; void clear() throws IOException; - boolean save(Task task) throws IOException; + void saveMetadata(TaskMetadata taskMetadata) throws IOException, TaskAlreadyExists; + void persistUpdate(Task task) throws IOException, UnknownTask; void enqueue(Task task) throws IOException; - static List> getTasks(Stream> stream, User user, Pattern pattern) { - return stream. - filter(t -> user.equals(t.getUser())). - filter(t -> pattern.matcher(t.name).matches()). - collect(toList()); + default List> getTasks(Stream> stream, User user, Pattern pattern) { + return stream + .filter(t -> user.equals(getTaskUser(t.id))) + .filter(t -> pattern.matcher(t.name).matches()) + .toList(); } default Map stopAllTasks(User user) throws IOException { - return getTasks().stream(). - filter(t -> user.equals(t.getUser())). - filter(t -> t.getState() == Task.State.RUNNING || t.getState() == Task.State.QUEUED).collect( + return getTasks().stream() + .filter(t -> user.equals(getTaskUser(t.id))) + .filter(t -> t.getState() == Task.State.RUNNING || t.getState() == Task.State.QUEUED).collect( toMap(t -> t.id, t -> { try { return stopTask(t.id); @@ -59,22 +71,22 @@ default Map stopAllTasks(User user) throws IOException { // for tests default String startTask(String taskName, User user, Map properties) throws IOException { - return startTask(new Task<>(taskName, user, properties)); + return startTask(new Task<>(taskName, properties), user, null); } // TaskResource and pipeline tasks default String startTask(Class taskClass, User user, Map properties) throws IOException { - return startTask(new Task<>(taskClass.getName(), user, new Group(taskClass.getAnnotation(TaskGroup.class).value()), properties)); + return startTask(new Task<>(taskClass.getName(), properties), user, new Group(taskClass.getAnnotation(TaskGroup.class).value())); } // for tests default String startTask(String taskName, User user, Group group, Map properties) throws IOException { - return startTask(new Task<>(taskName, user, group, properties)); + return startTask(new Task<>(taskName, properties), user, group); } // BatchSearchResource and WebApp for batch searches default String startTask(String id, Class taskClass, User user) throws IOException { - return startTask(new Task<>(id, taskClass.getName(), user, new Group(taskClass.getAnnotation(TaskGroup.class).value()))); + return startTask(new Task<>(id, taskClass.getName()), user, new Group(taskClass.getAnnotation(TaskGroup.class).value())); } /** @@ -82,16 +94,44 @@ default String startTask(String id, Class taskClass, User user) throws IOExc * it in the memory/redis/AMQP queue and return the id. Else it will not enqueue the task and return null. * * @param taskView: the task description. + * @param user: task user + * @param group: task group * @return task id if it was new and has been saved else null * @throws IOException in case of communication failure with Redis or AMQP broker */ + default String startTask(Task taskView, User user, Group group) throws IOException { + try { + save(taskView, user, group); + } catch (TaskAlreadyExists ignored) { + throw new RuntimeException("task with id " + taskView.id + " was already save !"); + } + taskView.queue(); + enqueue(taskView); + return taskView.id; + } + + default String startTask(Task taskView, User user) throws IOException { + return startTask(taskView, user, null); + } + default String startTask(Task taskView) throws IOException { - boolean saved = save(taskView); - if (saved) { - taskView.queue(); - enqueue(taskView); + return startTask(taskView, null, null); + } + + default String startTask(Task taskView, Group group) throws IOException { + return startTask(taskView, null, group); + } + + default void save(Task task, User user, Group group) throws IOException, TaskAlreadyExists { + saveMetadata(new TaskMetadata<>(task, user, group)); + } + + default void update(Task task) throws IOException { + try { + persistUpdate(task); + } catch (UnknownTask e) { + throw new RuntimeException("task " + task.id + " is unknown, save it first !"); } - return saved ? taskView.id: null; } default Task setResult(ResultEvent e) throws IOException { @@ -99,7 +139,7 @@ default Task setResult(ResultEvent e) throws IOEx if (taskView != null) { logger.info("result event for {}", e.taskId); taskView.setResult(e.result); - save(taskView); + update(taskView); } else { logger.warn("no task found for result event {}", e.taskId); } @@ -111,7 +151,7 @@ default Task setError(ErrorEvent e) throws IOExcepti if (taskView != null) { logger.info("error event for {}", e.taskId); taskView.setError(e.error); - save(taskView); + update(taskView); } else { logger.warn("no task found for error event {}", e.taskId); } @@ -123,7 +163,7 @@ default Task setCanceled(CancelledEvent e) throws IOException { if (taskView != null) { logger.info("canceled event for {}", e.taskId); taskView.cancel(); - save(taskView); + update(taskView); if (e.requeue) { try { enqueue(taskView); @@ -142,7 +182,7 @@ default Task setProgress(ProgressEvent e) throws IOException { Task taskView = getTask(e.taskId); if (taskView != null) { taskView.setProgress(e.progress); - save(taskView); + update(taskView); } return taskView; } diff --git a/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/TaskManagerAmqp.java b/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/TaskManagerAmqp.java index 8462ba95d..90acf5341 100644 --- a/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/TaskManagerAmqp.java +++ b/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/TaskManagerAmqp.java @@ -1,5 +1,6 @@ package org.icij.datashare.asynctasks; +import java.util.Optional; import org.icij.datashare.asynctasks.bus.amqp.AmqpConsumer; import org.icij.datashare.asynctasks.bus.amqp.AmqpInterlocutor; import org.icij.datashare.asynctasks.bus.amqp.AmqpQueue; @@ -11,7 +12,6 @@ import org.icij.datashare.user.User; import java.io.IOException; -import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -22,22 +22,22 @@ import static java.util.stream.Collectors.toList; public class TaskManagerAmqp implements TaskManager { - private final Map> tasks; + private final Map> taskMetas; private final RoutingStrategy routingStrategy; private final AmqpInterlocutor amqp; private final AmqpConsumer> eventConsumer; - public TaskManagerAmqp(AmqpInterlocutor amqp, Map> tasks) throws IOException { - this(amqp, tasks, RoutingStrategy.UNIQUE); + public TaskManagerAmqp(AmqpInterlocutor amqp, Map> taskMetas) throws IOException { + this(amqp, taskMetas, RoutingStrategy.UNIQUE); } - public TaskManagerAmqp(AmqpInterlocutor amqp, Map> tasks, RoutingStrategy routingStrategy) throws IOException { - this(amqp, tasks, routingStrategy, null); + public TaskManagerAmqp(AmqpInterlocutor amqp, Map> taskMetas, RoutingStrategy routingStrategy) throws IOException { + this(amqp, taskMetas, routingStrategy, null); } - public TaskManagerAmqp(AmqpInterlocutor amqp, Map> tasks, RoutingStrategy routingStrategy, Runnable eventCallback) throws IOException { + public TaskManagerAmqp(AmqpInterlocutor amqp, Map> taskMetas, RoutingStrategy routingStrategy, Runnable eventCallback) throws IOException { this.amqp = amqp; - this.tasks = tasks; + this.taskMetas = taskMetas; this.routingStrategy = routingStrategy; eventConsumer = new AmqpConsumer<>(amqp, event -> ofNullable(TaskManager.super.handleAck(event)).flatMap(t -> @@ -46,7 +46,7 @@ public TaskManagerAmqp(AmqpInterlocutor amqp, Map> tasks, Routin @Override public boolean stopTask(String taskId) { - Task taskView = tasks.get(taskId); + Task taskView = this.getTask(taskId); if (taskView != null) { try { logger.info("sending cancel event for {}", taskId); @@ -63,11 +63,11 @@ public boolean stopTask(String taskId) { @Override public Task clearTask(String taskId) { - if (tasks.get(taskId).getState() == Task.State.RUNNING) { + if (this.getTask(taskId).getState() == Task.State.RUNNING) { throw new IllegalStateException(String.format("task id <%s> is already in RUNNING state", taskId)); } logger.info("deleting task id <{}>", taskId); - return (Task) tasks.remove(taskId); + return (Task) taskMetas.remove(taskId).task(); } @Override @@ -76,15 +76,29 @@ public boolean shutdownAndAwaitTermination(int timeout, TimeUnit timeUnit) throw return true; } - public boolean save(Task task) { - Task oldVal = tasks.put(task.id, task); - return oldVal == null; + @Override + public void saveMetadata(TaskMetadata taskMetadata) throws TaskAlreadyExists { + String taskId = taskMetadata.taskId(); + if (taskMetas.containsKey(taskId)) { + throw new TaskAlreadyExists(taskId); + } + this.taskMetas.put(taskId, taskMetadata); + } + + @Override + public void persistUpdate(Task task) throws UnknownTask { + TaskMetadata updated = (TaskMetadata) taskMetas.get(task.id); + if (updated == null) { + throw new UnknownTask(task.id); + } + updated = updated.withTask(task); + this.taskMetas.put(task.id, updated); } @Override public void enqueue(Task task) throws IOException { switch (routingStrategy) { - case GROUP -> amqp.publish(AmqpQueue.TASK, task.getGroup().id(), task); + case GROUP -> amqp.publish(AmqpQueue.TASK, this.taskMetas.get(task.id).group().id(), task); case NAME -> amqp.publish(AmqpQueue.TASK, task.name, task); default -> amqp.publish(AmqpQueue.TASK, task); } @@ -92,22 +106,33 @@ public void enqueue(Task task) throws IOException { @Override public Task getTask(String taskId) { - return (Task) tasks.get(taskId); + return (Task) Optional.ofNullable(taskMetas.get(taskId)).map(TaskMetadata::task).orElse(null); } @Override public List> getTasks() { - return new LinkedList<>(tasks.values()); + return taskMetas.values().stream().map(TaskMetadata::task).collect(toList()); + } + + @Override + public List> getTasks(User user, Pattern pattern) throws IOException { + return this.getTasks(taskMetas.values().stream().map(TaskMetadata::task), user, pattern); + } + + @Override + public User getTaskUser(String taskId) { + return taskMetas.get(taskId).user(); } @Override - public List> getTasks(User user, Pattern pattern) { - return TaskManager.getTasks(tasks.values().stream(), user, pattern); + public Group getTaskGroup(String taskId) { + return taskMetas.get(taskId).group(); } @Override public List> clearDoneTasks() { - return tasks.values().stream().filter(f -> f.getState() != Task.State.RUNNING).map(t -> tasks.remove(t.id)).collect(toList()); + return taskMetas.values().stream().map(TaskMetadata::task).filter(Task::isFinished) + .map(t -> taskMetas.remove(t.id).task()).collect(toList()); } public void close() throws IOException { @@ -117,6 +142,6 @@ public void close() throws IOException { @Override public void clear() { - tasks.clear(); + taskMetas.clear(); } } diff --git a/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/TaskManagerMemory.java b/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/TaskManagerMemory.java index 18cc09758..1614efbf0 100644 --- a/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/TaskManagerMemory.java +++ b/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/TaskManagerMemory.java @@ -1,17 +1,16 @@ package org.icij.datashare.asynctasks; +import java.util.Optional; import org.apache.commons.lang3.NotImplementedException; import org.icij.datashare.PropertiesProvider; import org.icij.datashare.asynctasks.bus.amqp.Event; import org.icij.datashare.asynctasks.bus.amqp.TaskError; -import org.icij.datashare.asynctasks.bus.amqp.TaskEvent; import org.icij.datashare.user.User; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.io.Serializable; -import java.util.LinkedList; import java.util.List; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicInteger; @@ -22,13 +21,12 @@ import static java.lang.Integer.parseInt; import static java.util.stream.Collectors.toList; -import static org.icij.datashare.asynctasks.Task.State.RUNNING; public class TaskManagerMemory implements TaskManager, TaskSupplier { private final Logger logger = LoggerFactory.getLogger(getClass()); private final ExecutorService executor; - private final ConcurrentMap> tasks = new ConcurrentHashMap<>(); + private final ConcurrentMap> taskMetas = new ConcurrentHashMap<>(); private final BlockingQueue> taskQueue; private final List loops; private final AtomicInteger executedTasks = new AtomicInteger(0); @@ -47,22 +45,32 @@ public TaskManagerMemory(BlockingQueue> taskQueue, TaskFactory taskFacto } public Task getTask(final String taskId) { - return (Task) tasks.get(taskId); + return (Task) Optional.ofNullable(taskMetas.get(taskId)).map(TaskMetadata::task).orElse(null); } @Override public List> getTasks() { - return new LinkedList<>(tasks.values()); + return taskMetas.values().stream().map(TaskMetadata::task).collect(toList()); } @Override - public List> getTasks(User user, Pattern pattern) { - return TaskManager.getTasks(tasks.values().stream(), user, pattern); + public List> getTasks(User user, Pattern pattern) throws IOException { + return this.getTasks(taskMetas.values().stream().map(TaskMetadata::task), user, pattern); + } + + @Override + public User getTaskUser(String taskId) { + return taskMetas.get(taskId).user(); + } + + @Override + public Group getTaskGroup(String taskId) { + return taskMetas.get(taskId).group(); } @Override public Void progress(String taskId, double rate) { - Task taskView = tasks.get(taskId); + Task taskView = getTask(taskId); if (taskView != null) { taskView.setProgress(rate); } else { @@ -73,7 +81,7 @@ public Void progress(String taskId, double rate) { @Override public void result(String taskId, V result) { - Task taskView = (Task) tasks.get(taskId); + Task taskView = getTask(taskId); if (taskView != null) { taskView.setResult(result); executedTasks.incrementAndGet(); @@ -84,7 +92,7 @@ public void result(String taskId, V result) { @Override public void canceled(Task task, boolean requeue) { - Task taskView = tasks.get(task.id); + Task taskView = taskMetas.get(task.id).task(); if (taskView != null) { taskView.cancel(); if (requeue) { @@ -95,7 +103,7 @@ public void canceled(Task task, boolean requeue) { @Override public void error(String taskId, TaskError reason) { - Task taskView = tasks.get(taskId); + Task taskView = taskMetas.get(taskId).task(); if (taskView != null) { taskView.setError(reason); executedTasks.incrementAndGet(); @@ -104,9 +112,23 @@ public void error(String taskId, TaskError reason) { } } - public boolean save(Task taskView) { - Task oldTask = tasks.put(taskView.id, taskView); - return oldTask == null; + @Override + public void saveMetadata(TaskMetadata taskMetadata) throws TaskAlreadyExists { + String taskId = taskMetadata.taskId(); + if (taskMetas.containsKey(taskId)) { + throw new TaskAlreadyExists(taskId); + } + this.taskMetas.put(taskId, taskMetadata); + } + + @Override + public void persistUpdate(Task task) throws UnknownTask { + TaskMetadata updated = (TaskMetadata) taskMetas.get(task.id); + if (updated == null) { + throw new UnknownTask(task.id); + } + updated = updated.withTask(task); + this.taskMetas.put(task.id, updated); } @Override @@ -121,30 +143,31 @@ public boolean shutdownAndAwaitTermination(int timeout, TimeUnit timeUnit) throw } public List> waitTasksToBeDone(int timeout, TimeUnit timeUnit) { - return tasks.values().stream().peek(taskView -> { + return taskMetas.values().stream().peek(m -> { try { - taskView.getResult(timeout, timeUnit); + m.task().getResult(timeout, timeUnit); } catch (InterruptedException | CancellationException e) { logger.error("task interrupted while running", e); } - }).collect(toList()); + }).map(TaskMetadata::task).collect(toList()); } public List> clearDoneTasks() { - return tasks.values().stream().filter(taskView -> taskView.getState() != RUNNING).map(t -> tasks.remove(t.id)).collect(toList()); + return taskMetas.values().stream().map(TaskMetadata::task).filter(Task::isFinished) + .map(t -> taskMetas.remove(t.id).task()).collect(toList()); } @Override - public Task clearTask(String taskName) { - if (tasks.get(taskName).getState() == Task.State.RUNNING) { - throw new IllegalStateException(String.format("task id <%s> is already in RUNNING state", taskName)); + public Task clearTask(String taskId) { + if (getTask(taskId).getState() == Task.State.RUNNING) { + throw new IllegalStateException(String.format("task id <%s> is already in RUNNING state", taskId)); } - logger.info("deleting task id <{}>", taskName); - return (Task) tasks.remove(taskName); + logger.info("deleting task id <{}>", taskId); + return (Task) taskMetas.remove(taskId).task(); } public boolean stopTask(String taskId) { - Task taskView = tasks.get(taskId); + Task taskView = getTask(taskId); if (taskView != null) { switch (taskView.getState()) { case QUEUED: @@ -185,7 +208,7 @@ int numberOfExecutedTasks() { public void clear() { executedTasks.set(0); taskQueue.clear(); - tasks.clear(); + taskMetas.clear(); } @Override diff --git a/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/TaskManagerRedis.java b/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/TaskManagerRedis.java index c61e75d12..66b6a4ada 100644 --- a/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/TaskManagerRedis.java +++ b/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/TaskManagerRedis.java @@ -5,6 +5,8 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.buffer.ByteBufInputStream; import io.netty.buffer.ByteBufOutputStream; +import java.util.Optional; +import java.util.stream.Collectors; import org.icij.datashare.asynctasks.bus.amqp.AmqpQueue; import org.icij.datashare.asynctasks.bus.amqp.CancelEvent; import org.icij.datashare.asynctasks.bus.amqp.ShutdownEvent; @@ -30,11 +32,9 @@ import java.io.InputStream; import java.io.OutputStream; import java.nio.charset.Charset; -import java.util.LinkedList; import java.util.List; import java.util.concurrent.BlockingQueue; import java.util.concurrent.TimeUnit; -import java.util.function.Consumer; import java.util.regex.Pattern; import java.util.stream.StreamSupport; @@ -44,7 +44,7 @@ public class TaskManagerRedis implements TaskManager { private final Runnable eventCallback; // for test public static final String EVENT_CHANNEL_NAME = "EVENT"; - private final RedissonMap> tasks; + private final RedissonMap> taskMetas; private final RTopic eventTopic; private final RedissonClient redissonClient; private final RoutingStrategy routingStrategy; @@ -57,39 +57,49 @@ public TaskManagerRedis(RedissonClient redissonClient, String taskMapName, Routi this.redissonClient = redissonClient; this.routingStrategy = routingStrategy; CommandSyncService commandSyncService = getCommandSyncService(); - this.tasks = new RedissonMap<>(new TaskViewCodec(), commandSyncService, taskMapName, redissonClient, null, null); + this.taskMetas = new RedissonMap<>(new RedisCodec<>(TaskMetadata.class), commandSyncService, taskMapName, redissonClient, null, null); this.eventTopic = redissonClient.getTopic(EVENT_CHANNEL_NAME); this.eventCallback = eventCallback; eventTopic.addListener(TaskEvent.class, (channelString, message) -> handleEvent(message)); } - @Override - public Task getTask(String id) { - return (Task) tasks.get(id); + public Task getTask(final String taskId) { + return (Task) Optional.ofNullable(taskMetas.get(taskId)).map(TaskMetadata::task).orElse(null); } @Override public List> getTasks() { - return new LinkedList<>(tasks.values()); + return taskMetas.values().stream().map(TaskMetadata::task).collect(Collectors.toList()); + } + + @Override + public List> getTasks(User user, Pattern pattern) throws IOException { + return this.getTasks(taskMetas.values().stream().map(TaskMetadata::task), user, pattern); + } + + @Override + public User getTaskUser(String taskId) { + return taskMetas.get(taskId).user(); } @Override - public List> getTasks(User user, Pattern pattern) { - return TaskManager.getTasks(tasks.values().stream(), user, pattern); + public Group getTaskGroup(String taskId) { + return taskMetas.get(taskId).group(); } @Override public List> clearDoneTasks() { - return tasks.values().stream().filter(Task::isFinished).map(t -> tasks.remove(t.id)).collect(toList()); + return taskMetas.values().stream().map(TaskMetadata::task).filter(Task::isFinished) + .map(t -> taskMetas.remove(t.id).task()).collect(toList()); } @Override - public Task clearTask(String taskId) { - if (tasks.get(taskId).getState() == Task.State.RUNNING) { + public Task clearTask(String taskId) { + if (getTask(taskId).getState() == Task.State.RUNNING) { throw new IllegalStateException(String.format("task id <%s> is already in RUNNING state", taskId)); } logger.info("deleting task id <{}>", taskId); - return tasks.remove(taskId); + return (Task) taskMetas.remove(taskId).task(); } @Override @@ -116,13 +126,13 @@ public boolean shutdownAndAwaitTermination(int timeout, TimeUnit timeUnit) { BlockingQueue> taskQueue(Task task) { switch (routingStrategy) { case GROUP -> { - return new RedissonBlockingQueue<>(new TaskViewCodec(), getCommandSyncService(), String.format("%s.%s", AmqpQueue.TASK.name(), task.getGroup().id()), redissonClient); + return new RedissonBlockingQueue<>(new RedisCodec<>(Task.class), getCommandSyncService(), String.format("%s.%s", AmqpQueue.TASK.name(), this.taskMetas.get(task.id).group().id()), redissonClient); } case NAME -> { - return new RedissonBlockingQueue<>(new TaskViewCodec(), getCommandSyncService(), String.format("%s.%s", AmqpQueue.TASK.name(), task.name), redissonClient); + return new RedissonBlockingQueue<>(new RedisCodec<>(Task.class), getCommandSyncService(), String.format("%s.%s", AmqpQueue.TASK.name(), task.name), redissonClient); } default -> { - return new RedissonBlockingQueue<>(new TaskViewCodec(), getCommandSyncService(), AmqpQueue.TASK.name(), redissonClient); + return new RedissonBlockingQueue<>(new RedisCodec<>(Task.class), getCommandSyncService(), AmqpQueue.TASK.name(), redissonClient); } } } @@ -140,7 +150,7 @@ public void close() throws IOException { @Override public void clear() { - tasks.clear(); + taskMetas.clear(); clearTaskQueues(); } @@ -153,9 +163,23 @@ private void clearTaskQueues() { .forEach(k -> redissonClient.getQueue(k).delete()); } - public boolean save(Task task) { - Task oldVal = tasks.put(task.id, task); - return oldVal == null; + @Override + public void saveMetadata(TaskMetadata taskMetadata) throws TaskAlreadyExists { + String taskId = taskMetadata.taskId(); + if (taskMetas.containsKey(taskId)) { + throw new TaskAlreadyExists(taskId); + } + this.taskMetas.put(taskId, taskMetadata); + } + + @Override + public void persistUpdate(Task task) throws UnknownTask { + TaskMetadata updated = (TaskMetadata) taskMetas.get(task.id); + if (updated == null) { + throw new UnknownTask(task.id); + } + updated = updated.withTask(task); + this.taskMetas.put(task.id, updated); } @Override @@ -163,14 +187,16 @@ public void enqueue(Task task) { taskQueue(task).add(task); } - public static class TaskViewCodec extends BaseCodec { + public static class RedisCodec extends BaseCodec { + private final Class clazz; private final Encoder keyEncoder; private final Decoder keyDecoder; protected final ObjectMapper mapObjectMapper; - public TaskViewCodec() { + public RedisCodec(Class clazz) { + // Ugly but this doesn't work with type ref directly + this.clazz = clazz; this.mapObjectMapper = JsonObjectMapper.MAPPER; - this.keyEncoder = in -> { ByteBuf out = ByteBufAllocator.DEFAULT.buffer(); out.writeCharSequence(in.toString(), Charset.defaultCharset()); @@ -203,8 +229,8 @@ public ByteBuf encode(Object in) throws IOException { private final Decoder decoder = new Decoder<>() { @Override - public Object decode(ByteBuf buf, State state) throws IOException { - return mapObjectMapper.readValue((InputStream) new ByteBufInputStream(buf), Task.class); + public T decode(ByteBuf buf, State state) throws IOException { + return mapObjectMapper.readValue((InputStream) new ByteBufInputStream(buf), clazz); } }; diff --git a/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/TaskSupplierRedis.java b/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/TaskSupplierRedis.java index 703fc869a..adeca4b66 100644 --- a/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/TaskSupplierRedis.java +++ b/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/TaskSupplierRedis.java @@ -1,5 +1,6 @@ package org.icij.datashare.asynctasks; +import java.util.Optional; import org.apache.commons.lang3.NotImplementedException; import org.icij.datashare.asynctasks.bus.amqp.AmqpQueue; import org.icij.datashare.asynctasks.bus.amqp.CancelledEvent; @@ -83,8 +84,8 @@ public void waitForConsumer() {} private BlockingQueue> taskQueue() { return this.taskQueueKey == null ? - new RedissonBlockingQueue<>(new TaskManagerRedis.TaskViewCodec(), getCommandSyncService(), AmqpQueue.TASK.name(), redissonClient): - new RedissonBlockingQueue<>(new TaskManagerRedis.TaskViewCodec(), getCommandSyncService(), String.format("%s.%s", AmqpQueue.TASK.name(), taskQueueKey), redissonClient); + new RedissonBlockingQueue<>(new TaskManagerRedis.RedisCodec<>(Task.class), getCommandSyncService(), AmqpQueue.TASK.name(), redissonClient): + new RedissonBlockingQueue<>(new TaskManagerRedis.RedisCodec<>(Task.class), getCommandSyncService(), String.format("%s.%s", AmqpQueue.TASK.name(), taskQueueKey), redissonClient); } private CommandSyncService getCommandSyncService() { diff --git a/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/UnknownTask.java b/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/UnknownTask.java new file mode 100644 index 000000000..6a75fad0f --- /dev/null +++ b/datashare-tasks/src/main/java/org/icij/datashare/asynctasks/UnknownTask.java @@ -0,0 +1,9 @@ +package org.icij.datashare.asynctasks; + +public class UnknownTask extends Exception { + final String taskId; + + public UnknownTask(String taskId) { + this.taskId = taskId; + } +} diff --git a/datashare-tasks/src/test/java/org/icij/datashare/asynctasks/TaskManagerAmqpTest.java b/datashare-tasks/src/test/java/org/icij/datashare/asynctasks/TaskManagerAmqpTest.java index cdf3111e8..371f3189a 100644 --- a/datashare-tasks/src/test/java/org/icij/datashare/asynctasks/TaskManagerAmqpTest.java +++ b/datashare-tasks/src/test/java/org/icij/datashare/asynctasks/TaskManagerAmqpTest.java @@ -7,26 +7,25 @@ import org.icij.datashare.asynctasks.bus.amqp.TaskError; import org.icij.datashare.tasks.RoutingStrategy; import org.icij.datashare.user.User; -import org.icij.extract.redis.RedissonClientFactory; -import org.icij.task.Options; -import org.junit.*; -import org.redisson.Redisson; -import org.redisson.RedissonMap; -import org.redisson.api.RedissonClient; -import org.redisson.command.CommandSyncService; -import org.redisson.liveobject.core.RedissonObjectBuilder; import java.io.IOException; import java.io.Serializable; import java.util.HashMap; import java.util.Map; import java.util.concurrent.*; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Test; import static org.fest.assertions.Assertions.assertThat; public class TaskManagerAmqpTest { private static AmqpInterlocutor AMQP; - @ClassRule static public AmqpServerRule qpid = new AmqpServerRule(5672); + @ClassRule + static public AmqpServerRule qpid = new AmqpServerRule(5672); BlockingQueue> taskQueue = new LinkedBlockingQueue<>(); TaskManagerAmqp taskManager; TaskSupplierAmqp taskSupplier; @@ -53,7 +52,7 @@ public void test_new_task_with_group_routing() throws Exception { assertThat(groupTaskManager.getTask(expectedTaskViewId)).isNotNull(); Task actualTaskView = taskQueue.poll(1, TimeUnit.SECONDS); assertThat(actualTaskView).isNotNull(); - assertThat(actualTaskView.getGroup()).isEqualTo(new Group(key)); + assertThat(groupTaskManager.getTaskGroup(actualTaskView.id)).isEqualTo(new Group(key)); } } @@ -168,6 +167,30 @@ public void test_task_canceled() throws Exception { assertThat(taskManager.getTask(task.id).getState()).isEqualTo(Task.State.CANCELLED); } + @Test + public void test_save_task() throws TaskAlreadyExists, IOException { + Task task = new Task<>("name", new HashMap<>()); + + taskManager.save(task, User.local(), null); + + assertThat(taskManager.getTasks()).hasSize(1); + assertThat(taskManager.getTask(task.id)).isNotNull(); + } + + @Test + public void test_update_task() throws TaskAlreadyExists, IOException { + // Given + Task task = new Task<>("HelloWorld", Map.of("greeted", "world")); + TaskManager.TaskMetadata meta = new TaskManager.TaskMetadata<>(task, User.local(), null); + Task update = new Task<>(task.id, task.name, task.getState(), 0.5, null, task.args); + // When + taskManager.saveMetadata(meta); + taskManager.update(update); + Task updated = taskManager.getTask(task.id); + // Then + assertThat(updated).isEqualTo(update); + } + @BeforeClass public static void beforeClass() throws Exception { AMQP = new AmqpInterlocutor(new PropertiesProvider(new HashMap<>() {{ @@ -180,16 +203,7 @@ public static void beforeClass() throws Exception { @Before public void setUp() throws IOException { nextMessage = new CountDownLatch(1); - final RedissonClient redissonClient = new RedissonClientFactory().withOptions( - Options.from(new PropertiesProvider(Map.of("redisAddress", "redis://redis:6379")).getProperties())).create(); - Map> tasks = new RedissonMap<>(new TaskManagerRedis.TaskViewCodec(), - new CommandSyncService(((Redisson) redissonClient).getConnectionManager(), - new RedissonObjectBuilder(redissonClient)), - "tasks:queue:test", - redissonClient, - null, - null - ); + Map> tasks = new ConcurrentHashMap<>(); taskManager = new TaskManagerAmqp(AMQP, tasks, RoutingStrategy.UNIQUE, () -> nextMessage.countDown()); taskSupplier = new TaskSupplierAmqp(AMQP); taskSupplier.consumeTasks(t -> taskQueue.add(t)); diff --git a/datashare-tasks/src/test/java/org/icij/datashare/asynctasks/TaskManagerMemoryTest.java b/datashare-tasks/src/test/java/org/icij/datashare/asynctasks/TaskManagerMemoryTest.java index 2201c6f93..92cab7e01 100644 --- a/datashare-tasks/src/test/java/org/icij/datashare/asynctasks/TaskManagerMemoryTest.java +++ b/datashare-tasks/src/test/java/org/icij/datashare/asynctasks/TaskManagerMemoryTest.java @@ -1,5 +1,7 @@ package org.icij.datashare.asynctasks; +import java.io.IOException; +import java.util.HashMap; import org.icij.datashare.PropertiesProvider; import org.icij.datashare.test.LogbackCapturingRule; import org.icij.datashare.user.User; @@ -37,9 +39,9 @@ public void setUp() throws Exception { @Test public void test_run_task() throws Exception { - Task task = new Task<>(TestFactory.HelloWorld.class.getName(), User.local(), Map.of("greeted", "world")); + Task task = new Task<>(TestFactory.HelloWorld.class.getName(), Map.of("greeted", "world")); - String tid = taskManager.startTask(task); + String tid = taskManager.startTask(task, User.local()); taskManager.shutdownAndAwaitTermination(100, TimeUnit.MILLISECONDS); assertThat(taskManager.getTask(tid).getState()).isEqualTo(Task.State.DONE); @@ -49,7 +51,7 @@ public void test_run_task() throws Exception { @Test public void test_stop_current_task() throws Exception { - Task task = new Task<>(TestFactory.SleepForever.class.getName(), User.local(), Map.of("intParameter", 2000)); + Task task = new Task<>(TestFactory.SleepForever.class.getName(), Map.of("intParameter", 2000)); String taskId = taskManager.startTask(task); taskInspector.awaitToBeStarted(taskId, 10000); @@ -62,8 +64,8 @@ public void test_stop_current_task() throws Exception { @Test public void test_stop_queued_task() throws Exception { - Task t1 = new Task<>(TestFactory.SleepForever.class.getName(), User.local(), Map.of()); - Task t2 = new Task<>(TestFactory.HelloWorld.class.getName(), User.local(), Map.of("greeted", "stucked task")); + Task t1 = new Task<>(TestFactory.SleepForever.class.getName(), Map.of()); + Task t2 = new Task<>(TestFactory.HelloWorld.class.getName(), Map.of("greeted", "stucked task")); taskManager.startTask(t1); taskManager.startTask(t2); @@ -80,7 +82,7 @@ public void test_stop_queued_task() throws Exception { @Test public void test_clear_the_only_task() throws Exception { - Task task = new Task<>("sleep", User.local(), Map.of("intParameter", 12)); + Task task = new Task<>("sleep", Map.of("intParameter", 12)); taskManager.startTask(task); taskManager.shutdownAndAwaitTermination(1, TimeUnit.SECONDS); @@ -93,9 +95,9 @@ public void test_clear_the_only_task() throws Exception { @Test(expected = IllegalStateException.class) public void test_clear_running_task_should_throw_exception() throws Exception { - Task task = new Task<>("sleep", User.local(), Map.of("intParameter", 12)); + Task task = new Task<>("sleep", Map.of("intParameter", 12)); - taskManager.startTask(task); + taskManager.startTask(task, User.local()); taskManager.shutdownAndAwaitTermination(1, TimeUnit.SECONDS); taskManager.progress(task.id, 0.5); assertThat(taskManager.getTask(task.id).getState()).isEqualTo(Task.State.RUNNING); @@ -119,6 +121,29 @@ public void test_result_on_unknown_task() throws InterruptedException { "unknown task id for result=0.5 call"); } + @Test + public void test_save_task() throws TaskAlreadyExists, IOException { + Task task = new Task<>("name", new HashMap<>()); + + taskManager.save(task, User.local(), null); + + assertThat(taskManager.getTasks()).hasSize(1); + assertThat(taskManager.getTask(task.id)).isNotNull(); + } + + @Test + public void test_update_task() throws TaskAlreadyExists, IOException { + // Given + Task task = new Task<>("HelloWorld", Map.of("greeted", "world")); + TaskManager.TaskMetadata meta = new TaskManager.TaskMetadata<>(task, User.local(), null); + Task update = new Task<>(task.id, task.name, task.getState(), 0.5, null, task.args); + // When + taskManager.saveMetadata(meta); + taskManager.update(update); + Task updated = taskManager.getTask(task.id); + // Then + assertThat(updated).isEqualTo(update); + } @After public void tearDown() throws Exception { diff --git a/datashare-tasks/src/test/java/org/icij/datashare/asynctasks/TaskManagerRedisCodecTest.java b/datashare-tasks/src/test/java/org/icij/datashare/asynctasks/TaskManagerRedisCodecTest.java index 6722e9abb..6ecd986e3 100644 --- a/datashare-tasks/src/test/java/org/icij/datashare/asynctasks/TaskManagerRedisCodecTest.java +++ b/datashare-tasks/src/test/java/org/icij/datashare/asynctasks/TaskManagerRedisCodecTest.java @@ -3,7 +3,6 @@ import io.netty.buffer.Unpooled; import org.fest.assertions.Assertions; import org.icij.datashare.asynctasks.bus.amqp.UriResult; -import org.icij.datashare.user.User; import org.junit.Test; import org.redisson.client.handler.State; @@ -15,14 +14,13 @@ import static org.fest.assertions.Assertions.assertThat; import static org.fest.assertions.MapAssert.entry; -import static org.icij.datashare.json.JsonObjectMapper.MAPPER; public class TaskManagerRedisCodecTest { - TaskManagerRedis.TaskViewCodec codec = new TaskManagerRedis.TaskViewCodec(); + TaskManagerRedis.RedisCodec codec = new TaskManagerRedis.RedisCodec<>(Task.class); @Test public void test_json_serialize_deserialize_with_inline_properties_map() throws Exception { - Task taskView = new Task<>("name", User.local(), Map.of("key", "value")); + Task taskView = new Task<>("name", Map.of("key", "value")); String json = codec.getValueEncoder().encode(taskView).toString(Charset.defaultCharset()); assertThat(json).contains("\"key\":\"value\""); @@ -30,14 +28,13 @@ public void test_json_serialize_deserialize_with_inline_properties_map() throws Task actualTask = (Task) codec.getValueDecoder().decode(Unpooled.wrappedBuffer(json.getBytes()), new State()); Assertions.assertThat(actualTask.name).isEqualTo("name"); - Assertions.assertThat(actualTask.args).hasSize(2); + Assertions.assertThat(actualTask.args).hasSize(1); Assertions.assertThat(actualTask.args).includes(entry("key", "value")); - Assertions.assertThat(actualTask.getUser()).isEqualTo(User.local()); } @Test public void test_uri_result() throws Exception { - Task task = new Task<>("name", User.local(), new HashMap<>()); + Task task = new Task<>("name", new HashMap<>()); task.setResult(new UriResult(new URI("file://uri"), 123L)); assertThat(encodeDecode(task).getResult()).isInstanceOf(UriResult.class); @@ -45,7 +42,7 @@ public void test_uri_result() throws Exception { @Test public void test_simple_results() throws Exception { - Task task = new Task<>("name", User.local(), new HashMap<>()); + Task task = new Task<>("name", new HashMap<>()); task.setResult(123L); assertThat(encodeDecode(task).getResult()).isInstanceOf(Long.class); diff --git a/datashare-tasks/src/test/java/org/icij/datashare/asynctasks/TaskManagerRedisTest.java b/datashare-tasks/src/test/java/org/icij/datashare/asynctasks/TaskManagerRedisTest.java index 5b73a2d83..ff5fe6ca4 100644 --- a/datashare-tasks/src/test/java/org/icij/datashare/asynctasks/TaskManagerRedisTest.java +++ b/datashare-tasks/src/test/java/org/icij/datashare/asynctasks/TaskManagerRedisTest.java @@ -38,21 +38,35 @@ public class TaskManagerRedisTest { taskSupplier = new TaskSupplierRedis(redissonClient); @Test - public void test_save_task() { - Task task = new Task<>("name", User.local(), new HashMap<>()); + public void test_save_task() throws TaskAlreadyExists, IOException { + Task task = new Task<>("name", new HashMap<>()); - taskManager.save(task); + taskManager.save(task, User.local(), null); assertThat(taskManager.getTasks()).hasSize(1); assertThat(taskManager.getTask(task.id)).isNotNull(); } @Test - public void test_start_task() throws IOException { - assertThat(taskManager.startTask("HelloWorld", User.local(), - new HashMap<>() {{ put("greeted", "world"); }})).isNotNull(); + public void test_update_task() throws TaskAlreadyExists, IOException { + // Given + Task task = new Task<>("HelloWorld", Map.of("greeted", "world")); + TaskManager.TaskMetadata meta = new TaskManager.TaskMetadata<>(task, User.local(), null); + Task update = new Task<>(task.id, task.name, task.getState(), 0.5, null, task.args); + // When + taskManager.saveMetadata(meta); + taskManager.update(update); + Task updated = taskManager.getTask(task.id); + // Then + assertThat(updated).isEqualTo(update); + } + + @Test + public void test_start_task() throws IOException, TaskAlreadyExists { + String taskId = taskManager.startTask("HelloWorld", User.local(), Map.of("greeted", "world")); + assertThat(taskId).isNotNull(); assertThat(taskManager.getTasks()).hasSize(1); - assertThat(taskManager.getTasks().get(0).getUser()).isEqualTo(User.local()); + assertThat(taskManager.getTaskUser(taskId)).isEqualTo(User.local()); } @Test @@ -66,7 +80,7 @@ public void test_start_task_with_group_routing() throws Exception { assertThat(groupTaskManager.startTask("HelloWorld", User.local(), new Group("Group"),Map.of("greeted", "world"))).isNotNull(); Task task = taskSupplier.get(2, TimeUnit.SECONDS); - assertThat(task.getGroup()).isEqualTo(new Group("Group")); + assertThat(groupTaskManager.getTaskGroup(task.id)).isEqualTo(new Group("Group")); assertThat(((RedissonBlockingQueue) groupTaskManager.taskQueue(task)).getName()).isEqualTo("TASK.Group"); } } diff --git a/datashare-tasks/src/test/java/org/icij/datashare/asynctasks/TaskManagersIntTest.java b/datashare-tasks/src/test/java/org/icij/datashare/asynctasks/TaskManagersIntTest.java index 5e173e726..35a452cd6 100644 --- a/datashare-tasks/src/test/java/org/icij/datashare/asynctasks/TaskManagersIntTest.java +++ b/datashare-tasks/src/test/java/org/icij/datashare/asynctasks/TaskManagersIntTest.java @@ -58,7 +58,7 @@ public static Collection taskServices() throws Exception { "messageBusAddress", "amqp://admin:admin@rabbitmq")); final RedissonClient redissonClient = new RedissonClientFactory().withOptions( Options.from(propertiesProvider.getProperties())).create(); - Map> amqpTasks = new RedissonMap<>(new TaskManagerRedis.TaskViewCodec(), + Map> amqpTasks = new RedissonMap<>(new TaskManagerRedis.RedisCodec<>(TaskManager.TaskMetadata.class), new CommandSyncService(((Redisson) redissonClient).getConnectionManager(), new RedissonObjectBuilder(redissonClient)), "tasks:queue:test", diff --git a/datashare-tasks/src/test/java/org/icij/datashare/asynctasks/TaskTest.java b/datashare-tasks/src/test/java/org/icij/datashare/asynctasks/TaskTest.java index c5226c284..fe893fca3 100644 --- a/datashare-tasks/src/test/java/org/icij/datashare/asynctasks/TaskTest.java +++ b/datashare-tasks/src/test/java/org/icij/datashare/asynctasks/TaskTest.java @@ -2,7 +2,6 @@ import org.fest.assertions.Assertions; import org.icij.datashare.json.JsonObjectMapper; -import org.icij.datashare.user.User; import org.junit.Test; import java.util.HashMap; @@ -12,14 +11,13 @@ import java.util.concurrent.TimeUnit; import static org.fest.assertions.Assertions.assertThat; -import static org.fest.assertions.MapAssert.entry; public class TaskTest { private final ExecutorService executor = Executors.newSingleThreadExecutor(); @Test public void test_get_result_sync_when_task_is_running() throws InterruptedException { - Task taskView = new Task<>("name", User.local(), new HashMap<>()); + Task taskView = new Task<>("name", new HashMap<>()); executor.execute(() -> { try { taskView.getResult(1, TimeUnit.SECONDS); @@ -44,15 +42,9 @@ public void test_get_result_sync_when_task_is_not_local() { assertThat(taskView.getProgress()).isEqualTo(1); } - @Test - public void test_user_group_parameters() { - Task taskView = new Task<>("foo", User.local(), new Group("bar"), Map.of("baz", "qux")); - assertThat(taskView.args).includes(entry("group", new Group("bar")), entry("user", User.local()), entry("baz", "qux")); - } - @Test public void test_progress() { - Task taskView = new Task<>("name", User.local(), new HashMap<>()); + Task taskView = new Task<>("name", new HashMap<>()); assertThat(taskView.getProgress()).isEqualTo(0); assertThat(taskView.getState()).isEqualTo(Task.State.CREATED); @@ -90,10 +82,9 @@ public void test_json_deserialize() throws Exception { @Test public void test_serialize_deserialize() throws Exception { - Task taskView = new Task<>("name", User.local(), Map.of("key", "value")); + Task taskView = new Task<>("name", Map.of("key", "value")); String json = JsonObjectMapper.MAPPER.writeValueAsString(taskView); assertThat(json).contains("\"@type\":\"Task\""); - assertThat(json).contains("\"user\":{\"@type\":\"org.icij.datashare.user.User\""); Task taskCreation = JsonObjectMapper.MAPPER.readValue(json, Task.class); assertThat(taskCreation).isEqualTo(taskView); diff --git a/datashare-tasks/src/test/java/org/icij/datashare/asynctasks/TaskWorkerLoopTest.java b/datashare-tasks/src/test/java/org/icij/datashare/asynctasks/TaskWorkerLoopTest.java index 4cda5e2a9..bc189e044 100644 --- a/datashare-tasks/src/test/java/org/icij/datashare/asynctasks/TaskWorkerLoopTest.java +++ b/datashare-tasks/src/test/java/org/icij/datashare/asynctasks/TaskWorkerLoopTest.java @@ -1,6 +1,5 @@ package org.icij.datashare.asynctasks; -import org.icij.datashare.user.User; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentMatchers; @@ -28,7 +27,7 @@ public class TaskWorkerLoopTest { @Test(timeout = 2000) public void test_loop() throws Exception { TaskWorkerLoop app = new TaskWorkerLoop(registry, supplier); - Task taskView = new Task<>(TestFactory.HelloWorld.class.getName(), User.local(), Map.of("greeted", "world")); + Task taskView = new Task<>(TestFactory.HelloWorld.class.getName(), Map.of("greeted", "world")); Mockito.when(supplier.get(ArgumentMatchers.anyInt(), ArgumentMatchers.any())).thenReturn(taskView); CountDownLatch taskStarted = whenTaskHasStarted(taskView.id); @@ -42,9 +41,9 @@ public void test_loop() throws Exception { } @Test(timeout = 2000) - public void test_unknown_task() throws Exception { + public void test_unknown_task() { TaskWorkerLoop app = new TaskWorkerLoop(registry, supplier); - Task taskView = new Task<>("unknown_task", User.local(), Map.of()); + Task taskView = new Task<>("unknown_task", Map.of()); try { app.handle(taskView); @@ -57,7 +56,7 @@ public void test_unknown_task() throws Exception { @Test(timeout = 2000) public void test_cancel_task() throws Exception { TaskWorkerLoop app = new TaskWorkerLoop(registry, supplier); - Task taskView = new Task<>(TestFactory.SleepForever.class.getName(), User.local(), Map.of()); + Task taskView = new Task<>(TestFactory.SleepForever.class.getName(), Map.of()); Mockito.when(supplier.get(ArgumentMatchers.anyInt(), ArgumentMatchers.any())).thenReturn(taskView); boolean requeue = false; CountDownLatch taskStarted = whenTaskHasStarted(taskView.id); @@ -75,7 +74,7 @@ public void test_cancel_task() throws Exception { @Test(timeout = 2000) public void test_cancel_task_and_requeue() throws Exception { TaskWorkerLoop app = new TaskWorkerLoop(registry, supplier); - Task taskView = new Task<>(TestFactory.SleepForever.class.getName(), User.local(), Map.of()); + Task taskView = new Task<>(TestFactory.SleepForever.class.getName(), Map.of()); Mockito.when(supplier.get(ArgumentMatchers.anyInt(), ArgumentMatchers.any())).thenReturn(taskView); boolean requeue = true; CountDownLatch taskStarted = whenTaskHasStarted(taskView.id); @@ -93,7 +92,7 @@ public void test_cancel_task_and_requeue() throws Exception { @Test(timeout = 2000) public void test_task_interrupted() throws Exception { TaskWorkerLoop app = new TaskWorkerLoop(registry, supplier); - Task taskView = new Task<>(TestFactory.SleepForever.class.getName(), User.local(), Map.of()); + Task taskView = new Task<>(TestFactory.SleepForever.class.getName(), Map.of()); Mockito.when(supplier.get(ArgumentMatchers.anyInt(), ArgumentMatchers.any())).thenReturn(taskView); CountDownLatch taskStarted = whenTaskHasStarted(taskView.id);