diff --git a/MODULE.bazel b/MODULE.bazel index ead497f92..f61a3d109 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -246,6 +246,9 @@ maven.install( "org.apache.commons:commons-compress:1.22", "org.brotli:dec:0.1.2", "com.github.luben:zstd-jni:1.5.2-5", + + # RabbitMQ + "com.rabbitmq:amqp-client:5.22.0" ], # Note that order matters here, where earlier BOMs "win". boms = [ diff --git a/imports/java/rabbitmq/BUILD.bazel b/imports/java/rabbitmq/BUILD.bazel new file mode 100644 index 000000000..3b56bddef --- /dev/null +++ b/imports/java/rabbitmq/BUILD.bazel @@ -0,0 +1,6 @@ +package(default_visibility = ["//visibility:public"]) + +alias( + name = "client", + actual = "@maven//:com_rabbitmq_amqp_client", +) \ No newline at end of file diff --git a/maven_install.json b/maven_install.json index ba44115c0..9dd7a4d63 100755 --- a/maven_install.json +++ b/maven_install.json @@ -1,7 +1,7 @@ { "__AUTOGENERATED_FILE_DO_NOT_MODIFY_THIS_FILE_MANUALLY": "THERE_IS_NO_DATA_ONLY_ZUUL", - "__INPUT_ARTIFACTS_HASH": -1541156068, - "__RESOLVED_ARTIFACTS_HASH": 1589895329, + "__INPUT_ARTIFACTS_HASH": -1222882013, + "__RESOLVED_ARTIFACTS_HASH": -1382991214, "artifacts": { "com.adobe.testing:s3mock": { "shasums": { @@ -773,6 +773,13 @@ }, "version": "5.6" }, + "com.rabbitmq:amqp-client": { + "shasums": { + "jar": "f5df000445bbab12f650d4cd9224c57c2e416eb75a7b80f5734c670fbf932b1e", + "sources": "de7f700b6922dba8fdd84f84d43531b061085521329686009b37818ec13e4bd6" + }, + "version": "5.22.0" + }, "com.sun.istack:istack-commons-runtime": { "shasums": { "jar": "7fd6792361f4dd00f8c56af4a20cecc0066deea4a8f3dec38348af23fc2296ee", @@ -3150,6 +3157,9 @@ "org.apache.commons:commons-lang3", "org.apache.commons:commons-text" ], + "com.rabbitmq:amqp-client": [ + "org.slf4j:slf4j-api" + ], "commons-beanutils:commons-beanutils": [ "commons-collections:commons-collections", "commons-logging:commons-logging" @@ -4926,6 +4936,18 @@ "com.opencsv.stream.reader", "com.opencsv.validators" ], + "com.rabbitmq:amqp-client": [ + "com.rabbitmq.client", + "com.rabbitmq.client.impl", + "com.rabbitmq.client.impl.nio", + "com.rabbitmq.client.impl.recovery", + "com.rabbitmq.client.observation", + "com.rabbitmq.client.observation.micrometer", + "com.rabbitmq.tools", + "com.rabbitmq.tools.json", + "com.rabbitmq.tools.jsonrpc", + "com.rabbitmq.utility" + ], "com.sun.istack:istack-commons-runtime": [ "com.sun.istack", "com.sun.istack.localization", @@ -8743,6 +8765,8 @@ "com.ongres.stringprep:stringprep:jar:sources", "com.opencsv:opencsv", "com.opencsv:opencsv:jar:sources", + "com.rabbitmq:amqp-client", + "com.rabbitmq:amqp-client:jar:sources", "com.sun.istack:istack-commons-runtime", "com.sun.istack:istack-commons-runtime:jar:sources", "commons-beanutils:commons-beanutils", diff --git a/src/main/kotlin/org/wfanet/measurement/common/rabbitmq/BUILD.bazel b/src/main/kotlin/org/wfanet/measurement/common/rabbitmq/BUILD.bazel new file mode 100644 index 000000000..122c4a60e --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/common/rabbitmq/BUILD.bazel @@ -0,0 +1,14 @@ +load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_library") + +package(default_visibility = ["//visibility:public"]) + +kt_jvm_library( + name = "rabbit_mq_client", + srcs = ["RabbitMqClient.kt"], + deps = [ + "//imports/java/rabbitmq:client", + "//imports/java/picocli", + "//imports/kotlin/kotlinx/coroutines:core", + "//src/main/kotlin/org/wfanet/measurement/common", + ], +) \ No newline at end of file diff --git a/src/main/kotlin/org/wfanet/measurement/common/rabbitmq/RabbitMqClient.kt b/src/main/kotlin/org/wfanet/measurement/common/rabbitmq/RabbitMqClient.kt new file mode 100644 index 000000000..406aeaf19 --- /dev/null +++ b/src/main/kotlin/org/wfanet/measurement/common/rabbitmq/RabbitMqClient.kt @@ -0,0 +1,223 @@ +// Copyright 2024 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.common.rabbitmq + +import com.rabbitmq.client.AMQP +import com.rabbitmq.client.ConnectionFactory +import com.rabbitmq.client.Channel +import com.rabbitmq.client.DefaultConsumer +import com.rabbitmq.client.Envelope +import com.rabbitmq.client.Connection +import kotlin.coroutines.cancellation.CancellationException +import kotlin.system.exitProcess +import kotlinx.coroutines.runBlocking +import picocli.CommandLine + + +/** + * An abstract class that provides RabbitMQ messaging capabilities through command line configuration. + * + * This class handles connection management and message consumption from a RabbitMQ queue. + * Subclasses must implement [runWork] to define message processing logic. + * + * The client can be configured to: + * - Process a single message and exit using [RabbitMqFlags.consumeSingleMessage] + * - Process messages continuously until terminated + * + * @property rabbitMqFlags Configuration parameters for RabbitMQ connection and behavior + */ +@CommandLine.Command( + name = "RabbitMQClient", + description = ["Worker implementation for processing RabbitMQ messages"], + mixinStandardHelpOptions = true, + showDefaultValues = true +) +abstract class RabbitMQClient : Runnable { + + @CommandLine.Mixin + lateinit var rabbitMqFlags: RabbitMqFlags + private set + + private lateinit var connection: Connection + private lateinit var channel: Channel + + /** + * Abstract method to be implemented by subclasses to process RabbitMQ messages. + * + * @param message The message body as a ByteArray + */ + protected abstract suspend fun runWork(message: ByteArray) + + /** + * Initializes and runs the RabbitMQ client. + * Sets up the connection and starts consuming messages from the configured queue. + * + * @throws Exception if there's an error during setup or message processing + */ + override fun run() = runBlocking { + try { + setupRabbitMQ() + subscribeAndWait() + } catch (e: Exception) { + println("Fatal error during worker execution: ${e.message}") + throw e + } + } + + /** + * Sets up the RabbitMQ connection and channel. + * If [RabbitMqFlags.consumeSingleMessage] is true, sets QoS prefetch count to 1, + * meaning RabbitMQ will only deliver one unacknowledged message at a time to this consumer. + */ + private fun setupRabbitMQ() { + val factory = ConnectionFactory().apply { + host = rabbitMqFlags.rabbitHost + port = rabbitMqFlags.rabbitPort + username = rabbitMqFlags.rabbitUsername + password = rabbitMqFlags.rabbitPassword + } + connection = factory.newConnection() + channel = connection.createChannel().apply { + if (rabbitMqFlags.consumeSingleMessage) { + basicQos(1) + } + } + } + + /** + * Subscribes to the configured queue and processes messages. + * For each message received: + * 1. Launches a coroutine to process the message using [runWork] + * 2. Acknowledges or negatively acknowledges based on processing result + * 3. If [RabbitMqFlags.consumeSingleMessage] is true, exits after processing one message + */ + private fun subscribeAndWait() { + channel.basicConsume(rabbitMqFlags.rabbitQueueName, false, object : DefaultConsumer(channel) { + override fun handleDelivery( + consumerTag: String, + envelope: Envelope, + properties: AMQP.BasicProperties, + body: ByteArray + ) { + runBlocking { + try { + runWork(body) + channel.basicAck(envelope.deliveryTag, false) + if (rabbitMqFlags.consumeSingleMessage) { + cleanup() + } + } catch (e: CancellationException) { + println("Work was cancelled: ${e.message}") + channel.basicNack(envelope.deliveryTag, false, true) + if (rabbitMqFlags.consumeSingleMessage) { + cleanup() + } + } catch (e: Exception) { + println("Error processing message: ${e.message}") + channel.basicNack(envelope.deliveryTag, false, true) + if (rabbitMqFlags.consumeSingleMessage) { + cleanup() + } + } + } + } + }) + } + + /** + * Performs cleanup by closing the channel and connection, then exits the process. + * Called automatically after processing a message if [RabbitMqFlags.consumeSingleMessage] is true. + * + * @throws Exception if there's an error during cleanup + */ + private fun cleanup() { + try { + if (::channel.isInitialized) { + try { + channel.close() + } catch (e: Exception) { + println("Error closing channel: ${e.message}") + } + } + + if (::connection.isInitialized) { + try { + connection.close() + } catch (e: Exception) { + println("Error closing connection: ${e.message}") + } + } + + exitProcess(0) + } catch (e: Exception) { + println("Error during cleanup: ${e.message}") + exitProcess(1) + } + } + + /** + * Configuration flags for RabbitMQ connection and behavior. + */ + class RabbitMqFlags { + @CommandLine.Option( + names = ["--rabbitmq-host"], + description = ["Host name of the RabbitMQ server."], + required = true + ) + lateinit var rabbitHost: String + private set + + @CommandLine.Option( + names = ["--rabbitmq-port"], + description = ["Port of the RabbitMQ server."], + required = true + ) + var rabbitPort: Int = 5672 + private set + + @CommandLine.Option( + names = ["--rabbitmq-username"], + description = ["Username to authenticate to the RabbitMQ server."], + required = true + ) + lateinit var rabbitUsername: String + private set + + @CommandLine.Option( + names = ["--rabbitmq-password"], + description = ["Password to authenticate to the RabbitMQ server."], + required = true + ) + lateinit var rabbitPassword: String + private set + + @CommandLine.Option( + names = ["--rabbitmq-queue-name"], + description = ["The queue name to subscribe."], + required = true + ) + lateinit var rabbitQueueName: String + private set + + @CommandLine.Option( + names = ["--consume-single-message"], + description = ["If set, only consumes a single message then exits."], + required = false + ) + var consumeSingleMessage: Boolean = false + private set + + } +} diff --git a/src/test/kotlin/org/wfanet/measurement/common/rabbitmq/BUILD.bazel b/src/test/kotlin/org/wfanet/measurement/common/rabbitmq/BUILD.bazel new file mode 100644 index 000000000..0e674b59d --- /dev/null +++ b/src/test/kotlin/org/wfanet/measurement/common/rabbitmq/BUILD.bazel @@ -0,0 +1,17 @@ +load("@wfa_rules_kotlin_jvm//kotlin:defs.bzl", "kt_jvm_test") + +kt_jvm_test( + name = "RabbitMqClientTest", + srcs = ["RabbitMqClientTest.kt"], + test_class = "org.wfanet.measurement.common.rabbitmq.RabbitMQClientTest", + deps = [ + "//imports/java/com/google/common/truth", + "//imports/java/com/google/crypto/tink", + "//imports/java/org/junit", + "//src/main/kotlin/org/wfanet/measurement/common/rabbitmq:rabbit_mq_client", + ], + jvm_flags = [ + "-Dtest-broker.A.nodename=rabbit@localhost", + "-Drabbitmqctl.bin=/usr/sbin/rabbitmqctl", + ], +) \ No newline at end of file diff --git a/src/test/kotlin/org/wfanet/measurement/common/rabbitmq/RabbitMqClientTest.kt b/src/test/kotlin/org/wfanet/measurement/common/rabbitmq/RabbitMqClientTest.kt new file mode 100644 index 000000000..460a86cbe --- /dev/null +++ b/src/test/kotlin/org/wfanet/measurement/common/rabbitmq/RabbitMqClientTest.kt @@ -0,0 +1,202 @@ +// Copyright 2024 The Cross-Media Measurement Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package org.wfanet.measurement.common.rabbitmq + +import com.rabbitmq.client.ConnectionFactory +import java.nio.charset.Charset +import kotlinx.coroutines.runBlocking +import org.junit.After +import org.junit.Before +import org.junit.Test +import com.google.common.truth.Truth.assertThat +import java.util.* +import java.util.concurrent.CountDownLatch +import java.util.concurrent.TimeUnit +import kotlinx.coroutines.delay +import picocli.CommandLine + +private class TestRabbitMQClient : RabbitMQClient() { + private val processedMessages = Collections.synchronizedList(mutableListOf()) + private var messageProcessedLatch = CountDownLatch(1) + + override suspend fun runWork(message: ByteArray) { + processedMessages.add(message.toString(Charset.defaultCharset())) + messageProcessedLatch.countDown() + } + + fun getProcessedMessages(): List = processedMessages.toList() + + fun waitForMessageProcessing(): Boolean { + val result = messageProcessedLatch.await(5, TimeUnit.SECONDS) + if (result && !rabbitMqFlags.consumeSingleMessage) { + messageProcessedLatch = CountDownLatch(1) + } + return result + } + + override fun run() = runBlocking { + try { + super.run() + } catch (e: SecurityException) { + if (!e.message!!.contains("System.exit")) { + throw e + } + } + } +} + +class RabbitMQClientTest { + private var testClient: TestRabbitMQClient? = null + private var connectionFactory: ConnectionFactory? = null + private val queueName = "test-queue-${System.currentTimeMillis()}" + private val testMessage = "Hello, RabbitMQ!" + + @Before + fun setup() { + connectionFactory = ConnectionFactory().apply { + host = "localhost" + port = 5672 + username = "guest" + password = "guest" + } + + // Test connection and create queue + connectionFactory?.newConnection()?.use { connection -> + connection.createChannel().use { channel -> + channel.queueDeclare(queueName, false, false, true, null) + } + } + + } + + @After + fun cleanup() { + connectionFactory?.newConnection()?.use { connection -> + connection.createChannel().use { channel -> + try { + channel.queueDelete(queueName) + } catch (e: Exception) { + println("Failed to delete queue: ${e.message}") + } + } + } + } + + @Test + fun `test single message consumption`() = runBlocking { + // Publish test message + testClient = TestRabbitMQClient() + val cmd = CommandLine(testClient) + cmd.parseArgs( + "--rabbitmq-host=localhost", + "--rabbitmq-port=5672", + "--rabbitmq-username=guest", + "--rabbitmq-password=guest", + "--rabbitmq-queue-name=$queueName", + "--consume-single-message=true" + ) + connectionFactory?.newConnection()?.use { connection -> + connection.createChannel().use { channel -> + channel.basicPublish( + "", + queueName, + null, + testMessage.toByteArray() + ) + } + } + + val clientThread = Thread { + testClient?.run() + } + clientThread.start() + + assertThat(testClient?.waitForMessageProcessing()) + .isTrue() + + assertThat(testClient?.getProcessedMessages()?.size).isEqualTo(1) + assertThat(testClient?.getProcessedMessages()?.first()).isEqualTo(testMessage) + } + + @Test + fun testMultipleMessagesWhenConsumeSingleMessageIsFalse() { + val multiMessageClient = TestRabbitMQClient() + val cmd = CommandLine(multiMessageClient) + cmd.parseArgs( + "--rabbitmq-host=localhost", + "--rabbitmq-port=5672", + "--rabbitmq-username=guest", + "--rabbitmq-password=guest", + "--rabbitmq-queue-name=$queueName", + "--consume-single-message=false" + ) + + val messages = listOf("Message1", "Message2", "Message3") + + runBlocking { + connectionFactory?.newConnection()?.use { connection -> + connection.createChannel().use { channel -> + messages.forEach { msg -> + channel.basicPublish("", queueName, null, msg.toByteArray()) + } + } + } + + val clientThread = Thread { + multiMessageClient.run() + } + clientThread.start() + repeat(messages.size) { + assertThat(multiMessageClient.waitForMessageProcessing()).isTrue() + } + assertThat(multiMessageClient.getProcessedMessages()).containsExactlyElementsIn(messages) + } + } + + @Test + fun testMultipleMessagesWhenConsumeSingleMessageIsTrue() { + val singleMessageClient = TestRabbitMQClient() + val cmd = CommandLine(singleMessageClient) + cmd.parseArgs( + "--rabbitmq-host=localhost", + "--rabbitmq-port=5672", + "--rabbitmq-username=guest", + "--rabbitmq-password=guest", + "--rabbitmq-queue-name=$queueName", + "--consume-single-message=true" + ) + + val messages = listOf("Message1", "Message2", "Message3") + + runBlocking { + connectionFactory?.newConnection()?.use { connection -> + connection.createChannel().use { channel -> + messages.forEach { msg -> + channel.basicPublish("", queueName, null, msg.toByteArray()) + } + } + } + val clientThread = Thread { + singleMessageClient.run() + } + clientThread.start() + repeat(messages.size) { + assertThat(singleMessageClient.waitForMessageProcessing()).isTrue() + } + assertThat(singleMessageClient.getProcessedMessages()).containsExactly("Message1") + } + } + +}