Skip to content

Commit

Permalink
adds unit test for task queue producer
Browse files Browse the repository at this point in the history
  • Loading branch information
eduwercamacaro committed Nov 19, 2024
1 parent 1272b18 commit 96e874b
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import io.littlehorse.common.model.getable.objectId.TaskRunIdModel;
import io.littlehorse.common.proto.TaskClaimEventPb;
import io.littlehorse.common.util.LHUtil;
import io.littlehorse.server.streams.taskqueue.PollTaskRequestObserver;
import io.littlehorse.server.streams.topology.core.ExecutionContext;
import io.littlehorse.server.streams.topology.core.ProcessorExecutionContext;
import java.util.Date;
Expand Down Expand Up @@ -40,11 +39,11 @@ public class TaskClaimEvent extends CoreSubCommand<TaskClaimEventPb> {

public TaskClaimEvent() {}

public TaskClaimEvent(ScheduledTaskModel task, PollTaskRequestObserver taskClaimer) {
public TaskClaimEvent(ScheduledTaskModel task, String taskWorkerVersion, String taskWorkerId) {
this.taskRunId = task.getTaskRunId();
this.time = new Date();
this.taskWorkerId = taskClaimer.getClientId();
this.taskWorkerVersion = taskClaimer.getTaskWorkerVersion();
this.taskWorkerId = taskWorkerId;
this.taskWorkerVersion = taskWorkerVersion;
}

public Class<TaskClaimEventPb> getProtoBaseClass() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,24 @@
import java.util.concurrent.Future;
import org.apache.kafka.clients.producer.Callback;
import org.apache.kafka.clients.producer.KafkaProducer;
import org.apache.kafka.clients.producer.Producer;
import org.apache.kafka.clients.producer.ProducerRecord;
import org.apache.kafka.clients.producer.RecordMetadata;
import org.apache.kafka.common.header.Header;
import org.apache.kafka.common.utils.Bytes;

public class LHProducer implements Closeable {

private final KafkaProducer<String, Bytes> prod;
private final Producer<String, Bytes> prod;

public LHProducer(Properties configs) {
prod = new KafkaProducer<>(configs);
}

public LHProducer(Producer<String, Bytes> prod) {
this.prod = prod;
}

public Future<RecordMetadata> send(String key, AbstractCommand<?> t, String topic, Callback cb, Header... headers) {
return sendRecord(new ProducerRecord<>(topic, null, key, new Bytes(t.toBytes()), List.of(headers)), cb);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@ public class TaskClaimEventProducerCallback implements Callback {
private final ScheduledTaskModel scheduledTask;
private final PollTaskRequestObserver client;

public TaskClaimEventProducerCallback(ScheduledTaskModel scheduledTask, PollTaskRequestObserver client) {
public TaskClaimEventProducerCallback(
final ScheduledTaskModel scheduledTask, final PollTaskRequestObserver client) {
this.scheduledTask = scheduledTask;
this.client = client;
}

@Override
public void onCompletion(RecordMetadata metadata, Exception exception) {
public void onCompletion(final RecordMetadata metadata, final Exception exception) {
if (exception == null) {
client.sendResponse(scheduledTask);
} else {
client.onError(exception);
log.error("error", exception);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
import java.util.concurrent.locks.ReentrantLock;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.apache.kafka.clients.producer.Callback;
import org.apache.kafka.clients.producer.RecordMetadata;
import org.apache.kafka.streams.processor.TaskId;

// One instance of this class is responsible for coordinating the grpc backend for
Expand Down Expand Up @@ -289,21 +287,4 @@ public int size() {
}

private record QueueItem(TaskId streamsTaskId, ScheduledTaskModel scheduledTask) {}

private final class TaskClaimCallback implements Callback {
private final ScheduledTaskModel scheduledTask;
private final PollTaskRequestObserver luckyClient;

public TaskClaimCallback(ScheduledTaskModel scheduledTask, PollTaskRequestObserver luckyClient) {
this.scheduledTask = scheduledTask;
this.luckyClient = luckyClient;
}

@Override
public void onCompletion(RecordMetadata metadata, Exception exception) {
if (exception == null) {
luckyClient.sendResponse(scheduledTask);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ public TaskQueueCommandProducer(LHProducer producer, String commandTopic) {
* infers the request context from the GRPC Context.
*/
public void returnTaskToClient(ScheduledTaskModel scheduledTask, PollTaskRequestObserver client) {
TaskClaimEvent claimEvent = new TaskClaimEvent(scheduledTask, client);
TaskClaimEvent claimEvent =
new TaskClaimEvent(scheduledTask, client.getTaskWorkerVersion(), client.getClientId());
TaskClaimEventProducerCallback callback = new TaskClaimEventProducerCallback(scheduledTask, client);
processCommand(claimEvent, client.getPrincipalId(), client.getTenantId(), callback);
}
Expand Down
28 changes: 28 additions & 0 deletions server/src/test/java/io/littlehorse/common/MockLHProducer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package io.littlehorse.common;

import io.littlehorse.common.util.LHProducer;
import org.apache.kafka.clients.producer.MockProducer;
import org.apache.kafka.common.serialization.Serdes;
import org.apache.kafka.common.utils.Bytes;

public class MockLHProducer extends LHProducer {
private final MockProducer<String, Bytes> mockProducer;

private MockLHProducer(MockProducer<String, Bytes> mockProducer) {
super(mockProducer);
this.mockProducer = mockProducer;
}

public MockProducer<String, Bytes> getKafkaProducer() {
return mockProducer;
}

public static MockLHProducer create() {
return create(true);
}

public static MockLHProducer create(boolean autoComplete) {
return new MockLHProducer(new MockProducer<>(
autoComplete, Serdes.String().serializer(), Serdes.Bytes().serializer()));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package io.littlehorse.server.streams.taskqueue;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

import com.google.protobuf.Empty;
import io.grpc.stub.StreamObserver;
import io.littlehorse.common.AuthorizationContext;
import io.littlehorse.common.LHSerializable;
import io.littlehorse.common.MockLHProducer;
import io.littlehorse.common.model.ScheduledTaskModel;
import io.littlehorse.common.model.corecommand.subcommand.ReportTaskRunModel;
import io.littlehorse.common.model.getable.core.taskrun.TaskRunSourceModel;
import io.littlehorse.common.model.getable.objectId.PrincipalIdModel;
import io.littlehorse.common.model.getable.objectId.TaskDefIdModel;
import io.littlehorse.common.model.getable.objectId.TaskRunIdModel;
import io.littlehorse.common.model.getable.objectId.TenantIdModel;
import io.littlehorse.common.model.getable.objectId.WfRunIdModel;
import io.littlehorse.sdk.common.LHLibUtil;
import io.littlehorse.sdk.common.proto.ReportTaskRun;
import io.littlehorse.sdk.common.proto.TaskRunId;
import io.littlehorse.sdk.common.proto.VariableValue;
import io.littlehorse.sdk.common.proto.WfRunId;
import io.littlehorse.server.streams.topology.core.ProcessorExecutionContext;
import java.util.Date;
import java.util.List;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.junit.jupiter.MockitoExtension;

@ExtendWith(MockitoExtension.class)
class TaskQueueCommandProducerTest {

private final MockLHProducer autoCompleteMockProducer = MockLHProducer.create();
private final MockLHProducer mockProducer = MockLHProducer.create(false);
private final String commandTopic = "test";
private final PollTaskRequestObserver requestObserver = mock();
private final ProcessorExecutionContext processorContext = mock();
private final AuthorizationContext mockAuthContext = mock();
private final StreamObserver<Empty> reportTaskRunObserver = mock();

@BeforeEach
void setup() {
when(requestObserver.getPrincipalId()).thenReturn(new PrincipalIdModel("test-principal"));
when(requestObserver.getTenantId()).thenReturn(new TenantIdModel("test-tenant"));
when(requestObserver.getTaskWorkerVersion()).thenReturn("1.0.0");
when(requestObserver.getClientId()).thenReturn("test-worker");
when(mockAuthContext.principalId()).thenReturn(new PrincipalIdModel("test-principal"));
when(mockAuthContext.tenantId()).thenReturn(new TenantIdModel("test-tenant"));
}

@Test
void shouldReturnTaskToClientAfterProducerRecordSent() {
final TaskQueueCommandProducer taskQueueProducer =
new TaskQueueCommandProducer(autoCompleteMockProducer, commandTopic);
ScheduledTaskModel scheduledTask = buildScheduledTask();
taskQueueProducer.returnTaskToClient(scheduledTask, requestObserver);
verify(requestObserver).sendResponse(scheduledTask);
}

@Test
void shouldCloseResponseObserverWithErrorOnProducerFailures() {
final RuntimeException expectedException = new RuntimeException("oops");
final TaskQueueCommandProducer taskQueueProducer = new TaskQueueCommandProducer(mockProducer, commandTopic);
ScheduledTaskModel scheduledTask = buildScheduledTask();
taskQueueProducer.returnTaskToClient(scheduledTask, requestObserver);
mockProducer.getKafkaProducer().errorNext(expectedException);
verify(requestObserver).onError(expectedException);
}

@Test
void shouldSendReportTaskRunCommands() {
final TaskQueueCommandProducer taskQueueProducer =
new TaskQueueCommandProducer(autoCompleteMockProducer, commandTopic);
ReportTaskRunModel reportTaskRun = buildReportTaskRunModel();
taskQueueProducer.send(reportTaskRun, mockAuthContext, reportTaskRunObserver);
verify(reportTaskRunObserver).onNext(any());
verify(reportTaskRunObserver).onCompleted();
}

private ScheduledTaskModel buildScheduledTask() {
final TaskDefIdModel taskDefId = new TaskDefIdModel("task-1");
final WfRunIdModel wfRunId = new WfRunIdModel("wf-run-1");
final TaskRunIdModel taskRunId = new TaskRunIdModel(wfRunId, "task-run-1");
ScheduledTaskModel scheduledTask = new ScheduledTaskModel();
scheduledTask.setVariables(List.of());
scheduledTask.setAttemptNumber(1);
scheduledTask.setCreatedAt(new Date());
scheduledTask.setSource(new TaskRunSourceModel());
scheduledTask.setTaskDefId(taskDefId);
scheduledTask.setTaskRunId(taskRunId);
return scheduledTask;
}

private ReportTaskRunModel buildReportTaskRunModel() {
ReportTaskRun reportTaskRun = ReportTaskRun.newBuilder()
.setAttemptNumber(1)
.setOutput(VariableValue.newBuilder().setInt(10))
.setTaskRunId(TaskRunId.newBuilder()
.setWfRunId(WfRunId.newBuilder().setId("test"))
.setTaskGuid("test-guid"))
.setTime(LHLibUtil.fromDate(new Date()))
.build();
return LHSerializable.fromProto(reportTaskRun, ReportTaskRunModel.class, processorContext);
}
}

0 comments on commit 96e874b

Please sign in to comment.