Skip to content

Commit

Permalink
Added lease extension to java sdk and fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ystxn committed Dec 24, 2024
1 parent e0780c3 commit 5f6f138
Show file tree
Hide file tree
Showing 13 changed files with 206 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ dependencies {
testImplementation "org.junit.jupiter:junit-jupiter-api:${versions.junit}"
testRuntimeOnly "org.junit.jupiter:junit-jupiter-engine:${versions.junit}"

testImplementation "org.powermock:powermock-module-junit4:2.0.9"
testImplementation "org.powermock:powermock-api-mockito2:2.0.9"
testImplementation 'org.mockito:mockito-inline:5.2.0'

testImplementation 'org.spockframework:spock-core:2.3-groovy-3.0'
testImplementation 'org.codehaus.groovy:groovy:3.0.15'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@
import java.io.StringWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
Expand Down Expand Up @@ -68,6 +72,10 @@ class TaskRunner {
private final EventDispatcher<TaskRunnerEvent> eventDispatcher;
private final LinkedBlockingQueue<Task> tasksTobeExecuted;
private final boolean enableUpdateV2;
private static final int LEASE_EXTEND_RETRY_COUNT = 3;
private static final double LEASE_EXTEND_DURATION_FACTOR = 0.8;
private final ScheduledExecutorService leaseExtendExecutorService;
private Map<String, ScheduledFuture<?>> leaseExtendMap = new HashMap<>();

TaskRunner(Worker worker,
TaskClient taskClient,
Expand Down Expand Up @@ -122,6 +130,15 @@ class TaskRunner {
pollingIntervalInMillis,
domain);
LOGGER.info("Polling errors for taskType {} will be printed at every {} occurrence.", taskType, errorAt);

LOGGER.info("Initialized the task lease extend executor");
leaseExtendExecutorService = Executors.newSingleThreadScheduledExecutor(
new BasicThreadFactory.Builder()
.namingPattern("workflow-lease-extend-%d")
.daemon(true)
.uncaughtExceptionHandler(uncaughtExceptionHandler)
.build()
);
}

public void pollAndExecute() {
Expand All @@ -145,7 +162,25 @@ public void pollAndExecute() {
LOGGER.trace("Poller for task {} waited for {} ms before getting {} tasks to execute", taskType, stopwatch.elapsed(TimeUnit.MILLISECONDS), tasks.size());
stopwatch = null;
}
tasks.forEach(task -> this.executorService.submit(() -> this.processTask(task)));
tasks.forEach(task -> {
Future<Task> taskFuture = this.executorService.submit(() -> this.processTask(task));

if (task.getResponseTimeoutSeconds() > 0 && worker.leaseExtendEnabled()) {
ScheduledFuture<?> scheduledFuture = leaseExtendMap.get(task.getTaskId());
if (scheduledFuture != null) {
scheduledFuture.cancel(false);
}

long delay = Math.round(task.getResponseTimeoutSeconds() * LEASE_EXTEND_DURATION_FACTOR);
ScheduledFuture<?> leaseExtendFuture = leaseExtendExecutorService.scheduleWithFixedDelay(
extendLease(task, taskFuture),
delay,
delay,
TimeUnit.SECONDS
);
leaseExtendMap.put(task.getTaskId(), leaseExtendFuture);
}
});
} catch (Throwable t) {
LOGGER.error(t.getMessage(), t);
}
Expand Down Expand Up @@ -251,7 +286,7 @@ private List<Task> pollTask(int count) {
LOGGER.error("Uncaught exception. Thread {} will exit now", thread, error);
};

private void processTask(Task task) {
private Task processTask(Task task) {
eventDispatcher.publish(new TaskExecutionStarted(taskType, task.getTaskId(), worker.getIdentity()));
LOGGER.trace("Executing task: {} of type: {} in worker: {} at {}", task.getTaskId(), taskType, worker.getClass().getSimpleName(), worker.getIdentity());
LOGGER.trace("task {} is getting executed after {} ms of getting polled", task.getTaskId(), (System.currentTimeMillis() - task.getStartTime()));
Expand All @@ -271,6 +306,7 @@ private void processTask(Task task) {
} finally {
permits.release();
}
return task;
}

private void executeTask(Worker worker, Task task) {
Expand Down Expand Up @@ -400,4 +436,30 @@ private void handleException(Throwable t, TaskResult result, Worker worker, Task
result.log(stringWriter.toString());
updateTaskResult(updateRetryCount, task, result, worker);
}

private Runnable extendLease(Task task, Future<Task> taskCompletableFuture) {
return () -> {
if (taskCompletableFuture.isDone()) {
LOGGER.warn(
"Task processing for {} completed, but its lease extend was not cancelled",
task.getTaskId());
return;
}
LOGGER.info("Attempting to extend lease for {}", task.getTaskId());
try {
TaskResult result = new TaskResult(task);
result.setExtendLease(true);
retryOperation(
(TaskResult taskResult) -> {
taskClient.updateTask(taskResult);
return null;
},
LEASE_EXTEND_RETRY_COUNT,
result,
"extend lease");
} catch (Exception e) {
LOGGER.error("Failed to extend lease for {}", task.getTaskId(), e);
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@ public TaskRunnerConfigurer.Builder withTaskToDomain(Map<String, String> taskToD
public TaskRunnerConfigurer.Builder withTaskThreadCount(
Map<String, Integer> taskToThreadCount) {
this.taskToThreadCount = taskToThreadCount;
if (taskToThreadCount.values().stream().anyMatch(v -> v < 1)) {
throw new IllegalArgumentException("No. of threads cannot be less than 1");
}
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ public interface Worker {
String PROP_ALL_WORKERS = "all";
String PROP_LOG_INTERVAL = "log_interval";
String PROP_POLL_INTERVAL = "poll_interval";
String PROP_LEASE_EXTEND_ENABLED = "leaseExtendEnabled";
String PROP_PAUSED = "paused";

/**
Expand Down Expand Up @@ -91,6 +92,10 @@ default int getPollingInterval() {
return PropertyFactory.getInteger(getTaskDefName(), PROP_POLL_INTERVAL, 1000);
}

default boolean leaseExtendEnabled() {
return PropertyFactory.getBoolean(getTaskDefName(), PROP_LEASE_EXTEND_ENABLED, false);
}

static Worker create(String taskType, Function<Task, TaskResult> executor) {
return new Worker() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,58 +15,70 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicInteger;

import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.MethodOrderer;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestMethodOrder;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;

import com.netflix.conductor.client.exception.ConductorClientException;
import com.netflix.conductor.client.http.TaskClient;
import com.netflix.conductor.client.worker.Worker;
import com.netflix.conductor.common.metadata.tasks.Task;
import com.netflix.conductor.common.metadata.tasks.TaskResult;

import static com.netflix.conductor.common.metadata.tasks.TaskResult.Status.COMPLETED;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

@TestMethodOrder(MethodOrderer.MethodName.class)
public class TaskRunnerConfigurerTest {

private static final String TEST_TASK_DEF_NAME = "test";

private TaskClient client;

@Before
@BeforeEach
public void setup() {
client = Mockito.mock(TaskClient.class);
}

@Test(expected = NullPointerException.class)
@Test
public void testNoWorkersException() {
new TaskRunnerConfigurer.Builder(null, null).build();
assertThrows(NullPointerException.class, () -> new TaskRunnerConfigurer.Builder(null, null).build());
}

@Test(expected = ConductorClientException.class)
@Test
public void testInvalidThreadConfig() {
Worker worker1 = Worker.create("task1", TaskResult::new);
Worker worker2 = Worker.create("task2", TaskResult::new);
Map<String, Integer> taskThreadCount = new HashMap<>();
taskThreadCount.put(worker1.getTaskDefName(), 2);
taskThreadCount.put(worker1.getTaskDefName(), 0);
taskThreadCount.put(worker2.getTaskDefName(), 3);
new TaskRunnerConfigurer.Builder(client, Arrays.asList(worker1, worker2))
.withThreadCount(10)
.withTaskThreadCount(taskThreadCount)
.build();

assertThrows(IllegalArgumentException.class, () -> new TaskRunnerConfigurer.Builder(client, Arrays.asList(worker1, worker2))
.withThreadCount(-1)
.withTaskThreadCount(taskThreadCount)
.build());

assertThrows(IllegalArgumentException.class, () -> new TaskRunnerConfigurer.Builder(client, Arrays.asList(worker1, worker2))
.withTaskThreadCount(taskThreadCount)
.build());
}

@Test
Expand All @@ -81,12 +93,12 @@ public void testMissingTaskThreadConfig() {
.build();

assertFalse(configurer.getTaskThreadCount().isEmpty());
assertEquals(2, configurer.getTaskThreadCount().size());
assertEquals(1, configurer.getTaskThreadCount().size());
assertEquals(2, configurer.getTaskThreadCount().get("task1").intValue());
assertEquals(1, configurer.getTaskThreadCount().get("task2").intValue());
}

@Test
@SuppressWarnings("deprecation")
public void testPerTaskThreadPool() {
Worker worker1 = Worker.create("task1", TaskResult::new);
Worker worker2 = Worker.create("task2", TaskResult::new);
Expand All @@ -104,19 +116,18 @@ public void testPerTaskThreadPool() {
}

@Test
@SuppressWarnings("deprecation")
public void testSharedThreadPool() {
Worker worker = Worker.create(TEST_TASK_DEF_NAME, TaskResult::new);
TaskRunnerConfigurer configurer =
new TaskRunnerConfigurer.Builder(client, Arrays.asList(worker, worker, worker))
.build();
configurer.init();
assertEquals(3, configurer.getThreadCount());
assertEquals(-1, configurer.getThreadCount());
assertEquals(500, configurer.getSleepWhenRetry());
assertEquals(3, configurer.getUpdateRetryCount());
assertEquals(10, configurer.getShutdownGracePeriodSeconds());
assertFalse(configurer.getTaskThreadCount().isEmpty());
assertEquals(1, configurer.getTaskThreadCount().size());
assertEquals(3, configurer.getTaskThreadCount().get(TEST_TASK_DEF_NAME).intValue());
assertTrue(configurer.getTaskThreadCount().isEmpty());

configurer =
new TaskRunnerConfigurer.Builder(client, Collections.singletonList(worker))
Expand All @@ -133,9 +144,7 @@ public void testSharedThreadPool() {
assertEquals(10, configurer.getUpdateRetryCount());
assertEquals(15, configurer.getShutdownGracePeriodSeconds());
assertEquals("test-worker-", configurer.getWorkerNamePrefix());
assertFalse(configurer.getTaskThreadCount().isEmpty());
assertEquals(1, configurer.getTaskThreadCount().size());
assertEquals(100, configurer.getTaskThreadCount().get(TEST_TASK_DEF_NAME).intValue());
assertTrue(configurer.getTaskThreadCount().isEmpty());
}

@Test
Expand Down Expand Up @@ -186,9 +195,9 @@ public void testMultipleWorkersExecution() throws Exception {
Object[] args = invocation.getArguments();
String taskName = args[0].toString();
if (taskName.equals(task1Name)) {
return Arrays.asList(task1);
return List.of(task1);
} else if (taskName.equals(task2Name)) {
return Arrays.asList(task2);
return List.of(task2);
} else {
return Collections.emptyList();
}
Expand Down Expand Up @@ -220,6 +229,58 @@ public void testMultipleWorkersExecution() throws Exception {
assertEquals(1, task2Counter.get());
}

@Test
public void testLeaseExtension() throws Exception {
TaskClient taskClient = mock(TaskClient.class);
String taskName = "task1";

Worker worker = mock(Worker.class);
when(worker.getTaskDefName()).thenReturn(taskName);
when(worker.leaseExtendEnabled()).thenReturn(true);

doAnswer(invocation -> {
TaskResult result = new TaskResult(invocation.getArgument(0));
result.setStatus(TaskResult.Status.IN_PROGRESS);
return result;
}).when(worker).execute(any(Task.class));

Task task = new Task();
task.setTaskId("task123");
task.setTaskDefName(taskName);
task.setStatus(Task.Status.IN_PROGRESS);
task.setResponseTimeoutSeconds(2000);

when(taskClient.batchPollTasksInDomain(any(), any(), any(), anyInt(), anyInt()))
.thenAnswer((invocation) -> List.of(task));
when(taskClient.ack(any(), any())).thenReturn(true);

CountDownLatch latch = new CountDownLatch(1);
doAnswer(invocation -> {
latch.countDown();
return null;
}).when(taskClient).updateTask(any(TaskResult.class));

TaskRunnerConfigurer configurer = new TaskRunnerConfigurer.Builder(taskClient, List.of(worker))
.withSleepWhenRetry(100)
.withUpdateRetryCount(3)
.withThreadCount(1)
.build();

configurer.init();
latch.await();

ArgumentCaptor<TaskResult> taskResultCaptor = ArgumentCaptor.forClass(TaskResult.class);
verify(taskClient, atLeastOnce()).updateTask(taskResultCaptor.capture());

TaskResult capturedResult = taskResultCaptor.getValue();
assertNotNull(capturedResult);
assertEquals("task123", capturedResult.getTaskId());
assertEquals(TaskResult.Status.IN_PROGRESS, capturedResult.getStatus());

verify(worker, atLeastOnce()).execute(task);
assertTrue(worker.leaseExtendEnabled(), "Worker lease extension should be enabled");
}

private Task testTask(String taskDefName) {
Task task = new Task();
task.setTaskId(UUID.randomUUID().toString());
Expand Down
Loading

0 comments on commit 5f6f138

Please sign in to comment.