From da185eaa1fe70a7a0185f07c838a1d73d1749e1c Mon Sep 17 00:00:00 2001 From: Paulius Peciura Date: Thu, 10 Sep 2020 12:02:40 +0100 Subject: [PATCH] Poll the count of running step executions In a combination of short poll interval and large number of step executions that take a long time to run, memory consumption can go high - each step execution has its own reference to a job execution, which refers step executions for the job. This means that the job executions are different instances, resulting in creating a lot of objects. Instead, we query the database to get the number of step executions that are still running. Once all of them are finished, we then would fetch all steps and assign the same job execution that can be a shared instance. --- .../batch/core/BatchStatus.java | 7 +- .../batch/core/explore/JobExplorer.java | 23 +++++- .../explore/support/SimpleJobExplorer.java | 23 ++++++ .../repository/dao/JdbcStepExecutionDao.java | 64 +++++++++++----- .../repository/dao/MapStepExecutionDao.java | 40 ++++++---- .../core/repository/dao/StepExecutionDao.java | 21 +++++- .../support/CommandLineJobRunnerTests.java | 33 ++++---- .../MessageChannelPartitionHandler.java | 75 +++++++------------ .../MessageChannelPartitionHandlerTests.java | 28 +++---- 9 files changed, 195 insertions(+), 119 deletions(-) diff --git a/spring-batch-core/src/main/java/org/springframework/batch/core/BatchStatus.java b/spring-batch-core/src/main/java/org/springframework/batch/core/BatchStatus.java index fab23ada7e..1a155d6d96 100644 --- a/spring-batch-core/src/main/java/org/springframework/batch/core/BatchStatus.java +++ b/spring-batch-core/src/main/java/org/springframework/batch/core/BatchStatus.java @@ -16,6 +16,9 @@ package org.springframework.batch.core; +import java.util.Arrays; +import java.util.List; + /** * Enumeration representing the status of an Execution. * @@ -39,6 +42,8 @@ public enum BatchStatus { */ COMPLETED, STARTING, STARTED, STOPPING, STOPPED, FAILED, ABANDONED, UNKNOWN; + public static final List RUNNING_STATUSES = Arrays.asList(STARTING, STARTED); + public static BatchStatus max(BatchStatus status1, BatchStatus status2) { return status1.isGreaterThan(status2) ? status1 : status2; } @@ -49,7 +54,7 @@ public static BatchStatus max(BatchStatus status1, BatchStatus status2) { * @return true if the status is STARTING, STARTED */ public boolean isRunning() { - return this == STARTING || this == STARTED; + return RUNNING_STATUSES.contains(this); } /** diff --git a/spring-batch-core/src/main/java/org/springframework/batch/core/explore/JobExplorer.java b/spring-batch-core/src/main/java/org/springframework/batch/core/explore/JobExplorer.java index 586ad67be1..fa636accbe 100644 --- a/spring-batch-core/src/main/java/org/springframework/batch/core/explore/JobExplorer.java +++ b/spring-batch-core/src/main/java/org/springframework/batch/core/explore/JobExplorer.java @@ -15,9 +15,7 @@ */ package org.springframework.batch.core.explore; -import java.util.List; -import java.util.Set; - +import org.springframework.batch.core.BatchStatus; import org.springframework.batch.core.JobExecution; import org.springframework.batch.core.JobInstance; import org.springframework.batch.core.StepExecution; @@ -25,6 +23,10 @@ import org.springframework.batch.item.ExecutionContext; import org.springframework.lang.Nullable; +import java.util.Collection; +import java.util.List; +import java.util.Set; + /** * Entry point for browsing executions of running or historical jobs and steps. * Since the data may be re-hydrated from persistent storage, it may not contain @@ -89,6 +91,14 @@ default JobInstance getLastJobInstance(String jobName) { @Nullable StepExecution getStepExecution(@Nullable Long jobExecutionId, @Nullable Long stepExecutionId); + /** + * Retrieve number of step executions that match the step execution ids and the batch statuses + * @param stepExecutionIds given step execution ids + * @param matchingBatchStatuses given batch statuses to match against + * @return number of {@link StepExecution} matching the criteria + */ + int getStepExecutionCount(Collection stepExecutionIds, Collection matchingBatchStatuses); + /** * @param instanceId {@link Long} id for the jobInstance to obtain. * @return the {@link JobInstance} with this id, or null @@ -164,4 +174,11 @@ default JobExecution getLastJobExecution(JobInstance jobInstance) { */ int getJobInstanceCount(@Nullable String jobName) throws NoSuchJobException; + /** + * Find step executions in bulk + * @param jobExecutionId given job execution id + * @param stepExecutionIds given step execution ids + * @return collection of {@link StepExecution} + */ + Collection getStepExecutions(Long jobExecutionId, Collection stepExecutionIds); } diff --git a/spring-batch-core/src/main/java/org/springframework/batch/core/explore/support/SimpleJobExplorer.java b/spring-batch-core/src/main/java/org/springframework/batch/core/explore/support/SimpleJobExplorer.java index 81a1d9c333..f3b80c5515 100644 --- a/spring-batch-core/src/main/java/org/springframework/batch/core/explore/support/SimpleJobExplorer.java +++ b/spring-batch-core/src/main/java/org/springframework/batch/core/explore/support/SimpleJobExplorer.java @@ -16,6 +16,7 @@ package org.springframework.batch.core.explore.support; +import org.springframework.batch.core.BatchStatus; import org.springframework.batch.core.JobExecution; import org.springframework.batch.core.JobInstance; import org.springframework.batch.core.StepExecution; @@ -27,6 +28,7 @@ import org.springframework.batch.core.repository.dao.StepExecutionDao; import org.springframework.lang.Nullable; +import java.util.Collection; import java.util.List; import java.util.Set; @@ -165,6 +167,14 @@ public StepExecution getStepExecution(@Nullable Long jobExecutionId, @Nullable L return stepExecution; } + @Override + public int getStepExecutionCount(Collection stepExecutionIds, Collection matchingBatchStatuses) { + if (stepExecutionIds.isEmpty() || matchingBatchStatuses.isEmpty()) { + return 0; + } + return stepExecutionDao.countStepExecutions(stepExecutionIds, matchingBatchStatuses); + } + /* * (non-Javadoc) * @@ -221,6 +231,19 @@ public int getJobInstanceCount(@Nullable String jobName) throws NoSuchJobExcepti return jobInstanceDao.getJobInstanceCount(jobName); } + @Override + @Nullable + public Collection getStepExecutions(Long jobExecutionId, Collection stepExecutionIds) { + JobExecution jobExecution = jobExecutionDao.getJobExecution(jobExecutionId); + if (jobExecution == null) { + return null; + } + getJobExecutionDependencies(jobExecution); + Collection stepExecutions = stepExecutionDao.getStepExecutions(jobExecution, stepExecutionIds); + stepExecutions.forEach(this::getStepExecutionDependencies); + return stepExecutions; + } + /* * Find all dependencies for a JobExecution, including JobInstance (which * requires JobParameters) plus StepExecutions diff --git a/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/JdbcStepExecutionDao.java b/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/JdbcStepExecutionDao.java index d5712fa227..c63c9be6c4 100644 --- a/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/JdbcStepExecutionDao.java +++ b/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/JdbcStepExecutionDao.java @@ -16,25 +16,9 @@ package org.springframework.batch.core.repository.dao; -import java.sql.PreparedStatement; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.Timestamp; -import java.sql.Types; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Iterator; -import java.util.List; - import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; - -import org.springframework.batch.core.BatchStatus; -import org.springframework.batch.core.ExitStatus; -import org.springframework.batch.core.JobExecution; -import org.springframework.batch.core.JobInstance; -import org.springframework.batch.core.StepExecution; +import org.springframework.batch.core.*; import org.springframework.beans.factory.InitializingBean; import org.springframework.dao.OptimisticLockingFailureException; import org.springframework.jdbc.core.BatchPreparedStatementSetter; @@ -43,6 +27,11 @@ import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import java.sql.*; +import java.util.*; +import java.util.stream.Collectors; +import java.util.stream.Stream; + /** * JDBC implementation of {@link StepExecutionDao}.
* @@ -114,6 +103,15 @@ public class JdbcStepExecutionDao extends AbstractJdbcBatchMetadataDao implement " and SE.JOB_EXECUTION_ID = JE.JOB_EXECUTION_ID " + " and SE.STEP_NAME = ?"; + // need to replace the %STEP_EXECUTION_IDS% with a known number of ?s + private static final String GET_STEP_EXECUTIONS_BY_IDS = GET_RAW_STEP_EXECUTIONS + " and STEP_EXECUTION_ID IN (%STEP_EXECUTION_IDS%)"; + + // need to replace the %STEP_EXECUTION_IDS% and %STEP_STATUSES% with a known number of ?s + private static final String COUNT_STEP_EXECUTIONS_MATCHING_IDS_AND_STATUSES = "SELECT COUNT(*) " + + "from %PREFIX%STEP_EXECUTION SE " + + "where SE.STEP_EXECUTION_ID IN (%STEP_EXECUTION_IDS%) " + + "and SE.STATUS IN (%STEP_STATUSES%)"; + private int exitMessageLength = DEFAULT_EXIT_MESSAGE_LENGTH; private DataFieldMaxValueIncrementer stepExecutionIncrementer; @@ -350,12 +348,31 @@ public StepExecution getLastStepExecution(JobInstance jobInstance, String stepNa } } + @Override + public Collection getStepExecutions(JobExecution jobExecution, Collection stepExecutionIds) { + String sql = createParameterizedQuery(GET_STEP_EXECUTIONS_BY_IDS, "%STEP_EXECUTION_IDS%", stepExecutionIds); + return getJdbcTemplate().query(getQuery(sql), + new StepExecutionRowMapper(jobExecution), + Stream.concat(Stream.of(jobExecution.getId()), stepExecutionIds.stream()).toArray()); + } + @Override public void addStepExecutions(JobExecution jobExecution) { getJdbcTemplate().query(getQuery(GET_STEP_EXECUTIONS), new StepExecutionRowMapper(jobExecution), jobExecution.getId()); } + @Override + public int countStepExecutions(Collection stepExecutionIds, Collection matchingBatchStatuses) { + String sql = createParameterizedQuery(COUNT_STEP_EXECUTIONS_MATCHING_IDS_AND_STATUSES, "%STEP_EXECUTION_IDS%", stepExecutionIds); + sql = createParameterizedQuery(sql, "%STEP_STATUSES%", matchingBatchStatuses); + Object[] args = Stream.concat(stepExecutionIds.stream(), + matchingBatchStatuses.stream().map(BatchStatus::name)).toArray(); + return getJdbcTemplate().queryForObject(getQuery(sql), + Integer.class, + args); + } + @Override public int countStepExecutions(JobInstance jobInstance, String stepName) { return getJdbcTemplate().queryForObject(getQuery(COUNT_STEP_EXECUTIONS), new Object[] { jobInstance.getInstanceId(), stepName }, Integer.class); @@ -391,4 +408,17 @@ public StepExecution mapRow(ResultSet rs, int rowNum) throws SQLException { } + /** + * Replaces a given placeholder with a number of parameters (i.e. "?"). + * + * @param sqlTemplate given sql template + * @param placeholder placeholder that is being used for parameters + * @param parameters collection of parameters with variable size + * + * @return sql query replaced with a number of parameters + */ + private static String createParameterizedQuery(String sqlTemplate, String placeholder, Collection parameters) { + String params = parameters.stream().map(p -> "?").collect(Collectors.joining(", ")); + return sqlTemplate.replace(placeholder, params); + } } diff --git a/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/MapStepExecutionDao.java b/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/MapStepExecutionDao.java index 2e3bed2466..c74865861d 100644 --- a/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/MapStepExecutionDao.java +++ b/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/MapStepExecutionDao.java @@ -15,26 +15,19 @@ */ package org.springframework.batch.core.repository.dao; -import java.lang.reflect.Field; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.Comparator; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicLong; - -import org.springframework.batch.core.Entity; -import org.springframework.batch.core.JobExecution; -import org.springframework.batch.core.JobInstance; -import org.springframework.batch.core.StepExecution; +import org.springframework.batch.core.*; import org.springframework.dao.OptimisticLockingFailureException; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ReflectionUtils; import org.springframework.util.SerializationUtils; +import java.lang.reflect.Field; +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; +import java.util.stream.Collectors; + /** * In-memory implementation of {@link StepExecutionDao}. * @@ -189,4 +182,23 @@ public int countStepExecutions(JobInstance jobInstance, String stepName) { } return count; } + + @Override + public int countStepExecutions(Collection stepExecutionIds, Collection matchingBatchStatuses) { + int count = 0; + + for (Long id: stepExecutionIds) { + if (executionsByStepExecutionId.containsKey(id) && matchingBatchStatuses.contains(executionsByStepExecutionId.get(id).getStatus())) { + count++; + } + } + return count; + } + + @Override + public Collection getStepExecutions(JobExecution jobExecution, Collection stepExecutionIds) { + return executionsByStepExecutionId.values().stream() + .filter(se -> stepExecutionIds.contains(se.getId()) && se.getJobExecutionId().equals(jobExecution.getId())) + .collect(Collectors.toList()); + } } diff --git a/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/StepExecutionDao.java b/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/StepExecutionDao.java index 107b44e717..3b451ba601 100644 --- a/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/StepExecutionDao.java +++ b/spring-batch-core/src/main/java/org/springframework/batch/core/repository/dao/StepExecutionDao.java @@ -16,13 +16,14 @@ package org.springframework.batch.core.repository.dao; -import java.util.Collection; - +import org.springframework.batch.core.BatchStatus; import org.springframework.batch.core.JobExecution; import org.springframework.batch.core.JobInstance; import org.springframework.batch.core.StepExecution; import org.springframework.lang.Nullable; +import java.util.Collection; + public interface StepExecutionDao { /** @@ -86,6 +87,22 @@ default StepExecution getLastStepExecution(JobInstance jobInstance, String stepN */ void addStepExecutions(JobExecution jobExecution); + /** + * Count {@link StepExecution} that match the ids and statuses of them - avoid loading them into memory + * @param stepExecutionIds given step execution ids + * @param matchingBatchStatuses + * @return + */ + int countStepExecutions(Collection stepExecutionIds, Collection matchingBatchStatuses); + + /** + * Get a collection of {@link StepExecution} matching job execution and step execution ids. + * @param jobExecution the parent job execution + * @param stepExecutionIds the step execution ids + * @return collection of {@link StepExecution} + */ + @Nullable + Collection getStepExecutions(JobExecution jobExecution, Collection stepExecutionIds); /** * Counts all the {@link StepExecution} for a given step name. * diff --git a/spring-batch-core/src/test/java/org/springframework/batch/core/launch/support/CommandLineJobRunnerTests.java b/spring-batch-core/src/test/java/org/springframework/batch/core/launch/support/CommandLineJobRunnerTests.java index 8b068fc2bf..34e3a0bbb4 100644 --- a/spring-batch-core/src/test/java/org/springframework/batch/core/launch/support/CommandLineJobRunnerTests.java +++ b/spring-batch-core/src/test/java/org/springframework/batch/core/launch/support/CommandLineJobRunnerTests.java @@ -15,28 +15,10 @@ */ package org.springframework.batch.core.launch.support; -import java.io.IOException; -import java.io.InputStream; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Date; -import java.util.HashSet; -import java.util.List; -import java.util.Properties; -import java.util.Set; - import org.junit.After; import org.junit.Before; import org.junit.Test; - -import org.springframework.batch.core.BatchStatus; -import org.springframework.batch.core.ExitStatus; -import org.springframework.batch.core.Job; -import org.springframework.batch.core.JobExecution; -import org.springframework.batch.core.JobInstance; -import org.springframework.batch.core.JobParameters; -import org.springframework.batch.core.JobParametersBuilder; -import org.springframework.batch.core.StepExecution; +import org.springframework.batch.core.*; import org.springframework.batch.core.converter.DefaultJobParametersConverter; import org.springframework.batch.core.converter.JobParametersConverter; import org.springframework.batch.core.explore.JobExplorer; @@ -49,6 +31,10 @@ import org.springframework.lang.Nullable; import org.springframework.util.ClassUtils; +import java.io.IOException; +import java.io.InputStream; +import java.util.*; + import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -538,6 +524,11 @@ public StepExecution getStepExecution(@Nullable Long jobExecutionId, @Nullable L throw new UnsupportedOperationException(); } + @Override + public int getStepExecutionCount(Collection stepExecutionIds, Collection matchingBatchStatuses) { + throw new UnsupportedOperationException(); + } + @Override public List getJobNames() { throw new UnsupportedOperationException(); @@ -566,6 +557,10 @@ public int getJobInstanceCount(@Nullable String jobName) } } + @Override + public Collection getStepExecutions(Long jobExecutionId, Collection stepExecutionIds) { + throw new UnsupportedOperationException(); + } } public static class StubJobParametersConverter implements JobParametersConverter { diff --git a/spring-batch-integration/src/main/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandler.java b/spring-batch-integration/src/main/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandler.java index bac0462b61..58dc69a228 100644 --- a/spring-batch-integration/src/main/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandler.java +++ b/spring-batch-integration/src/main/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandler.java @@ -1,19 +1,9 @@ package org.springframework.batch.integration.partition; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Iterator; -import java.util.List; -import java.util.Set; -import java.util.concurrent.Callable; -import java.util.concurrent.Future; -import java.util.concurrent.TimeUnit; - -import javax.sql.DataSource; - import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; - +import org.springframework.batch.core.BatchStatus; +import org.springframework.batch.core.Entity; import org.springframework.batch.core.Step; import org.springframework.batch.core.StepExecution; import org.springframework.batch.core.explore.JobExplorer; @@ -37,6 +27,15 @@ import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; +import javax.sql.DataSource; +import java.util.Collection; +import java.util.List; +import java.util.Set; +import java.util.concurrent.Callable; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + /** * A {@link PartitionHandler} that uses {@link MessageChannel} instances to send instructions to remote workers and * receive their responses. The {@link MessageChannel} provides a nice abstraction so that the location of the workers @@ -236,49 +235,31 @@ public Collection handle(StepExecutionSplitter stepExecutionSplit } } - private Collection pollReplies(final StepExecution masterStepExecution, final Set split) throws Exception { - final Collection result = new ArrayList<>(split.size()); - - Callable> callback = new Callable>() { - @Override - public Collection call() throws Exception { + private Collection pollReplies(final StepExecution masterStepExecution, final Set split) throws Exception { + Collection ids = split.stream().map(Entity::getId).collect(Collectors.toList()); - for(Iterator stepExecutionIterator = split.iterator(); stepExecutionIterator.hasNext(); ) { - StepExecution curStepExecution = stepExecutionIterator.next(); - - if(!result.contains(curStepExecution)) { - StepExecution partitionStepExecution = - jobExplorer.getStepExecution(masterStepExecution.getJobExecutionId(), curStepExecution.getId()); - - if(!partitionStepExecution.getStatus().isRunning()) { - result.add(partitionStepExecution); - } - } - } + Callable> callback = () -> { + int runningStepExecutions = jobExplorer.getStepExecutionCount(ids, BatchStatus.RUNNING_STATUSES); + if(runningStepExecutions > 0 && split.size() > 0) { if(logger.isDebugEnabled()) { - logger.debug(String.format("Currently waiting on %s partitions to finish", split.size())); - } - - if(result.size() == split.size()) { - return result; - } - else { - return null; + logger.debug(String.format("Currently waiting on %s out of %s partitions to finish", runningStepExecutions, split.size())); } + return null; + } else { + return jobExplorer.getStepExecutions(masterStepExecution.getJobExecutionId(), ids); } }; - Poller> poller = new DirectPoller<>(pollInterval); - Future> resultsFuture = poller.poll(callback); + Poller> poller = new DirectPoller<>(pollInterval); + Future> resultsFuture = poller.poll(callback); - if(timeout >= 0) { - return resultsFuture.get(timeout, TimeUnit.MILLISECONDS); - } - else { - return resultsFuture.get(); - } - } + if(timeout >= 0) { + return resultsFuture.get(timeout, TimeUnit.MILLISECONDS); + } else { + return resultsFuture.get(); + } + } private Collection receiveReplies(PollableChannel currentReplyChannel) { @SuppressWarnings("unchecked") diff --git a/spring-batch-integration/src/test/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandlerTests.java b/spring-batch-integration/src/test/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandlerTests.java index fd9170412f..fa0d83817c 100644 --- a/spring-batch-integration/src/test/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandlerTests.java +++ b/spring-batch-integration/src/test/java/org/springframework/batch/integration/partition/MessageChannelPartitionHandlerTests.java @@ -1,12 +1,6 @@ package org.springframework.batch.integration.partition; -import java.util.Collection; -import java.util.Collections; -import java.util.HashSet; -import java.util.concurrent.TimeoutException; - import org.junit.Test; - import org.springframework.batch.core.BatchStatus; import org.springframework.batch.core.JobExecution; import org.springframework.batch.core.JobParameters; @@ -18,15 +12,16 @@ import org.springframework.messaging.Message; import org.springframework.messaging.PollableChannel; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.concurrent.TimeoutException; + +import static org.junit.Assert.*; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; /** * @@ -154,8 +149,8 @@ public void testHandleWithJobRepositoryPolling() throws Exception { stepExecutions.add(partition2); stepExecutions.add(partition3); when(stepExecutionSplitter.split(any(StepExecution.class), eq(1))).thenReturn(stepExecutions); - when(jobExplorer.getStepExecution(eq(5L), any(Long.class))).thenReturn(partition2, partition1, partition3, partition3, partition3, partition3, partition4); - + when(jobExplorer.getStepExecutionCount(any(), any())).thenReturn(3, 2, 0); + when(jobExplorer.getStepExecutions(eq(5L), any())).thenReturn(Arrays.asList(partition1, partition2, partition4)); //set messageChannelPartitionHandler.setMessagingOperations(operations); messageChannelPartitionHandler.setJobExplorer(jobExplorer); @@ -198,7 +193,8 @@ public void testHandleWithJobRepositoryPollingTimeout() throws Exception { stepExecutions.add(partition2); stepExecutions.add(partition3); when(stepExecutionSplitter.split(any(StepExecution.class), eq(1))).thenReturn(stepExecutions); - when(jobExplorer.getStepExecution(eq(5L), any(Long.class))).thenReturn(partition2, partition1, partition3); + when(jobExplorer.getStepExecutionCount(any(), any())).thenReturn(2); + when(jobExplorer.getStepExecutions(eq(5L), any())).thenReturn(Arrays.asList(partition1, partition2, partition3)); //set messageChannelPartitionHandler.setMessagingOperations(operations);