From 68580e5b31eeffa01f8d89c0612e7dad0f86ba58 Mon Sep 17 00:00:00 2001 From: Enrico Olivelli Date: Mon, 6 May 2024 13:58:14 +0200 Subject: [PATCH] Improve JMSPriority handling on the consumer side (#139) --- pulsar-jms-integration-tests/pom.xml | 4 +- pulsar-jms/pom.xml | 8 +- ...agePriorityGrowableArrayBlockingQueue.java | 202 ++++++++++++------ .../pulsar/jms/PulsarConnectionFactory.java | 45 +++- .../oss/pulsar/jms/PulsarMessage.java | 15 +- ...riorityGrowableArrayBlockingQueueTest.java | 77 +++++++ 6 files changed, 272 insertions(+), 79 deletions(-) create mode 100644 pulsar-jms/src/test/java/com/datastax/oss/pulsar/jms/MessagePriorityGrowableArrayBlockingQueueTest.java diff --git a/pulsar-jms-integration-tests/pom.xml b/pulsar-jms-integration-tests/pom.xml index 8ac6fce5..dff7fc1c 100644 --- a/pulsar-jms-integration-tests/pom.xml +++ b/pulsar-jms-integration-tests/pom.xml @@ -103,8 +103,8 @@ copy filters - - + + diff --git a/pulsar-jms/pom.xml b/pulsar-jms/pom.xml index 00136c60..d6f93692 100644 --- a/pulsar-jms/pom.xml +++ b/pulsar-jms/pom.xml @@ -128,10 +128,10 @@ copy filters - - - - + + + + diff --git a/pulsar-jms/src/main/java/com/datastax/oss/pulsar/jms/MessagePriorityGrowableArrayBlockingQueue.java b/pulsar-jms/src/main/java/com/datastax/oss/pulsar/jms/MessagePriorityGrowableArrayBlockingQueue.java index 6e8b4f6d..5398d60f 100644 --- a/pulsar-jms/src/main/java/com/datastax/oss/pulsar/jms/MessagePriorityGrowableArrayBlockingQueue.java +++ b/pulsar-jms/src/main/java/com/datastax/oss/pulsar/jms/MessagePriorityGrowableArrayBlockingQueue.java @@ -19,147 +19,225 @@ import java.util.concurrent.PriorityBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; +import lombok.AllArgsConstructor; +import lombok.extern.slf4j.Slf4j; import org.apache.pulsar.client.api.Message; import org.apache.pulsar.common.util.collections.GrowableArrayBlockingQueue; -public class MessagePriorityGrowableArrayBlockingQueue extends GrowableArrayBlockingQueue { +@Slf4j +public class MessagePriorityGrowableArrayBlockingQueue + extends GrowableArrayBlockingQueue> { + private final PriorityBlockingQueue> queue; + private final AtomicBoolean terminated = new AtomicBoolean(false); + + private volatile Consumer> itemAfterTerminatedHandler; + private final AtomicInteger[] numberMessagesByPriority = new AtomicInteger[10]; - static int getPriority(Message m) { - Integer priority = PulsarMessage.readJMSPriority(m); - return priority == null ? PulsarMessage.DEFAULT_PRIORITY : priority; + @AllArgsConstructor + private static final class MessageWithPriority { + final int priority; + final Message message; } - private final PriorityBlockingQueue queue; - private final AtomicBoolean terminated = new AtomicBoolean(false); + private static final Comparator> comparator = + (o1, o2) -> { + // ORDER BY priority DESC, messageId ASC + int priority1 = o1.priority; + int priority2 = o2.priority; + if (priority1 == priority2) { + // if priorities are equal, we want to sort by messageId + return o1.message.getMessageId().compareTo(o2.message.getMessageId()); + } + return Integer.compare(priority2, priority1); + }; public MessagePriorityGrowableArrayBlockingQueue() { this(10); } public MessagePriorityGrowableArrayBlockingQueue(int initialCapacity) { - queue = - new PriorityBlockingQueue<>( - initialCapacity, - new Comparator() { - @Override - public int compare(Message o1, Message o2) { - int priority1 = getPriority(o1); - int priority2 = getPriority(o2); - return Integer.compare(priority2, priority1); - } - }); + queue = new PriorityBlockingQueue<>(initialCapacity, comparator); + for (int i = 0; i < 10; i++) { + numberMessagesByPriority[i] = new AtomicInteger(); + } } @Override - public Message remove() { - return queue.remove(); + public Message remove() { + throw new UnsupportedOperationException(); } @Override - public Message poll() { - return queue.poll(); + public Message poll() { + MessageWithPriority pair = queue.poll(); + if (pair == null) { + return null; + } + Message result = pair.message; + int prio = pair.priority; + if (log.isDebugEnabled()) { + log.debug( + "polled message prio {} {} stats {}", + prio, + result.getMessageId(), + Arrays.toString(numberMessagesByPriority)); + } + numberMessagesByPriority[prio].decrementAndGet(); + return result; } @Override - public Message element() { - return queue.element(); + public Message element() { + throw new UnsupportedOperationException(); } @Override - public Message peek() { - return queue.peek(); + public Message peek() { + MessageWithPriority pair = queue.peek(); + if (pair == null) { + return null; + } + Message result = pair.message; + if (log.isDebugEnabled()) { + log.debug( + "peeking message: {} prio {}", + result.getMessageId(), + PulsarMessage.readJMSPriority(result)); + } + return result; } @Override - public boolean offer(Message e) { - return queue.offer(e); + public boolean offer(Message e) { + boolean result; + if (!this.terminated.get()) { + int prio = PulsarMessage.readJMSPriority(e); + numberMessagesByPriority[prio].incrementAndGet(); + result = queue.offer(new MessageWithPriority(prio, e)); + if (log.isDebugEnabled()) { + log.debug( + "offered message: {} prio {} stats {}", + e.getMessageId(), + prio, + Arrays.toString(numberMessagesByPriority)); + } + } else { + if (log.isDebugEnabled()) { + log.debug("queue is terminated, not offering message: {}", e.getMessageId()); + } + if (itemAfterTerminatedHandler != null) { + itemAfterTerminatedHandler.accept(e); + } + result = false; + } + return result; } @Override - public void put(Message e) { - queue.put(e); + public void put(Message e) { + throw new UnsupportedOperationException(); } @Override - public boolean add(Message e) { - return queue.add(e); + public boolean add(Message e) { + throw new UnsupportedOperationException(); } @Override public boolean offer(Message e, long timeout, TimeUnit unit) { - return queue.offer(e, timeout, unit); + throw new UnsupportedOperationException(); } @Override - public Message take() throws InterruptedException { - return queue.take(); + public Message take() throws InterruptedException { + throw new UnsupportedOperationException(); } @Override - public Message poll(long timeout, TimeUnit unit) throws InterruptedException { - return queue.poll(timeout, unit); + public Message poll(long timeout, TimeUnit unit) throws InterruptedException { + MessageWithPriority pair = queue.poll(timeout, unit); + if (pair == null) { + return null; + } + Message result = pair.message; + int prio = pair.priority; + if (log.isDebugEnabled()) { + log.debug( + "polled message (tm {} {}):prio {} {} stats {}", + timeout, + unit, + prio, + result.getMessageId(), + Arrays.toString(numberMessagesByPriority)); + } + numberMessagesByPriority[prio].decrementAndGet(); + return result; } @Override - public int remainingCapacity() { - return queue.remainingCapacity(); + public void clear() { + queue.clear(); } @Override - public int drainTo(Collection c) { - return queue.drainTo(c); + public int size() { + return queue.size(); } @Override - public int drainTo(Collection c, int maxElements) { - return queue.drainTo(c, maxElements); + public void forEach(Consumer> action) { + queue.stream().sorted(comparator).forEach(x -> action.accept(x.message)); } @Override - public void clear() { - queue.clear(); + public String toString() { + return "queue:" + queue + ", stats:" + getPriorityStats() + ", terminated:" + terminated.get(); } @Override - public boolean remove(Object o) { - return queue.remove(o); + public void terminate(Consumer> itemAfterTerminatedHandler) { + this.itemAfterTerminatedHandler = itemAfterTerminatedHandler; + terminated.set(true); } @Override - public int size() { - return queue.size(); + public boolean isTerminated() { + return terminated.get(); } @Override - public Iterator iterator() { - return queue.iterator(); + public boolean remove(Object o) { + throw new UnsupportedOperationException(); } @Override - public List toList() { - List list = new ArrayList<>(size()); - forEach(list::add); - return list; + public int remainingCapacity() { + throw new UnsupportedOperationException(); } @Override - public void forEach(Consumer action) { - queue.forEach(action); + public int drainTo(Collection> c) { + throw new UnsupportedOperationException(); } @Override - public String toString() { - return queue.toString(); + public int drainTo(Collection> c, int maxElements) { + throw new UnsupportedOperationException(); } @Override - public void terminate(Consumer itemAfterTerminatedHandler) { - terminated.set(true); + public Iterator> iterator() { + throw new UnsupportedOperationException(); } @Override - public boolean isTerminated() { - return terminated.get(); + public List> toList() { + throw new UnsupportedOperationException(); + } + + public String getPriorityStats() { + return Arrays.toString(numberMessagesByPriority); } } diff --git a/pulsar-jms/src/main/java/com/datastax/oss/pulsar/jms/PulsarConnectionFactory.java b/pulsar-jms/src/main/java/com/datastax/oss/pulsar/jms/PulsarConnectionFactory.java index f620931e..31d6aab9 100644 --- a/pulsar-jms/src/main/java/com/datastax/oss/pulsar/jms/PulsarConnectionFactory.java +++ b/pulsar-jms/src/main/java/com/datastax/oss/pulsar/jms/PulsarConnectionFactory.java @@ -91,9 +91,11 @@ import org.apache.pulsar.client.api.SubscriptionType; import org.apache.pulsar.client.api.TopicMetadata; import org.apache.pulsar.client.impl.ConsumerBase; +import org.apache.pulsar.client.impl.ConsumerImpl; import org.apache.pulsar.client.impl.MultiTopicsConsumerImpl; import org.apache.pulsar.client.impl.PulsarClientImpl; import org.apache.pulsar.client.impl.auth.AuthenticationToken; +import org.apache.pulsar.common.naming.TopicName; import org.apache.pulsar.common.partition.PartitionedTopicMetadata; @Slf4j @@ -1088,9 +1090,7 @@ Producer getProducerForDestination(Destination defaultDestination, boole @Override public int choosePartition(Message msg, TopicMetadata metadata) { - Integer priority = PulsarMessage.readJMSPriority(msg); - int key = - priority == null ? PulsarMessage.DEFAULT_PRIORITY : priority; + int key = PulsarMessage.readJMSPriority(msg); return Utils.mapPriorityToPartition( key, metadata.numPartitions(), @@ -1346,8 +1346,8 @@ private static void replaceIncomingMessageList(Consumer c) { new Comparator() { @Override public int compare(Message o1, Message o2) { - int priority1 = MessagePriorityGrowableArrayBlockingQueue.getPriority(o1); - int priority2 = MessagePriorityGrowableArrayBlockingQueue.getPriority(o2); + int priority1 = PulsarMessage.readJMSPriority(o1); + int priority2 = PulsarMessage.readJMSPriority(o2); return Integer.compare(priority2, priority1); } }); @@ -1370,11 +1370,46 @@ public int compare(Message o1, Message o2) { ((BlockingQueue) oldQueue).drainTo(newQueue); incomingMessages.set(c, newQueue); + + if (consumerBase instanceof MultiTopicsConsumerImpl) { + setReceiverQueueSizeForJMSPriority(consumerBase); + } } catch (Exception err) { throw new RuntimeException(err); } } + private static void setReceiverQueueSizeForJMSPriority(ConsumerBase consumerBase) throws Exception { + Field consumersField = MultiTopicsConsumerImpl.class.getDeclaredField("consumers"); + + consumersField.setAccessible(true); + + ConcurrentHashMap> consumers = + (ConcurrentHashMap) consumersField.get(consumerBase); + Method setCurrentReceiverQueueSizeMethod = + ConsumerImpl.class.getDeclaredMethod("setCurrentReceiverQueueSize", int.class); + setCurrentReceiverQueueSizeMethod.setAccessible(true); + + // set the queue size for each consumer based on the partition index + // we set a higher number to the consumers for the higher priority partitions + // this way the backlog is drained more quickly for the higher priority partitions + int numConsumers = consumers.size(); + int sumPriorities = + (numConsumers * (numConsumers + 1)) / 2; // 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + 10 + int receiverQueueSize = consumerBase.getCurrentReceiverQueueSize(); + + for (ConsumerImpl consumer : consumers.values()) { + String topic = consumer.getTopic(); + int partitionIndex = TopicName.get(topic).getPartitionIndex(); + // no need to map exactly the partition index to the priority + int prio = Math.max(partitionIndex, 0); + // the size is proportional to the priority (partition index) + int size = Math.max(1, (prio + 1) * receiverQueueSize / sumPriorities); + log.info("Setting receiverQueueSize={} for {} (to handle JMSPriority)", size, topic); + setCurrentReceiverQueueSizeMethod.invoke(consumer, size); + } + } + public String downloadServerSideFilter( String fullQualifiedTopicName, String subscriptionName, SubscriptionMode subscriptionMode) throws JMSException { diff --git a/pulsar-jms/src/main/java/com/datastax/oss/pulsar/jms/PulsarMessage.java b/pulsar-jms/src/main/java/com/datastax/oss/pulsar/jms/PulsarMessage.java index 99141dad..324cd578 100644 --- a/pulsar-jms/src/main/java/com/datastax/oss/pulsar/jms/PulsarMessage.java +++ b/pulsar-jms/src/main/java/com/datastax/oss/pulsar/jms/PulsarMessage.java @@ -1377,9 +1377,8 @@ protected PulsarMessage applyMessage( if (msg.hasProperty("JMSCorrelationID")) { this.correlationId = Base64.getDecoder().decode(msg.getProperty("JMSCorrelationID")); } - Integer jmsPriorityValue = readJMSPriority(msg); - if (jmsPriorityValue != null) { - this.jmsPriority = jmsPriorityValue; + if (msg.hasProperty("JMSPriority")) { + this.jmsPriority = readJMSPriority(msg); } if (msg.hasProperty("JMSDeliveryMode")) { try { @@ -1479,14 +1478,18 @@ public org.apache.pulsar.client.api.Message getReceivedPulsarMessage() { return receivedPulsarMessage; } - public static Integer readJMSPriority(org.apache.pulsar.client.api.Message msg) { + public static int readJMSPriority(org.apache.pulsar.client.api.Message msg) { if (msg.hasProperty("JMSPriority")) { try { - return Integer.parseInt(msg.getProperty("JMSPriority")); + int value = Integer.parseInt(msg.getProperty("JMSPriority")); + if (value < 0 || value >= 10) { // impossible values according to JMS Specs + return PulsarMessage.DEFAULT_PRIORITY; + } + return value; } catch (NumberFormatException err) { // cannot decode priority, not a big deal as it is not supported in Pulsar } } - return null; + return PulsarMessage.DEFAULT_PRIORITY; } } diff --git a/pulsar-jms/src/test/java/com/datastax/oss/pulsar/jms/MessagePriorityGrowableArrayBlockingQueueTest.java b/pulsar-jms/src/test/java/com/datastax/oss/pulsar/jms/MessagePriorityGrowableArrayBlockingQueueTest.java new file mode 100644 index 00000000..6d786f2e --- /dev/null +++ b/pulsar-jms/src/test/java/com/datastax/oss/pulsar/jms/MessagePriorityGrowableArrayBlockingQueueTest.java @@ -0,0 +1,77 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datastax.oss.pulsar.jms; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import org.apache.pulsar.client.api.Message; +import org.apache.pulsar.client.impl.MessageIdImpl; +import org.junit.jupiter.api.Test; + +class MessagePriorityGrowableArrayBlockingQueueTest { + @Test + public void basicTest() { + test(1, 2, 9); + test(2, 1, 9); + test(2, 9, 1); + } + + private static void test(int... priorities) { + List prio = new ArrayList<>(); + for (int i : priorities) { + prio.add(i); + } + + List sorted = prio; + sorted.sort(Comparator.reverseOrder()); + + MessagePriorityGrowableArrayBlockingQueue queue = + new MessagePriorityGrowableArrayBlockingQueue<>(); + int position = 0; + for (int i : priorities) { + queue.offer(messageWithPriority(i, position++)); + } + + List prioritiesForEach = new ArrayList<>(); + queue.forEach( + m -> { + System.out.println("prio: " + m.getProperty("JMSPriority")); + prioritiesForEach.add(PulsarMessage.readJMSPriority(m)); + }); + assertEquals(prioritiesForEach, sorted); + + List polledPriorities = new ArrayList<>(); + while (queue.peek() != null) { + Message message = queue.poll(); + polledPriorities.add(Integer.parseInt(message.getProperty("JMSPriority"))); + } + assertEquals(polledPriorities, sorted); + } + + private static Message messageWithPriority(int priority, int position) { + Message message = mock(Message.class); + when(message.hasProperty(eq("JMSPriority"))).thenReturn(true); + when(message.getProperty("JMSPriority")).thenReturn(priority + ""); + when(message.getMessageId()).thenReturn(new MessageIdImpl(1, position, 1)); + return message; + } +}