Skip to content

Commit

Permalink
[ISSUE#46] Add support for wildcard subscription (#75)
Browse files Browse the repository at this point in the history
* [ISSUE#46] Add support for wildcard subscription
  • Loading branch information
deepanshu42 authored Nov 14, 2023
1 parent 3ca9fb1 commit eba8018
Show file tree
Hide file tree
Showing 12 changed files with 155 additions and 16 deletions.
3 changes: 3 additions & 0 deletions courier-core/api/courier-core.api
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ public final class com/gojek/courier/extensions/CollectionExtensionsKt {
public static final fun toImmutableSet (Ljava/util/Set;)Ljava/util/Set;
}

public final class com/gojek/courier/extensions/StringExtensionsKt {
}

public final class com/gojek/courier/extensions/TimeUnitExtensionsKt {
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package com.gojek.courier.extensions

import androidx.annotation.RestrictTo

@RestrictTo(RestrictTo.Scope.LIBRARY)
fun String.isWildCardTopic(): Boolean {
return startsWith("+/") || contains("/+/") || endsWith("/+") || equals("+") ||
endsWith("/#") || equals("#")
}
47 changes: 45 additions & 2 deletions courier/src/main/java/com/gojek/courier/coordinator/Coordinator.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@ import com.gojek.mqtt.client.model.ConnectionState
import com.gojek.mqtt.client.model.MqttMessage
import com.gojek.mqtt.event.EventHandler
import com.gojek.mqtt.event.MqttEvent
import com.gojek.mqtt.event.MqttEvent.MqttSubscribeFailureEvent
import io.reactivex.BackpressureStrategy
import io.reactivex.Flowable
import io.reactivex.FlowableOnSubscribe
import io.reactivex.disposables.CompositeDisposable
import io.reactivex.schedulers.Schedulers
import io.reactivex.subjects.PublishSubject
import org.eclipse.paho.client.mqttv3.MqttException
import org.reactivestreams.Subscriber
import org.reactivestreams.Subscription

Expand All @@ -28,6 +32,18 @@ internal class Coordinator(
private val logger: ILogger
) : StubInterface.Callback {

private val eventSubject = PublishSubject.create<MqttEvent> { emitter ->
val eventHandler = object : EventHandler {
override fun onEvent(mqttEvent: MqttEvent) {
if (emitter.isDisposed.not()) {
emitter.onNext(mqttEvent)
}
}
}
client.addEventHandler(eventHandler)
emitter.setCancellable { client.removeEventHandler(eventHandler) }
}

@Synchronized
override fun send(stubMethod: StubMethod.Send, args: Array<Any>): Any {
logger.d("Coordinator", "Send method invoked")
Expand Down Expand Up @@ -106,7 +122,15 @@ internal class Coordinator(
}
}
client.addMessageListener(topic, listener)
emitter.setCancellable { client.removeMessageListener(topic, listener) }
val eventDisposable = eventSubject.filter { event ->
isInvalidSubscriptionFailureEvent(event, topic)
}.subscribe {
client.removeMessageListener(topic, listener)
}
emitter.setCancellable {
client.removeMessageListener(topic, listener)
eventDisposable.dispose()
}
},
BackpressureStrategy.BUFFER
)
Expand Down Expand Up @@ -166,9 +190,22 @@ internal class Coordinator(
}
}
}
val eventDisposable = CompositeDisposable()
for (topic in topicList) {
client.addMessageListener(topic.first, listener)
emitter.setCancellable { client.removeMessageListener(topic.first, listener) }
eventDisposable.add(
eventSubject.filter { event ->
isInvalidSubscriptionFailureEvent(event, topic.first)
}.subscribe {
client.removeMessageListener(topic.first, listener)
}
)
}
emitter.setCancellable {
for (topic in topicList) {
client.removeMessageListener(topic.first, listener)
eventDisposable.dispose()
}
}
},
BackpressureStrategy.BUFFER
Expand Down Expand Up @@ -249,4 +286,10 @@ internal class Coordinator(
null
}
}

private fun isInvalidSubscriptionFailureEvent(event: MqttEvent, topic: String): Boolean {
return event is MqttSubscribeFailureEvent &&
event.topics.containsKey(topic) &&
event.exception.reasonCode == MqttException.REASON_CODE_INVALID_SUBSCRIPTION.toInt()
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.gojek.mqtt.client

import com.gojek.courier.extensions.fromSecondsToNanos
import com.gojek.courier.extensions.isWildCardTopic
import com.gojek.courier.logging.ILogger
import com.gojek.courier.utils.Clock
import com.gojek.mqtt.client.listener.MessageListener
Expand Down Expand Up @@ -47,6 +48,8 @@ internal class IncomingMsgControllerImpl(

private val listenerMap = ConcurrentHashMap<String, List<MessageListener>>()

private val wildcardTopicListenerMap = ConcurrentHashMap<String, List<MessageListener>>()

private var cleanupFuture: ScheduledFuture<*>? = null

override fun triggerHandleMessage() {
Expand All @@ -64,40 +67,68 @@ internal class IncomingMsgControllerImpl(

@Synchronized
override fun registerListener(topic: String, listener: MessageListener) {
listenerMap[topic] = (listenerMap[topic] ?: emptyList()) + listener
if (topic.isWildCardTopic()) {
wildcardTopicListenerMap[topic] = (wildcardTopicListenerMap[topic] ?: emptyList()) + listener
} else {
listenerMap[topic] = (listenerMap[topic] ?: emptyList()) + listener
}
triggerHandleMessage()
}

@Synchronized
override fun unregisterListener(topic: String, listener: MessageListener) {
listenerMap[topic] = (listenerMap[topic] ?: emptyList()) - listener
if (listenerMap[topic]!!.isEmpty()) {
listenerMap.remove(topic)
if (topic.isWildCardTopic()) {
wildcardTopicListenerMap[topic] = (wildcardTopicListenerMap[topic] ?: emptyList()) - listener
if (wildcardTopicListenerMap[topic]!!.isEmpty()) {
wildcardTopicListenerMap.remove(topic)
}
} else {
listenerMap[topic] = (listenerMap[topic] ?: emptyList()) - listener
if (listenerMap[topic]!!.isEmpty()) {
listenerMap.remove(topic)
}
}
}

private inner class HandleMessage : Runnable {
override fun run() {
try {
if (listenerMap.keys.isEmpty()) {
if (listenerMap.keys.isEmpty() && wildcardTopicListenerMap.isEmpty()) {
logger.d(TAG, "No listeners registered")
return
}
val messages: List<MqttReceivePacket> =
mqttReceivePersistence.getAllIncomingMessagesWithTopicFilter(listenerMap.keys)
if (mqttUtils.isEmpty(messages)) {
logger.d(TAG, "No Messages in Table")
return
}
val deletedMsgIds = mutableListOf<Long>()
for (message in messages) {
logger.d(TAG, "Going to process ${message.messageId}")
val listenersNotified = notifyListeners(message)
val listenersNotified = notifyListeners(message, listenerMap[message.topic]!!)
if (listenersNotified) {
deletedMsgIds.add(message.messageId)
}
logger.d(TAG, "Successfully Processed Message ${message.messageId}")
}
// processing messages for wildcard topic subscription
for (wildCardTopic in wildcardTopicListenerMap.keys()) {
val topicForDBQuery = parseWildCardTopicForDBQuery(wildCardTopic)
val wildcardMessages: List<MqttReceivePacket> =
mqttReceivePersistence.getAllIncomingMessagesForWildCardTopic(topicForDBQuery)
for (message in wildcardMessages) {
logger.d(TAG, "Going to process ${message.messageId}")
val wildCardTopicRegex = parseWildCardTopicForRegex(wildCardTopic)
if (wildCardTopicRegex.matches(message.topic)) {
logger.d(TAG, "Wildcard topic: $wildCardTopic matches ${message.topic}")
val listenersNotified =
notifyListeners(message, wildcardTopicListenerMap[wildCardTopic]!!)
if (listenersNotified) {
deletedMsgIds.add(message.messageId)
}
} else {
logger.d(TAG, "Wildcard topic: $wildCardTopic does not match ${message.topic}")
}
logger.d(TAG, "Successfully Processed Message ${message.messageId}")
}
}
if (deletedMsgIds.isNotEmpty()) {
val deletedMessagesCount = deleteMessages(deletedMsgIds)
logger.d(TAG, "Deleted $deletedMessagesCount messages")
Expand All @@ -112,6 +143,18 @@ internal class IncomingMsgControllerImpl(
}
}

private fun parseWildCardTopicForDBQuery(topic: String): String {
var updatedTopic: String = topic.replace("+", "%")
updatedTopic = updatedTopic.replace("#", "%")
return updatedTopic
}

private fun parseWildCardTopicForRegex(topic: String): Regex {
var updatedTopic: String = topic.replace("+", "[^\\/]+")
updatedTopic = updatedTopic.replace("#", "([^\\/]+(\\/?[^\\/])*)+")
return Regex(updatedTopic)
}

private inner class CleanupExpiredMessages : Runnable {
override fun run() {
logger.d(TAG, "Deleting expired messages")
Expand All @@ -123,10 +166,10 @@ internal class IncomingMsgControllerImpl(
}
}

private fun notifyListeners(message: MqttReceivePacket): Boolean {
private fun notifyListeners(message: MqttReceivePacket, listeners: List<MessageListener>): Boolean {
var notified = false
try {
listenerMap[message.topic]!!.forEach {
listeners.forEach {
notified = true
it.onMessageReceived(message.toMqttMessage())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,7 @@ internal class MqttConnection(
),
timeTakenMillis = (clock.nanoTime() - subscribeStartTime).fromNanosToMillis()
)
subscriptionStore.getListener().onInvalidTopicsSubscribeFailure(topicMap)
}
}
}
Expand Down Expand Up @@ -546,6 +547,7 @@ internal class MqttConnection(
),
timeTakenMillis = (clock.nanoTime() - unsubscribeStartTime).fromNanosToMillis()
)
subscriptionStore.getListener().onInvalidTopicsUnsubscribeFailure(topics)
}
}
}
Expand Down Expand Up @@ -576,11 +578,12 @@ internal class MqttConnection(
connectionConfig.connectionEventHandler.onMqttSubscribeFailure(
topics = failTopicMap,
timeTakenMillis = (clock.nanoTime() - context.startTime).fromNanosToMillis(),
throwable = MqttException(MqttException.REASON_CODE_INVALID_SUBSCRIPTION.toInt())
throwable = MqttException(REASON_CODE_INVALID_SUBSCRIPTION.toInt())
)
}

subscriptionStore.getListener().onTopicsSubscribed(successTopicMap)
subscriptionStore.getListener().onInvalidTopicsSubscribeFailure(failTopicMap)
subscriptionPolicy.resetParams()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ internal interface IMqttReceivePersistence {
fun getAllIncomingMessagesWithTopicFilter(topics: Set<String>): List<MqttReceivePacket>
fun removeReceivedMessages(messageIds: List<Long>): Int
fun removeMessagesWithOlderTimestamp(timestampNanos: Long): Int
fun getAllIncomingMessagesForWildCardTopic(topic: String): List<MqttReceivePacket>
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ internal interface IncomingMessagesDao {
@Query("SELECT * from incoming_messages where topic in (:topics)")
fun getAllMessagesWithTopicFilter(topics: Set<String>): List<MqttReceivePacket>

@Query("SELECT * from incoming_messages where topic LIKE :topic")
fun getAllIncomingMessagesForWildCardTopic(topic: String): List<MqttReceivePacket>

@Query("DELETE from incoming_messages")
fun clearAllMessages()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ internal class PahoPersistence(private val context: Context) :
return incomingMessagesDao.getAllMessagesWithTopicFilter(topics)
}

override fun getAllIncomingMessagesForWildCardTopic(topic: String): List<MqttReceivePacket> {
return incomingMessagesDao.getAllIncomingMessagesForWildCardTopic(topic)
}

override fun removeReceivedMessages(messageIds: List<Long>): Int {
return incomingMessagesDao.removeMessagesById(messageIds)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@ import com.gojek.courier.QoS

internal class InMemorySubscriptionStore : SubscriptionStore {
private var state = State(mapOf())
private val listener = object : SubscriptionStoreListener {}
private val listener = object : SubscriptionStoreListener {
override fun onInvalidTopicsSubscribeFailure(topicMap: Map<String, QoS>) {
state = state.copy(
subscriptionTopics = state.subscriptionTopics - topicMap.keys
)
}
}

private data class State(val subscriptionTopics: Map<String, QoS>)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,17 @@ internal class PersistableSubscriptionStore(context: Context) : SubscriptionStor
override fun onTopicsUnsubscribed(topics: Set<String>) {
onTopicsUnsubscribedInternal(topics)
}

override fun onInvalidTopicsSubscribeFailure(topicMap: Map<String, QoS>) {
state = state.copy(
subscriptionTopics = state.subscriptionTopics - topicMap.keys,
pendingUnsubscribeTopics = state.pendingUnsubscribeTopics
)
}

override fun onInvalidTopicsUnsubscribeFailure(topics: Set<String>) {
onTopicsUnsubscribedInternal(topics)
}
}

private data class State(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@ internal class PersistableSubscriptionStoreV2(context: Context) : SubscriptionSt
override fun onTopicsUnsubscribed(topics: Set<String>) {
onTopicsUnsubscribedInternal(topics)
}

override fun onInvalidTopicsSubscribeFailure(topicMap: Map<String, QoS>) {
state = state.copy(
subscriptionTopics = state.subscriptionTopics - topicMap.keys,
pendingUnsubscribeTopics = state.pendingUnsubscribeTopics
)
}

override fun onInvalidTopicsUnsubscribeFailure(topics: Set<String>) {
onTopicsUnsubscribedInternal(topics)
}
}

private data class State(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,7 @@ internal interface SubscriptionStore {

internal interface SubscriptionStoreListener {
fun onTopicsSubscribed(topicMap: Map<String, QoS>) = Unit
fun onInvalidTopicsSubscribeFailure(topicMap: Map<String, QoS>) = Unit
fun onTopicsUnsubscribed(topics: Set<String>) = Unit
fun onInvalidTopicsUnsubscribeFailure(topics: Set<String>) = Unit
}

0 comments on commit eba8018

Please sign in to comment.