diff --git a/docs/documentation/advanced/postgresql.md b/docs/documentation/advanced/postgresql.md new file mode 100644 index 000000000..fb0e80f57 --- /dev/null +++ b/docs/documentation/advanced/postgresql.md @@ -0,0 +1,50 @@ +# PostgreSQL + +By default conductor runs with an in-memory Redis mock. However, you +can run Conductor against PostgreSQL which provides workflow management, queues and indexing. +There are a number of configuration options that enable you to use more or less of PostgreSQL functionality for your needs. +It has the benefit of requiring fewer moving parts for the infrastructure, but does not scale as well to handle high volumes of workflows. +You should benchmark Conductor with Postgres against your specific workload to be sure. + + +## Configuration + +To enable the basic use of PostgreSQL to manage workflow metadata, set the following property: + +```properties +conductor.db.type=postgres +spring.datasource.url=jdbc:postgresql://postgres:5432/conductor +spring.datasource.username=conductor +spring.datasource.password=password +# optional +conductor.postgres.schema=public +``` + +To also use PostgreSQL for queues, you can set: + +```properties +conductor.queue.type=postgres +``` + +You can also use PostgreSQL to index workflows, configure this as follows: + +```properties +conductor.indexing.enabled=true +conductor.indexing.type=postgres +conductor.elasticsearch.version=0 +``` + +By default, Conductor writes the latest poll for tasks to the database so that it can be used to determine which tasks and domains are active. This creates a lot of database traffic. +To avoid some of this traffic you can configure the PollDataDAO with a write buffer so that it only flushes every x milliseconds. If you keep this value around 5s then there should be no impact on behaviour. Conductor uses a default duration of 10s to determine whether a queue for a domain is active or not (also configurable using `conductor.app.activeWorkerLastPollTimeout`) so this will ensure that there is plenty of time for the data to get to the database to be shared by other instances: + +```properties +# Flush the data every 5 seconds +conductor.postgres.pollDataFlushInterval=5000 +``` + +You can also configure a duration when the cached poll data will be considered stale. This means that the PollDataDAO will try to use the cached data, but if it is older than the configured period, it will check against the database. There is no downside to setting this as if this Conductor node already can confirm that the queue is active then there's no need to go to the database. If the record in the cache is out of date, then we still go to the database to check. + +```properties +# Data older than 5 seconds is considered stale +conductor.postgres.pollDataCacheValidityPeriod=5000 +``` diff --git a/mkdocs.yml b/mkdocs.yml index f94c96a8b..96f414fe5 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -101,6 +101,7 @@ nav: - documentation/advanced/azureblob-storage.md - documentation/advanced/externalpayloadstorage.md - documentation/advanced/redis.md + - documentation/advanced/postgresql.md - Client SDKs: - documentation/clientsdks/index.md - documentation/clientsdks/java-sdk.md diff --git a/postgres-persistence/src/main/java/com/netflix/conductor/postgres/config/PostgresConfiguration.java b/postgres-persistence/src/main/java/com/netflix/conductor/postgres/config/PostgresConfiguration.java index ecb62941e..25eb5ad74 100644 --- a/postgres-persistence/src/main/java/com/netflix/conductor/postgres/config/PostgresConfiguration.java +++ b/postgres-persistence/src/main/java/com/netflix/conductor/postgres/config/PostgresConfiguration.java @@ -30,10 +30,7 @@ import org.springframework.retry.policy.SimpleRetryPolicy; import org.springframework.retry.support.RetryTemplate; -import com.netflix.conductor.postgres.dao.PostgresExecutionDAO; -import com.netflix.conductor.postgres.dao.PostgresIndexDAO; -import com.netflix.conductor.postgres.dao.PostgresMetadataDAO; -import com.netflix.conductor.postgres.dao.PostgresQueueDAO; +import com.netflix.conductor.postgres.dao.*; import com.fasterxml.jackson.databind.ObjectMapper; import jakarta.annotation.*; @@ -85,6 +82,15 @@ public PostgresExecutionDAO postgresExecutionDAO( return new PostgresExecutionDAO(retryTemplate, objectMapper, dataSource); } + @Bean + @DependsOn({"flywayForPrimaryDb"}) + public PostgresPollDataDAO postgresPollDataDAO( + @Qualifier("postgresRetryTemplate") RetryTemplate retryTemplate, + ObjectMapper objectMapper, + PostgresProperties properties) { + return new PostgresPollDataDAO(retryTemplate, objectMapper, dataSource, properties); + } + @Bean @DependsOn({"flywayForPrimaryDb"}) public PostgresQueueDAO postgresQueueDAO( diff --git a/postgres-persistence/src/main/java/com/netflix/conductor/postgres/config/PostgresProperties.java b/postgres-persistence/src/main/java/com/netflix/conductor/postgres/config/PostgresProperties.java index 68108c8ff..4f4338d37 100644 --- a/postgres-persistence/src/main/java/com/netflix/conductor/postgres/config/PostgresProperties.java +++ b/postgres-persistence/src/main/java/com/netflix/conductor/postgres/config/PostgresProperties.java @@ -27,6 +27,12 @@ public class PostgresProperties { private Integer deadlockRetryMax = 3; + @DurationUnit(ChronoUnit.MILLIS) + private Duration pollDataFlushInterval = Duration.ofMillis(0); + + @DurationUnit(ChronoUnit.MILLIS) + private Duration pollDataCacheValidityPeriod = Duration.ofMillis(0); + public String schema = "public"; public boolean allowFullTextQueries = true; @@ -94,4 +100,20 @@ public int getAsyncMaxPoolSize() { public void setAsyncMaxPoolSize(int asyncMaxPoolSize) { this.asyncMaxPoolSize = asyncMaxPoolSize; } + + public Duration getPollDataFlushInterval() { + return pollDataFlushInterval; + } + + public void setPollDataFlushInterval(Duration interval) { + this.pollDataFlushInterval = interval; + } + + public Duration getPollDataCacheValidityPeriod() { + return pollDataCacheValidityPeriod; + } + + public void setPollDataCacheValidityPeriod(Duration period) { + this.pollDataCacheValidityPeriod = period; + } } diff --git a/postgres-persistence/src/main/java/com/netflix/conductor/postgres/dao/PostgresExecutionDAO.java b/postgres-persistence/src/main/java/com/netflix/conductor/postgres/dao/PostgresExecutionDAO.java index cf6afae4e..5ac642e93 100644 --- a/postgres-persistence/src/main/java/com/netflix/conductor/postgres/dao/PostgresExecutionDAO.java +++ b/postgres-persistence/src/main/java/com/netflix/conductor/postgres/dao/PostgresExecutionDAO.java @@ -14,7 +14,6 @@ import java.sql.Connection; import java.sql.Date; -import java.sql.SQLException; import java.text.SimpleDateFormat; import java.util.*; import java.util.concurrent.Executors; @@ -27,12 +26,10 @@ import org.springframework.retry.support.RetryTemplate; import com.netflix.conductor.common.metadata.events.EventExecution; -import com.netflix.conductor.common.metadata.tasks.PollData; import com.netflix.conductor.common.metadata.tasks.TaskDef; import com.netflix.conductor.core.exception.NonTransientException; import com.netflix.conductor.dao.ConcurrentExecutionLimitDAO; import com.netflix.conductor.dao.ExecutionDAO; -import com.netflix.conductor.dao.PollDataDAO; import com.netflix.conductor.dao.RateLimitingDAO; import com.netflix.conductor.metrics.Monitors; import com.netflix.conductor.model.TaskModel; @@ -47,7 +44,7 @@ import jakarta.annotation.*; public class PostgresExecutionDAO extends PostgresBaseDAO - implements ExecutionDAO, RateLimitingDAO, PollDataDAO, ConcurrentExecutionLimitDAO { + implements ExecutionDAO, RateLimitingDAO, ConcurrentExecutionLimitDAO { private final ScheduledExecutorService scheduledExecutorService; @@ -558,45 +555,6 @@ public List getEventExecutions( } } - @Override - public void updateLastPollData(String taskDefName, String domain, String workerId) { - Preconditions.checkNotNull(taskDefName, "taskDefName name cannot be null"); - PollData pollData = new PollData(taskDefName, domain, workerId, System.currentTimeMillis()); - String effectiveDomain = (domain == null) ? "DEFAULT" : domain; - withTransaction(tx -> insertOrUpdatePollData(tx, pollData, effectiveDomain)); - } - - @Override - public PollData getPollData(String taskDefName, String domain) { - Preconditions.checkNotNull(taskDefName, "taskDefName name cannot be null"); - String effectiveDomain = (domain == null) ? "DEFAULT" : domain; - return getWithRetriedTransactions(tx -> readPollData(tx, taskDefName, effectiveDomain)); - } - - @Override - public List getPollData(String taskDefName) { - Preconditions.checkNotNull(taskDefName, "taskDefName name cannot be null"); - return readAllPollData(taskDefName); - } - - @Override - public List getAllPollData() { - try (Connection tx = dataSource.getConnection()) { - boolean previousAutoCommitMode = tx.getAutoCommit(); - tx.setAutoCommit(true); - try { - String GET_ALL_POLL_DATA = "SELECT json_data FROM poll_data ORDER BY queue_name"; - return query(tx, GET_ALL_POLL_DATA, q -> q.executeAndFetch(PollData.class)); - } catch (Throwable th) { - throw new NonTransientException(th.getMessage(), th); - } finally { - tx.setAutoCommit(previousAutoCommitMode); - } - } catch (SQLException ex) { - throw new NonTransientException(ex.getMessage(), ex); - } - } - private List getTasks(Connection connection, List taskIds) { if (taskIds.isEmpty()) { return Lists.newArrayList(); @@ -1027,56 +985,6 @@ private EventExecution readEventExecution( .executeAndFetchFirst(EventExecution.class)); } - private void insertOrUpdatePollData(Connection connection, PollData pollData, String domain) { - /* - * Most times the row will be updated so let's try the update first. This used to be an 'INSERT/ON CONFLICT do update' sql statement. The problem with that - * is that if we try the INSERT first, the sequence will be increased even if the ON CONFLICT happens. Since polling happens *a lot*, the sequence can increase - * dramatically even though it won't be used. - */ - String UPDATE_POLL_DATA = - "UPDATE poll_data SET json_data=?, modified_on=CURRENT_TIMESTAMP WHERE queue_name=? AND domain=?"; - int rowsUpdated = - query( - connection, - UPDATE_POLL_DATA, - q -> - q.addJsonParameter(pollData) - .addParameter(pollData.getQueueName()) - .addParameter(domain) - .executeUpdate()); - - if (rowsUpdated == 0) { - String INSERT_POLL_DATA = - "INSERT INTO poll_data (queue_name, domain, json_data, modified_on) VALUES (?, ?, ?, CURRENT_TIMESTAMP) ON CONFLICT (queue_name,domain) DO UPDATE SET json_data=excluded.json_data, modified_on=excluded.modified_on"; - execute( - connection, - INSERT_POLL_DATA, - q -> - q.addParameter(pollData.getQueueName()) - .addParameter(domain) - .addJsonParameter(pollData) - .executeUpdate()); - } - } - - private PollData readPollData(Connection connection, String queueName, String domain) { - String GET_POLL_DATA = - "SELECT json_data FROM poll_data WHERE queue_name = ? AND domain = ?"; - return query( - connection, - GET_POLL_DATA, - q -> - q.addParameter(queueName) - .addParameter(domain) - .executeAndFetchFirst(PollData.class)); - } - - private List readAllPollData(String queueName) { - String GET_ALL_POLL_DATA = "SELECT json_data FROM poll_data WHERE queue_name = ?"; - return queryWithTransaction( - GET_ALL_POLL_DATA, q -> q.addParameter(queueName).executeAndFetch(PollData.class)); - } - private List findAllTasksInProgressInOrderOfArrival(TaskModel task, int limit) { String GET_IN_PROGRESS_TASKS_WITH_LIMIT = "SELECT task_id FROM task_in_progress WHERE task_def_name = ? ORDER BY created_on LIMIT ?"; diff --git a/postgres-persistence/src/main/java/com/netflix/conductor/postgres/dao/PostgresPollDataDAO.java b/postgres-persistence/src/main/java/com/netflix/conductor/postgres/dao/PostgresPollDataDAO.java new file mode 100644 index 000000000..d79bdc5da --- /dev/null +++ b/postgres-persistence/src/main/java/com/netflix/conductor/postgres/dao/PostgresPollDataDAO.java @@ -0,0 +1,218 @@ +/* + * Copyright 2024 Conductor Authors. + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package com.netflix.conductor.postgres.dao; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import javax.sql.DataSource; + +import org.springframework.retry.support.RetryTemplate; + +import com.netflix.conductor.common.metadata.tasks.PollData; +import com.netflix.conductor.core.exception.NonTransientException; +import com.netflix.conductor.dao.PollDataDAO; +import com.netflix.conductor.postgres.config.PostgresProperties; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Preconditions; +import jakarta.annotation.PostConstruct; + +public class PostgresPollDataDAO extends PostgresBaseDAO implements PollDataDAO { + + private ConcurrentHashMap> pollDataCache = + new ConcurrentHashMap<>(); + + private long pollDataFlushInterval; + + private long cacheValidityPeriod; + + private long lastFlushTime = 0; + + private boolean useReadCache; + + public PostgresPollDataDAO( + RetryTemplate retryTemplate, + ObjectMapper objectMapper, + DataSource dataSource, + PostgresProperties properties) { + super(retryTemplate, objectMapper, dataSource); + this.pollDataFlushInterval = properties.getPollDataFlushInterval().toMillis(); + if (this.pollDataFlushInterval > 0) { + logger.info("Using Postgres pollData write cache"); + } + this.cacheValidityPeriod = properties.getPollDataCacheValidityPeriod().toMillis(); + this.useReadCache = cacheValidityPeriod > 0; + if (this.useReadCache) { + logger.info("Using Postgres pollData read cache"); + } + } + + @PostConstruct + public void schedulePollDataRefresh() { + if (pollDataFlushInterval > 0) { + Executors.newSingleThreadScheduledExecutor() + .scheduleWithFixedDelay( + this::flushData, + pollDataFlushInterval, + pollDataFlushInterval, + TimeUnit.MILLISECONDS); + } + } + + @Override + public void updateLastPollData(String taskDefName, String domain, String workerId) { + Preconditions.checkNotNull(taskDefName, "taskDefName name cannot be null"); + + String effectiveDomain = domain == null ? "DEFAULT" : domain; + PollData pollData = new PollData(taskDefName, domain, workerId, System.currentTimeMillis()); + + if (pollDataFlushInterval > 0) { + ConcurrentHashMap domainPollData = pollDataCache.get(taskDefName); + if (domainPollData == null) { + domainPollData = new ConcurrentHashMap<>(); + pollDataCache.put(taskDefName, domainPollData); + } + domainPollData.put(effectiveDomain, pollData); + } else { + withTransaction(tx -> insertOrUpdatePollData(tx, pollData, effectiveDomain)); + } + } + + @Override + public PollData getPollData(String taskDefName, String domain) { + PollData result; + + if (useReadCache) { + ConcurrentHashMap domainPollData = pollDataCache.get(taskDefName); + if (domainPollData == null) { + return null; + } + result = domainPollData.get(domain == null ? "DEFAULT" : domain); + long diffSeconds = System.currentTimeMillis() - result.getLastPollTime(); + if (diffSeconds < cacheValidityPeriod) { + return result; + } + } + + Preconditions.checkNotNull(taskDefName, "taskDefName name cannot be null"); + String effectiveDomain = (domain == null) ? "DEFAULT" : domain; + return getWithRetriedTransactions(tx -> readPollData(tx, taskDefName, effectiveDomain)); + } + + @Override + public List getPollData(String taskDefName) { + Preconditions.checkNotNull(taskDefName, "taskDefName name cannot be null"); + return readAllPollData(taskDefName); + } + + @Override + public List getAllPollData() { + try (Connection tx = dataSource.getConnection()) { + boolean previousAutoCommitMode = tx.getAutoCommit(); + tx.setAutoCommit(true); + try { + String GET_ALL_POLL_DATA = "SELECT json_data FROM poll_data ORDER BY queue_name"; + return query(tx, GET_ALL_POLL_DATA, q -> q.executeAndFetch(PollData.class)); + } catch (Throwable th) { + throw new NonTransientException(th.getMessage(), th); + } finally { + tx.setAutoCommit(previousAutoCommitMode); + } + } catch (SQLException ex) { + throw new NonTransientException(ex.getMessage(), ex); + } + } + + public long getLastFlushTime() { + return lastFlushTime; + } + + private void insertOrUpdatePollData(Connection connection, PollData pollData, String domain) { + try { + /* + * Most times the row will be updated so let's try the update first. This used to be an 'INSERT/ON CONFLICT do update' sql statement. The problem with that + * is that if we try the INSERT first, the sequence will be increased even if the ON CONFLICT happens. Since polling happens *a lot*, the sequence can increase + * dramatically even though it won't be used. + */ + String UPDATE_POLL_DATA = + "UPDATE poll_data SET json_data=?, modified_on=CURRENT_TIMESTAMP WHERE queue_name=? AND domain=?"; + int rowsUpdated = + query( + connection, + UPDATE_POLL_DATA, + q -> + q.addJsonParameter(pollData) + .addParameter(pollData.getQueueName()) + .addParameter(domain) + .executeUpdate()); + + if (rowsUpdated == 0) { + String INSERT_POLL_DATA = + "INSERT INTO poll_data (queue_name, domain, json_data, modified_on) VALUES (?, ?, ?, CURRENT_TIMESTAMP) ON CONFLICT (queue_name,domain) DO UPDATE SET json_data=excluded.json_data, modified_on=excluded.modified_on"; + execute( + connection, + INSERT_POLL_DATA, + q -> + q.addParameter(pollData.getQueueName()) + .addParameter(domain) + .addJsonParameter(pollData) + .executeUpdate()); + } + } catch (NonTransientException e) { + if (!e.getMessage().startsWith("ERROR: lastPollTime cannot be set to a lower value")) { + throw e; + } + } + } + + private PollData readPollData(Connection connection, String queueName, String domain) { + String GET_POLL_DATA = + "SELECT json_data FROM poll_data WHERE queue_name = ? AND domain = ?"; + return query( + connection, + GET_POLL_DATA, + q -> + q.addParameter(queueName) + .addParameter(domain) + .executeAndFetchFirst(PollData.class)); + } + + private List readAllPollData(String queueName) { + String GET_ALL_POLL_DATA = "SELECT json_data FROM poll_data WHERE queue_name = ?"; + return queryWithTransaction( + GET_ALL_POLL_DATA, q -> q.addParameter(queueName).executeAndFetch(PollData.class)); + } + + private void flushData() { + try { + for (Map.Entry> queue : + pollDataCache.entrySet()) { + for (Map.Entry domain : queue.getValue().entrySet()) { + withTransaction( + tx -> { + insertOrUpdatePollData(tx, domain.getValue(), domain.getKey()); + }); + } + } + lastFlushTime = System.currentTimeMillis(); + } catch (Exception e) { + logger.error("Postgres pollData cache flush failed ", e); + } + } +} diff --git a/postgres-persistence/src/main/resources/db/migration_postgres/V10__poll_data_check.sql b/postgres-persistence/src/main/resources/db/migration_postgres/V10__poll_data_check.sql new file mode 100644 index 000000000..8bdbebe7c --- /dev/null +++ b/postgres-persistence/src/main/resources/db/migration_postgres/V10__poll_data_check.sql @@ -0,0 +1,13 @@ +CREATE OR REPLACE FUNCTION poll_data_update_check () + RETURNS TRIGGER + AS $$ +BEGIN + IF(NEW.json_data::json ->> 'lastPollTime')::BIGINT < (OLD.json_data::json ->> 'lastPollTime')::BIGINT THEN + RAISE EXCEPTION 'lastPollTime cannot be set to a lower value'; + END IF; + RETURN NEW; +END; +$$ +LANGUAGE plpgsql; + +CREATE TRIGGER poll_data_update_check_trigger BEFORE UPDATE ON poll_data FOR EACH ROW EXECUTE FUNCTION poll_data_update_check (); \ No newline at end of file diff --git a/postgres-persistence/src/test/java/com/netflix/conductor/postgres/dao/PostgresPollDataDAOCacheTest.java b/postgres-persistence/src/test/java/com/netflix/conductor/postgres/dao/PostgresPollDataDAOCacheTest.java new file mode 100644 index 000000000..8b1c41efe --- /dev/null +++ b/postgres-persistence/src/test/java/com/netflix/conductor/postgres/dao/PostgresPollDataDAOCacheTest.java @@ -0,0 +1,155 @@ +/* + * Copyright 2023 Conductor Authors. + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package com.netflix.conductor.postgres.dao; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.List; +import java.util.Map; + +import javax.sql.DataSource; + +import org.flywaydb.core.Flyway; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.boot.autoconfigure.flyway.FlywayAutoConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.TestPropertySource; +import org.springframework.test.context.junit4.SpringRunner; + +import com.netflix.conductor.common.config.TestObjectMapperConfiguration; +import com.netflix.conductor.common.metadata.tasks.PollData; +import com.netflix.conductor.dao.PollDataDAO; +import com.netflix.conductor.postgres.config.PostgresConfiguration; +import com.netflix.conductor.postgres.util.Query; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +import static org.junit.Assert.*; + +@ContextConfiguration( + classes = { + TestObjectMapperConfiguration.class, + PostgresConfiguration.class, + FlywayAutoConfiguration.class + }) +@RunWith(SpringRunner.class) +@TestPropertySource( + properties = { + "conductor.app.asyncIndexingEnabled=false", + "conductor.elasticsearch.version=0", + "conductor.indexing.type=postgres", + "conductor.postgres.pollDataFlushInterval=200", + "conductor.postgres.pollDataCacheValidityPeriod=100", + "spring.flyway.clean-disabled=false" + }) +@SpringBootTest +public class PostgresPollDataDAOCacheTest { + + @Autowired private PollDataDAO pollDataDAO; + + @Autowired private ObjectMapper objectMapper; + + @Qualifier("dataSource") + @Autowired + private DataSource dataSource; + + @Autowired Flyway flyway; + + // clean the database between tests. + @Before + public void before() { + try (Connection conn = dataSource.getConnection()) { + conn.setAutoCommit(true); + conn.prepareStatement("truncate table poll_data").executeUpdate(); + } catch (Exception e) { + e.printStackTrace(); + throw new RuntimeException(e); + } + } + + private List> queryDb(String query) throws SQLException { + try (Connection c = dataSource.getConnection()) { + try (Query q = new Query(objectMapper, c, query)) { + return q.executeAndFetchMap(); + } + } + } + + private void waitForCacheFlush() throws InterruptedException { + long startTime = System.currentTimeMillis(); + long lastDiff = + System.currentTimeMillis() - ((PostgresPollDataDAO) pollDataDAO).getLastFlushTime(); + + if (lastDiff == 0) { + return; + } + + while (true) { + long currentDiff = + System.currentTimeMillis() + - ((PostgresPollDataDAO) pollDataDAO).getLastFlushTime(); + + if (currentDiff < lastDiff || System.currentTimeMillis() - startTime > 1000) { + return; + } + + lastDiff = currentDiff; + + Thread.sleep(1); + } + } + + @Test + public void cacheFlushTest() + throws SQLException, JsonProcessingException, InterruptedException { + waitForCacheFlush(); + pollDataDAO.updateLastPollData("dummy-task", "dummy-domain", "dummy-worker-id"); + + List> records = + queryDb("SELECT * FROM poll_data WHERE queue_name = 'dummy-task'"); + + assertEquals("Poll data records returned", 0, records.size()); + + waitForCacheFlush(); + + records = queryDb("SELECT * FROM poll_data WHERE queue_name = 'dummy-task'"); + assertEquals("Poll data records returned", 1, records.size()); + assertEquals("Wrong domain set", "dummy-domain", records.get(0).get("domain")); + + JsonNode jsonData = objectMapper.readTree(records.get(0).get("json_data").toString()); + assertEquals( + "Poll data is incorrect", "dummy-worker-id", jsonData.get("workerId").asText()); + } + + @Test + public void getCachedPollDataByDomainTest() throws InterruptedException, SQLException { + waitForCacheFlush(); + pollDataDAO.updateLastPollData("dummy-task2", "dummy-domain2", "dummy-worker-id2"); + + PollData pollData = pollDataDAO.getPollData("dummy-task2", "dummy-domain2"); + assertNotNull("pollData is null", pollData); + assertEquals("dummy-worker-id2", pollData.getWorkerId()); + + List> records = + queryDb("SELECT * FROM poll_data WHERE queue_name = 'dummy-task2'"); + + assertEquals("Poll data records returned", 0, records.size()); + } +} diff --git a/postgres-persistence/src/test/java/com/netflix/conductor/postgres/dao/PostgresPollDataDAONoCacheTest.java b/postgres-persistence/src/test/java/com/netflix/conductor/postgres/dao/PostgresPollDataDAONoCacheTest.java new file mode 100644 index 000000000..527bf3943 --- /dev/null +++ b/postgres-persistence/src/test/java/com/netflix/conductor/postgres/dao/PostgresPollDataDAONoCacheTest.java @@ -0,0 +1,200 @@ +/* + * Copyright 2023 Conductor Authors. + *

+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + *

+ * http://www.apache.org/licenses/LICENSE-2.0 + *

+ * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package com.netflix.conductor.postgres.dao; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.*; +import java.util.stream.Collectors; + +import javax.sql.DataSource; + +import org.flywaydb.core.Flyway; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.boot.autoconfigure.flyway.FlywayAutoConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.TestPropertySource; +import org.springframework.test.context.junit4.SpringRunner; + +import com.netflix.conductor.common.config.TestObjectMapperConfiguration; +import com.netflix.conductor.common.metadata.tasks.PollData; +import com.netflix.conductor.dao.PollDataDAO; +import com.netflix.conductor.postgres.config.PostgresConfiguration; +import com.netflix.conductor.postgres.util.Query; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +import static org.junit.Assert.*; + +@ContextConfiguration( + classes = { + TestObjectMapperConfiguration.class, + PostgresConfiguration.class, + FlywayAutoConfiguration.class + }) +@RunWith(SpringRunner.class) +@TestPropertySource( + properties = { + "conductor.app.asyncIndexingEnabled=false", + "conductor.elasticsearch.version=0", + "conductor.indexing.type=postgres", + "conductor.postgres.pollDataFlushInterval=0", + "conductor.postgres.pollDataCacheValidityPeriod=0", + "spring.flyway.clean-disabled=false" + }) +@SpringBootTest +public class PostgresPollDataDAONoCacheTest { + + @Autowired private PollDataDAO pollDataDAO; + + @Autowired private ObjectMapper objectMapper; + + @Qualifier("dataSource") + @Autowired + private DataSource dataSource; + + @Autowired Flyway flyway; + + // clean the database between tests. + @Before + public void before() { + try (Connection conn = dataSource.getConnection()) { + conn.setAutoCommit(true); + conn.prepareStatement("truncate table poll_data").executeUpdate(); + } catch (Exception e) { + e.printStackTrace(); + throw new RuntimeException(e); + } + } + + private List> queryDb(String query) throws SQLException { + try (Connection c = dataSource.getConnection()) { + try (Query q = new Query(objectMapper, c, query)) { + return q.executeAndFetchMap(); + } + } + } + + @Test + public void updateLastPollDataTest() throws SQLException, JsonProcessingException { + pollDataDAO.updateLastPollData("dummy-task", "dummy-domain", "dummy-worker-id"); + + List> records = + queryDb("SELECT * FROM poll_data WHERE queue_name = 'dummy-task'"); + + assertEquals("More than one poll data records returned", 1, records.size()); + assertEquals("Wrong domain set", "dummy-domain", records.get(0).get("domain")); + + JsonNode jsonData = objectMapper.readTree(records.get(0).get("json_data").toString()); + assertEquals( + "Poll data is incorrect", "dummy-worker-id", jsonData.get("workerId").asText()); + } + + @Test + public void updateLastPollDataNullDomainTest() throws SQLException, JsonProcessingException { + pollDataDAO.updateLastPollData("dummy-task", null, "dummy-worker-id"); + + List> records = + queryDb("SELECT * FROM poll_data WHERE queue_name = 'dummy-task'"); + + assertEquals("More than one poll data records returned", 1, records.size()); + assertEquals("Wrong domain set", "DEFAULT", records.get(0).get("domain")); + + JsonNode jsonData = objectMapper.readTree(records.get(0).get("json_data").toString()); + assertEquals( + "Poll data is incorrect", "dummy-worker-id", jsonData.get("workerId").asText()); + } + + @Test + public void getPollDataByDomainTest() { + pollDataDAO.updateLastPollData("dummy-task", "dummy-domain", "dummy-worker-id"); + + PollData pollData = pollDataDAO.getPollData("dummy-task", "dummy-domain"); + assertEquals("dummy-task", pollData.getQueueName()); + assertEquals("dummy-domain", pollData.getDomain()); + assertEquals("dummy-worker-id", pollData.getWorkerId()); + } + + @Test + public void getPollDataByNullDomainTest() { + pollDataDAO.updateLastPollData("dummy-task", null, "dummy-worker-id"); + + PollData pollData = pollDataDAO.getPollData("dummy-task", null); + assertEquals("dummy-task", pollData.getQueueName()); + assertNull(pollData.getDomain()); + assertEquals("dummy-worker-id", pollData.getWorkerId()); + } + + @Test + public void getPollDataByTaskTest() { + pollDataDAO.updateLastPollData("dummy-task1", "domain1", "dummy-worker-id1"); + pollDataDAO.updateLastPollData("dummy-task1", "domain2", "dummy-worker-id2"); + pollDataDAO.updateLastPollData("dummy-task1", null, "dummy-worker-id3"); + pollDataDAO.updateLastPollData("dummy-task2", "domain2", "dummy-worker-id4"); + + List pollData = pollDataDAO.getPollData("dummy-task1"); + assertEquals("Wrong number of records returned", 3, pollData.size()); + + List queueNames = + pollData.stream().map(x -> x.getQueueName()).collect(Collectors.toList()); + assertEquals(3, Collections.frequency(queueNames, "dummy-task1")); + + List domains = + pollData.stream().map(x -> x.getDomain()).collect(Collectors.toList()); + assertTrue(domains.contains("domain1")); + assertTrue(domains.contains("domain2")); + assertTrue(domains.contains(null)); + + List workerIds = + pollData.stream().map(x -> x.getWorkerId()).collect(Collectors.toList()); + assertTrue(workerIds.contains("dummy-worker-id1")); + assertTrue(workerIds.contains("dummy-worker-id2")); + assertTrue(workerIds.contains("dummy-worker-id3")); + } + + @Test + public void getAllPollDataTest() { + pollDataDAO.updateLastPollData("dummy-task1", "domain1", "dummy-worker-id1"); + pollDataDAO.updateLastPollData("dummy-task1", "domain2", "dummy-worker-id2"); + pollDataDAO.updateLastPollData("dummy-task1", null, "dummy-worker-id3"); + pollDataDAO.updateLastPollData("dummy-task2", "domain2", "dummy-worker-id4"); + + List pollData = pollDataDAO.getAllPollData(); + assertEquals("Wrong number of records returned", 4, pollData.size()); + + List queueNames = + pollData.stream().map(x -> x.getQueueName()).collect(Collectors.toList()); + assertEquals(3, Collections.frequency(queueNames, "dummy-task1")); + assertEquals(1, Collections.frequency(queueNames, "dummy-task2")); + + List domains = + pollData.stream().map(x -> x.getDomain()).collect(Collectors.toList()); + assertEquals(1, Collections.frequency(domains, "domain1")); + assertEquals(2, Collections.frequency(domains, "domain2")); + assertEquals(1, Collections.frequency(domains, null)); + + List workerIds = + pollData.stream().map(x -> x.getWorkerId()).collect(Collectors.toList()); + assertTrue(workerIds.contains("dummy-worker-id1")); + assertTrue(workerIds.contains("dummy-worker-id2")); + assertTrue(workerIds.contains("dummy-worker-id3")); + assertTrue(workerIds.contains("dummy-worker-id4")); + } +}