diff --git a/zio-kafka-test-utils/src/main/scala/zio/kafka/KafkaTestUtils.scala b/zio-kafka-test-utils/src/main/scala/zio/kafka/KafkaTestUtils.scala index 5d227c1407..aafc0660aa 100644 --- a/zio-kafka-test-utils/src/main/scala/zio/kafka/KafkaTestUtils.scala +++ b/zio-kafka-test-utils/src/main/scala/zio/kafka/KafkaTestUtils.scala @@ -70,6 +70,7 @@ object KafkaTestUtils { allowAutoCreateTopics: Boolean = true, offsetRetrieval: OffsetRetrieval = OffsetRetrieval.Auto(), restartStreamOnRebalancing: Boolean = false, + rebalanceSafeCommits: Boolean = false, `max.poll.records`: Int = 100, // settings this higher can cause concurrency bugs to go unnoticed runloopTimeout: Duration = ConsumerSettings.defaultRunloopTimeout, properties: Map[String, String] = Map.empty @@ -92,6 +93,7 @@ object KafkaTestUtils { .withPerPartitionChunkPrefetch(16) .withOffsetRetrieval(offsetRetrieval) .withRestartStreamOnRebalancing(restartStreamOnRebalancing) + .withRebalanceSafeCommits(rebalanceSafeCommits) .withProperties(properties) val withClientInstanceId = clientInstanceId.fold(settings)(settings.withGroupInstanceId) @@ -105,6 +107,7 @@ object KafkaTestUtils { allowAutoCreateTopics: Boolean = true, offsetRetrieval: OffsetRetrieval = OffsetRetrieval.Auto(), restartStreamOnRebalancing: Boolean = false, + rebalanceSafeCommits: Boolean = false, properties: Map[String, String] = Map.empty ): URIO[Kafka, ConsumerSettings] = consumerSettings( @@ -114,6 +117,7 @@ object KafkaTestUtils { allowAutoCreateTopics = allowAutoCreateTopics, offsetRetrieval = offsetRetrieval, restartStreamOnRebalancing = restartStreamOnRebalancing, + rebalanceSafeCommits = rebalanceSafeCommits, properties = properties ) .map( @@ -135,6 +139,7 @@ object KafkaTestUtils { allowAutoCreateTopics: Boolean = true, diagnostics: Diagnostics = Diagnostics.NoOp, restartStreamOnRebalancing: Boolean = false, + rebalanceSafeCommits: Boolean = false, properties: Map[String, String] = Map.empty ): ZLayer[Kafka, Throwable, Consumer] = (ZLayer( @@ -145,6 +150,7 @@ object KafkaTestUtils { allowAutoCreateTopics = allowAutoCreateTopics, offsetRetrieval = offsetRetrieval, restartStreamOnRebalancing = restartStreamOnRebalancing, + rebalanceSafeCommits = rebalanceSafeCommits, properties = properties ) ) ++ ZLayer.succeed(diagnostics)) >>> Consumer.live @@ -157,6 +163,7 @@ object KafkaTestUtils { allowAutoCreateTopics: Boolean = true, diagnostics: Diagnostics = Diagnostics.NoOp, restartStreamOnRebalancing: Boolean = false, + rebalanceSafeCommits: Boolean = false, properties: Map[String, String] = Map.empty, rebalanceListener: RebalanceListener = RebalanceListener.noop ): ZLayer[Kafka, Throwable, Consumer] = @@ -168,6 +175,7 @@ object KafkaTestUtils { allowAutoCreateTopics = allowAutoCreateTopics, offsetRetrieval = offsetRetrieval, restartStreamOnRebalancing = restartStreamOnRebalancing, + rebalanceSafeCommits = rebalanceSafeCommits, properties = properties ).map(_.withRebalanceListener(rebalanceListener)) ) ++ ZLayer.succeed(diagnostics)) >>> Consumer.live diff --git a/zio-kafka-test/src/test/scala/zio/kafka/consumer/ConsumerSpec.scala b/zio-kafka-test/src/test/scala/zio/kafka/consumer/ConsumerSpec.scala index 492f2bf9f3..6a43d803b7 100644 --- a/zio-kafka-test/src/test/scala/zio/kafka/consumer/ConsumerSpec.scala +++ b/zio-kafka-test/src/test/scala/zio/kafka/consumer/ConsumerSpec.scala @@ -478,6 +478,9 @@ object ConsumerSpec extends ZIOKafkaSpec { ) }, test("commit offsets for all consumed messages") { + // + // TODO: find out whether the test description is wrong (it doesn't seem to commit), or the test is wrong + // val nrMessages = 50 val messages = (1 to nrMessages).toList.map(i => (s"key$i", s"msg$i")) diff --git a/zio-kafka-test/src/test/scala/zio/kafka/consumer/RebalanceSafeCommitConsumerSpec.scala b/zio-kafka-test/src/test/scala/zio/kafka/consumer/RebalanceSafeCommitConsumerSpec.scala new file mode 100644 index 0000000000..9eb0e34f30 --- /dev/null +++ b/zio-kafka-test/src/test/scala/zio/kafka/consumer/RebalanceSafeCommitConsumerSpec.scala @@ -0,0 +1,1081 @@ +package zio.kafka.consumer + +import io.github.embeddedkafka.EmbeddedKafka +import org.apache.kafka.clients.consumer.{ + ConsumerConfig, + ConsumerPartitionAssignor, + CooperativeStickyAssignor, + RangeAssignor +} +import org.apache.kafka.clients.producer.{ ProducerRecord, RecordMetadata } +import org.apache.kafka.common.TopicPartition +import zio._ +import zio.kafka.KafkaTestUtils._ +import zio.kafka.ZIOKafkaSpec +import zio.kafka.consumer.Consumer.{ AutoOffsetStrategy, OffsetRetrieval } +import zio.kafka.consumer.diagnostics.{ DiagnosticEvent, Diagnostics } +import zio.kafka.embedded.Kafka +import zio.kafka.producer.{ Producer, TransactionalProducer } +import zio.kafka.serde.Serde +import zio.stream.{ ZSink, ZStream } +import zio.test.Assertion._ +import zio.test.TestAspect._ +import zio.test._ + +import scala.reflect.ClassTag + +object RebalanceSafeCommitConsumerSpec extends ZIOKafkaSpec { + override val kafkaPrefix: String = "commitsafeconsumespec" + + override def spec: Spec[TestEnvironment & Kafka, Throwable] = + suite("Rebalance safe commit consumer streaming")( + test("plainStream emits messages for a topic subscription") { + val kvs = (1 to 5).toList.map(i => (s"key$i", s"msg$i")) + for { + topic <- randomTopic + client <- randomClient + group <- randomGroup + + _ <- produceMany(topic, kvs) + + records <- Consumer + .plainStream(Subscription.Topics(Set(topic)), Serde.string, Serde.string) + .take(5) + .runCollect + .provideSomeLayer[Kafka](consumer(client, Some(group), rebalanceSafeCommits = true)) + kvOut = records.map(r => (r.record.key, r.record.value)).toList + } yield assert(kvOut)(equalTo(kvs)) + }, + test("chunk sizes") { + val kvs = (1 to 100).toList.map(i => (s"key$i", s"msg$i")) + for { + topic <- randomTopic + client <- randomClient + group <- randomGroup + + _ <- produceMany(topic, kvs) + + sizes <- Consumer + .plainStream(Subscription.Topics(Set(topic)), Serde.string, Serde.string) + .take(100) + .mapChunks(c => Chunk(c.size)) + .runCollect + .provideSomeLayer[Kafka](consumer(client, Some(group), rebalanceSafeCommits = true)) + } yield assert(sizes)(forall(isGreaterThan(1))) + }, + test("Manual subscription without groupId works properly") { + val kvs = (1 to 5).toList.map(i => (s"key$i", s"msg$i")) + for { + topic <- randomTopic + client <- randomClient + + _ <- produceMany(topic, kvs) + + records <- + Consumer + .plainStream( + Subscription.Manual(Set(new org.apache.kafka.common.TopicPartition(topic, 0))), + Serde.string, + Serde.string + ) + .take(5) + .runCollect + .provideSomeLayer[Kafka](consumer(clientId = client, rebalanceSafeCommits = true)) + kvOut = records.map(r => (r.record.key, r.record.value)).toList + } yield assert(kvOut)(equalTo(kvs)) + }, + test("Consuming+provideCustomLayer") { + val kvs = (1 to 100).toList.map(i => (s"key$i", s"msg$i")) + for { + topic <- randomTopic + client <- randomClient + group <- randomGroup + + _ <- produceMany(topic, kvs) + + records <- Consumer + .plainStream(Subscription.Topics(Set(topic)), Serde.string, Serde.string) + .take(100) + .runCollect + .provideSomeLayer[Kafka](consumer(client, Some(group), rebalanceSafeCommits = true)) + kvOut = records.map(r => (r.record.key, r.record.value)).toList + } yield assert(kvOut)(equalTo(kvs)) + }, + test("plainStream emits messages for a pattern subscription") { + val kvs = (1 to 5).toList.map(i => (s"key$i", s"msg$i")) + for { + client <- randomClient + group <- randomGroup + + _ <- produceMany("pattern150", kvs) + records <- Consumer + .plainStream(Subscription.Pattern("pattern[0-9]+".r), Serde.string, Serde.string) + .take(5) + .runCollect + .provideSomeLayer[Kafka](consumer(client, Some(group), rebalanceSafeCommits = true)) + kvOut = records.map(r => (r.record.key, r.record.value)).toList + } yield assert(kvOut)(equalTo(kvs)) + }, + test("receive only messages from the subscribed topic-partition when creating a manual subscription") { + val nrPartitions = 5 + + for { + client <- randomClient + group <- randomGroup + topic <- randomTopic + + _ <- ZIO.succeed(EmbeddedKafka.createCustomTopic(topic, partitions = nrPartitions)) + _ <- ZIO.foreachDiscard(1 to nrPartitions) { i => + produceMany(topic, partition = i % nrPartitions, kvs = List(s"key$i" -> s"msg$i")) + } + record <- Consumer + .plainStream(Subscription.manual(topic, partition = 2), Serde.string, Serde.string) + .take(1) + .runHead + .provideSomeLayer[Kafka](consumer(client, Some(group), rebalanceSafeCommits = true)) + kvOut = record.map(r => (r.record.key, r.record.value)) + } yield assert(kvOut)(isSome(equalTo("key2" -> "msg2"))) + }, + test("receive from the right offset when creating a manual subscription with manual seeking") { + val nrPartitions = 5 + + val manualOffsetSeek = 3 + + for { + client <- randomClient + group <- randomGroup + topic <- randomTopic + + _ <- ZIO.succeed(EmbeddedKafka.createCustomTopic(topic, partitions = nrPartitions)) + _ <- ZIO.foreachDiscard(1 to nrPartitions) { i => + produceMany(topic, partition = i % nrPartitions, kvs = (0 to 9).map(j => s"key$i-$j" -> s"msg$i-$j")) + } + offsetRetrieval = OffsetRetrieval.Manual(tps => ZIO.attempt(tps.map(_ -> manualOffsetSeek.toLong).toMap)) + record <- Consumer + .plainStream(Subscription.manual(topic, partition = 2), Serde.string, Serde.string) + .take(1) + .runHead + .provideSomeLayer[Kafka]( + consumer(client, Some(group), offsetRetrieval = offsetRetrieval, rebalanceSafeCommits = true) + ) + kvOut = record.map(r => (r.record.key, r.record.value)) + } yield assert(kvOut)(isSome(equalTo("key2-3" -> "msg2-3"))) + }, + test("restart from the committed position") { + val data = (1 to 10).toList.map(i => s"key$i" -> s"msg$i") + for { + topic <- randomTopic + group <- randomGroup + first <- randomClient + second <- randomClient + + _ <- produceMany(topic, 0, data) + firstResults <- for { + results <- Consumer + .partitionedStream(Subscription.Topics(Set(topic)), Serde.string, Serde.string) + .filter(_._1 == new TopicPartition(topic, 0)) + .flatMap(_._2) + .take(5) + .transduce(ZSink.collectAllN[CommittableRecord[String, String]](5)) + .mapConcatZIO { committableRecords => + val records = committableRecords.map(_.record) + val offsetBatch = OffsetBatch(committableRecords.map(_.offset)) + + offsetBatch.commit.as(records) + } + .runCollect + .provideSomeLayer[Kafka]( + consumer(first, Some(group), rebalanceSafeCommits = true) + ) + } yield results + secondResults <- for { + results <- + Consumer + .partitionedStream(Subscription.Topics(Set(topic)), Serde.string, Serde.string) + .flatMap(_._2) + .take(5) + .transduce(ZSink.collectAllN[CommittableRecord[String, String]](20)) + .mapConcatZIO { committableRecords => + val records = committableRecords.map(_.record) + val offsetBatch = OffsetBatch(committableRecords.map(_.offset)) + + offsetBatch.commit.as(records) + } + .runCollect + .provideSomeLayer[Kafka]( + consumer(second, Some(group), rebalanceSafeCommits = true) + ) + } yield results + } yield assert((firstResults ++ secondResults).map(rec => rec.key() -> rec.value()).toList)(equalTo(data)) + }, + test("partitionedStream emits messages for each partition in a separate stream") { + val nrMessages = 50 + val nrPartitions = 5 + + for { + // Produce messages on several partitions + topic <- randomTopic + group <- randomGroup + client <- randomClient + + _ <- ZIO.attempt(EmbeddedKafka.createCustomTopic(topic, partitions = nrPartitions)) + _ <- ZIO.foreachDiscard(1 to nrMessages) { i => + produceMany(topic, partition = i % nrPartitions, kvs = List(s"key$i" -> s"msg$i")) + } + + // Consume messages + messagesReceived <- ZIO.foreach((0 until nrPartitions).toList)(i => Ref.make[Int](0).map(i -> _)).map(_.toMap) + subscription = Subscription.topics(topic) + fib <- Consumer + .partitionedStream(subscription, Serde.string, Serde.string) + .flatMapPar(nrPartitions) { case (_, partition) => + partition + .mapZIO(record => messagesReceived(record.partition).update(_ + 1).as(record)) + } + .take(nrMessages.toLong) + .runDrain + .provideSomeLayer[Kafka](consumer(client, Some(group), rebalanceSafeCommits = true)) + .fork + _ <- fib.join + messagesPerPartition <- ZIO.foreach(messagesReceived.values)(_.get) + + } yield assert(messagesPerPartition)(forall(equalTo(nrMessages / nrPartitions))) + }, + test("fail when the consuming effect produces a failure") { + val nrMessages = 10 + val messages = (1 to nrMessages).toList.map(i => (s"key$i", s"msg$i")) + + for { + topic <- randomTopic + group <- randomGroup + client <- randomClient + subscription = Subscription.Topics(Set(topic)) + _ <- produceMany(topic, messages) + consumeResult <- consumeWithStrings(client, Some(group), subscription) { _ => + ZIO.die(new IllegalArgumentException("consumeWith failure")) + }.exit + } yield consumeResult.foldExit[TestResult]( + _ => assertCompletes, + _ => assert("result")(equalTo("Expected consumeWith to fail")) + ) + } @@ timeout(10.seconds), + test("stopConsumption must end streams while still processing commits") { + for { + topic <- randomTopic + group <- randomGroup + client <- randomClient + + keepProducing <- Ref.make(true) + _ <- produceOne(topic, "key", "value").repeatWhileZIO(_ => keepProducing.get).fork + _ <- Consumer + .plainStream(Subscription.topics(topic), Serde.string, Serde.string) + .zipWithIndex + .tap { case (record, idx) => + (Consumer.stopConsumption <* ZIO.logDebug("Stopped consumption")).when(idx == 3) *> + record.offset.commit <* ZIO.logDebug(s"Committed $idx") + } + .tap { case (_, idx) => ZIO.logDebug(s"Consumed $idx") } + .runDrain + .tap(_ => ZIO.logDebug("Stream completed")) + .provideSomeLayer[Kafka]( + consumer(client, Some(group), rebalanceSafeCommits = true) + ) *> keepProducing + .set(false) + } yield assertCompletes + }, + test("process outstanding commits after a graceful shutdown") { + val kvs = (1 to 100).toList.map(i => (s"key$i", s"msg$i")) + val topic = "test-outstanding-commits" + for { + group <- randomGroup + client <- randomClient + _ <- produceMany(topic, kvs) + messagesReceived <- Ref.make[Int](0) + offset <- (Consumer + .plainStream(Subscription.topics(topic), Serde.string, Serde.string) + .mapConcatZIO { record => + for { + nr <- messagesReceived.updateAndGet(_ + 1) + _ <- Consumer.stopConsumption.when(nr == 10) + } yield if (nr < 10) Seq(record.offset) else Seq.empty + } + .transduce(Consumer.offsetBatches) + .mapZIO(_.commit) + .runDrain *> + Consumer.committed(Set(new TopicPartition(topic, 0))).map(_.values.head)) + .provideSomeLayer[Kafka](consumer(client, Some(group), rebalanceSafeCommits = true)) + } yield assert(offset.map(_.offset))(isSome(equalTo(9L))) + }, + test("offset batching collects the latest offset for all partitions") { + val nrMessages = 50 + val nrPartitions = 5 + + for { + // Produce messages on several partitions + topic <- randomTopic + group <- randomGroup + client <- randomClient + _ <- ZIO.attempt(EmbeddedKafka.createCustomTopic(topic, partitions = nrPartitions)) + _ <- ZIO.foreachDiscard(1 to nrMessages) { i => + produceMany(topic, partition = i % nrPartitions, kvs = List(s"key$i" -> s"msg$i")) + } + + // Consume messages + subscription = Subscription.topics(topic) + offsets <- (Consumer + .partitionedStream(subscription, Serde.string, Serde.string) + .flatMapPar(nrPartitions)(_._2.map(_.offset)) + .take(nrMessages.toLong) + .transduce(Consumer.offsetBatches) + .take(1) + .mapZIO(_.commit) + .runDrain *> + Consumer.committed((0 until nrPartitions).map(new TopicPartition(topic, _)).toSet)) + .provideSomeLayer[Kafka](consumer(client, Some(group), rebalanceSafeCommits = true)) + } yield assert(offsets.values.map(_.map(_.offset)))(forall(isSome(equalTo(nrMessages.toLong / nrPartitions)))) + }, + test("handle rebalancing by completing topic-partition streams") { + val nrMessages = 50 + val nrPartitions = 6 + + for { + // Produce messages on several partitions + topic <- randomTopic + group <- randomGroup + client1 <- randomClient + client2 <- randomClient + + _ <- ZIO.attempt(EmbeddedKafka.createCustomTopic(topic, partitions = nrPartitions)) + _ <- ZIO.foreachDiscard(1 to nrMessages) { i => + produceMany(topic, partition = i % nrPartitions, kvs = List(s"key$i" -> s"msg$i")) + } + + // Consume messages + subscription = Subscription.topics(topic) + consumer1 <- Consumer + .partitionedStream(subscription, Serde.string, Serde.string) + .flatMapPar(nrPartitions) { case (tp, partition) => + ZStream + .fromZIO(partition.runDrain) + .as(tp) + } + .take(nrPartitions.toLong / 2) + .runDrain + .provideSomeLayer[Kafka](consumer(client1, Some(group), rebalanceSafeCommits = true)) + .fork + _ <- Live.live(ZIO.sleep(5.seconds)) + consumer2 <- Consumer + .partitionedStream(subscription, Serde.string, Serde.string) + .take(nrPartitions.toLong / 2) + .runDrain + .provideSomeLayer[Kafka](consumer(client2, Some(group), rebalanceSafeCommits = true)) + .fork + _ <- consumer1.join + _ <- consumer2.join + } yield assertCompletes + } @@ ignore /* What does this actually test? */, + test("produce diagnostic events when rebalancing") { + val partitionCount = 6 + + def committingConsumer(subscription: Subscription, group: String, client: String, diagnostics: Diagnostics) + : ZIO[Kafka, Throwable, Unit] = Consumer + .partitionedStream(subscription, Serde.string, Serde.string) + .flatMapPar(Int.MaxValue) { case (_, partitionStream) => + ZStream.fromZIO { + partitionStream.mapChunksZIO { records => + val offsets = OffsetBatch(records.map(_.offset)) + ZIO.logDebug( + s"Committing offsets: ${offsets.offsets.map { case (tp, offset) => s"${tp.partition()}-$offset" }}" + ) *> + offsets.commit.as(Chunk.empty) /* <* ZIO.logDebug( + s"Committed offsets: ${offsets.offsets.map { case (tp, offset) => s"${tp.partition()}-$offset" }}" + )*/ + }.runDrain.unit + } + } + .runDrain + .provideSomeLayer[Kafka]( + consumer(client, Some(group), diagnostics = diagnostics, rebalanceSafeCommits = true) + ) + + ZIO.scoped { + for { + diagnostics <- Diagnostics.SlidingQueue.make() + // Produce messages on several partitions + topic <- randomTopic + group <- randomGroup + client1 <- randomClient + client2 <- randomClient + + // Create topic with many partitions + _ <- ZIO.attempt(EmbeddedKafka.createCustomTopic(topic, partitions = partitionCount)) + + // Continuously produce messages + producerFiber <- scheduledProducer(topic, partitionCount, Schedule.fixed(10.millis)).runDrain.fork + + // Collect diagnostics events + endDiagnostics <- Promise.make[Nothing, Unit] + diagnosticStream <- ZStream + .fromQueue(diagnostics.queue) + .haltWhen(endDiagnostics) + .collect { case rebalance: DiagnosticEvent.Rebalance => rebalance } + .runCollect + .fork + + // Consume messages + subscription = Subscription.topics(topic) + consumer1 <- committingConsumer(subscription, group, client1, diagnostics).fork + _ <- ZIO.sleep(2.seconds) + _ <- ZIO.logDebug(s"starting client 2") + consumer2 <- committingConsumer(subscription, group, client2, diagnostics).fork + _ <- ZIO.sleep(2.seconds) + _ <- ZIO.logDebug(s"interrupting fibers") + _ <- consumer1.interrupt + _ <- consumer2.interrupt + _ <- producerFiber.interrupt + _ <- endDiagnostics.succeed(()) + diagnosticEvents <- diagnosticStream.join + } yield assertTrue(diagnosticEvents.size >= 2) + } + }, + test("support manual seeking") { + val nrRecords = 10 + val data = (1 to nrRecords).toList.map(i => s"key$i" -> s"msg$i") + val manualOffsetSeek = 3 + + for { + topic <- randomTopic + group <- randomGroup + client1 <- randomClient + client2 <- randomClient + + _ <- produceMany(topic, 0, data) + // Consume 5 records to have the offset committed at 5 + _ <- Consumer + .plainStream(Subscription.topics(topic), Serde.string, Serde.string) + .take(5) + .transduce(ZSink.collectAllN[CommittableRecord[String, String]](5)) + .mapConcatZIO { committableRecords => + val records = committableRecords.map(_.record) + val offsetBatch = OffsetBatch(committableRecords.map(_.offset)) + + offsetBatch.commit.as(records) + } + .runCollect + .provideSomeLayer[Kafka](consumer(client1, Some(group), rebalanceSafeCommits = true)) + // Start a new consumer with manual offset before the committed offset + offsetRetrieval = OffsetRetrieval.Manual(tps => ZIO.attempt(tps.map(_ -> manualOffsetSeek.toLong).toMap)) + secondResults <- + Consumer + .plainStream(Subscription.topics(topic), Serde.string, Serde.string) + .take(nrRecords.toLong - manualOffsetSeek) + .map(_.record) + .runCollect + .provideSomeLayer[Kafka]( + consumer(client2, Some(group), offsetRetrieval = offsetRetrieval, rebalanceSafeCommits = true) + ) + // Check that we only got the records starting from the manually seek'd offset + } yield assert(secondResults.map(rec => rec.key() -> rec.value()).toList)( + equalTo(data.drop(manualOffsetSeek)) + ) + }, + test("commit offsets for all consumed messages") { + // + // TODO: find out whether the test description is wrong (it doesn't seem to commit), or the test is wrong + // + val nrMessages = 50 + val messages = (1 to nrMessages).toList.map(i => (s"key$i", s"msg$i")) + + def consumeIt( + client: String, + group: String, + subscription: Subscription, + messagesReceived: Ref[List[(String, String)]], + done: Promise[Nothing, Unit] + ) = + consumeWithStrings(client, Some(group), subscription)({ record => + for { + messagesSoFar <- messagesReceived.updateAndGet(_ :+ (record.key() -> record.value())) + _ <- ZIO.when(messagesSoFar.size == nrMessages)(done.succeed(())) + } yield () + }).fork + + for { + topic <- randomTopic + group <- randomGroup + client <- randomClient + subscription = Subscription.Topics(Set(topic)) + + done <- Promise.make[Nothing, Unit] + messagesReceived <- Ref.make(List.empty[(String, String)]) + _ <- produceMany(topic, messages) + fib <- consumeIt(client, group, subscription, messagesReceived, done) + _ <- + done.await *> Live + .live( + ZIO.sleep(3.seconds) + ) // TODO the sleep is necessary for the outstanding commits to be flushed. Maybe we can fix that another way + _ <- fib.interrupt + _ <- produceOne(topic, "key-new", "msg-new") + newMessage <- Consumer + .plainStream(subscription, Serde.string, Serde.string) + .take(1) + .map(r => (r.record.key(), r.record.value())) + .run(ZSink.collectAll[(String, String)]) + .map(_.head) + .orDie + .provideSomeLayer[Kafka](consumer(client, Some(group), rebalanceSafeCommits = true)) + consumedMessages <- messagesReceived.get + } yield assert(consumedMessages)(contains(newMessage).negate) + }, + test("partitions for topic doesn't fail if doesn't exist") { + for { + topic <- randomTopic + group <- randomGroup + client <- randomClient + partitions <- Consumer + .partitionsFor(topic) + .provideSomeLayer[Kafka]( + consumer(client, Some(group), allowAutoCreateTopics = false, rebalanceSafeCommits = true) + ) + } yield assert(partitions)(isEmpty) + }, + // Test backported from fs2-kafka: https://github.com/fd4s/fs2-kafka/blob/1bd0c1f3d46b543277fce1a3cc743154c162ef09/modules/core/src/test/scala/fs2/kafka/KafkaConsumerSpec.scala#L592 + test("should close old stream during rebalancing under load") { + val nrMessages = 50000 + val nrPartitions = 3 + val partitions = (0 until nrPartitions).toList + val waitTimeout = 15.seconds + + final case class ValidAssignmentsNotSeen(instances: Set[Int], st: String) + extends RuntimeException(s"Valid assignment not seen for instances $instances: $st") + + def run(instance: Int, topic: String, allAssignments: Ref[Map[Int, List[Int]]]) = + ZIO.logAnnotate("consumer", instance.toString) { + val subscription = Subscription.topics(topic) + Consumer + .partitionedStream(subscription, Serde.string, Serde.string) + .flatMapPar(Int.MaxValue) { case (tp, partStream) => + val registerAssignment = ZStream.logInfo(s"Registering partition ${tp.partition()}") *> + ZStream.fromZIO { + allAssignments.update { current => + current.get(instance) match { + case Some(currentList) => current.updated(instance, currentList :+ tp.partition()) + case None => current.updated(instance, List(tp.partition())) + } + } + } + val deregisterAssignment = ZStream.logInfo(s"Deregistering partition ${tp.partition()}") *> + ZStream.finalizer { + allAssignments.update { current => + current.get(instance) match { + case Some(currentList) => + val idx = currentList.indexOf(tp.partition()) + if (idx != -1) current.updated(instance, currentList.patch(idx, Nil, 1)) + else current + case None => current + } + } + } + + (registerAssignment *> deregisterAssignment *> partStream).drain + } + .runDrain + } + + // Check every 30 millis (for at most 15 seconds) that the following condition holds: + // - all instances are assigned to a partition, + // - each instance has a partition assigned, + // - all partitions are assigned. + // Fail when the condition is not observed. + def checkAssignments(allAssignments: Ref[Map[Int, List[Int]]])(instances: Set[Int]) = + ZStream + .repeatZIOWithSchedule(allAssignments.get, Schedule.spaced(30.millis)) + .filter { state => + state.keySet == instances && + instances.forall(instance => state.get(instance).exists(_.nonEmpty)) && + state.values.toList.flatten.sorted == partitions + } + .runHead + .timeout(waitTimeout) + .someOrElseZIO(allAssignments.get.map(as => ValidAssignmentsNotSeen(instances, as.toString)).flip) + + def createConsumer(client: String, group: String): ZLayer[Kafka, Throwable, Consumer] = + consumer( + client, + Some(group), + offsetRetrieval = OffsetRetrieval.Auto(reset = AutoOffsetStrategy.Earliest), + rebalanceSafeCommits = true + ) + + for { + // Produce messages on several partitions + topic <- randomTopic + group <- randomGroup + client1 <- randomThing("client-1") + client2 <- randomThing("client-2") + client3 <- randomThing("client-3") + _ <- ZIO.fromTry(EmbeddedKafka.createCustomTopic(topic, partitions = nrPartitions)) + _ <- produceMany(topic, kvs = (0 until nrMessages).map(n => s"key-$n" -> s"value->$n")) + allAssignments <- Ref.make(Map.empty[Int, List[Int]]) + check = checkAssignments(allAssignments)(_) + fiber0 <- run(0, topic, allAssignments) + .provideSomeLayer[Kafka](createConsumer(client1, group)) + .fork + _ <- check(Set(0)) + fiber1 <- run(1, topic, allAssignments) + .provideSomeLayer[Kafka](createConsumer(client2, group)) + .fork + _ <- check(Set(0, 1)) + fiber2 <- run(2, topic, allAssignments) + .provideSomeLayer[Kafka](createConsumer(client3, group)) + .fork + _ <- check(Set(0, 1, 2)) + _ <- fiber2.interrupt + _ <- allAssignments.update(_ - 2) + _ <- check(Set(0, 1)) + _ <- fiber1.interrupt + _ <- allAssignments.update(_ - 1) + _ <- check(Set(0)) + _ <- fiber0.interrupt + } yield assertCompletes + }, + test("restartStreamsOnRebalancing mode closes all partition streams") { + val nrPartitions = 5 + val nrMessages = 100 + + for { + // Produce messages on several partitions + _ <- ZIO.logInfo("Starting test") + topic <- randomTopic + group <- randomGroup + client1 <- randomClient + client2 <- randomClient + + _ <- ZIO.fromTry(EmbeddedKafka.createCustomTopic(topic, partitions = nrPartitions)) + _ <- ZIO.foreachDiscard(1 to nrMessages) { i => + produceMany(topic, partition = i % nrPartitions, kvs = List(s"key$i" -> s"msg$i")) + } + + // Consume messages + messagesReceived <- + ZIO.foreach((0 until nrPartitions).toList)(i => Ref.make[Int](0).map(i -> _)).map(_.toMap) + drainCount <- Ref.make(0) + subscription = Subscription.topics(topic) + fib <- ZIO + .logAnnotate("consumer", "1") { + Consumer + .partitionedAssignmentStream(subscription, Serde.string, Serde.string) + .rechunk(1) + .mapZIO { partitions => + ZIO.logInfo(s"Got partition assignment ${partitions.map(_._1).mkString(",")}") *> + ZStream + .fromIterable(partitions) + .flatMapPar(Int.MaxValue) { case (tp, partitionStream) => + ZStream.finalizer(ZIO.logInfo(s"TP ${tp.toString} finalizer")) *> + partitionStream.mapChunksZIO { records => + OffsetBatch(records.map(_.offset)).commit *> messagesReceived(tp.partition) + .update(_ + records.size) + .as(records) + } + } + .runDrain + } + .mapZIO(_ => + drainCount.updateAndGet(_ + 1).flatMap { + case 2 => ZIO.logInfo("Stopping consumption") *> Consumer.stopConsumption + // 1: when consumer on fib2 starts + // 2: when consumer on fib2 stops, end of test + case _ => ZIO.unit + } + ) + .runDrain + .provideSomeLayer[Kafka]( + consumer( + client1, + Some(group), + clientInstanceId = Some("consumer1"), + restartStreamOnRebalancing = true, + rebalanceSafeCommits = true, + properties = Map(ConsumerConfig.MAX_POLL_RECORDS_CONFIG -> "10") + ) + ) + } + .fork + // fib is running, consuming all the published messages from all partitions. + // Waiting until it recorded all messages + _ <- ZIO + .foreach(messagesReceived.values)(_.get) + .map(_.sum) + .repeat(Schedule.recurUntil((n: Int) => n == nrMessages) && Schedule.fixed(100.millis)) + + // Starting a new consumer that will stop after receiving 20 messages, + // causing two rebalancing events for fib1 consumers on start and stop + fib2 <- ZIO + .logAnnotate("consumer", "2") { + Consumer + .plainStream(subscription, Serde.string, Serde.string) + .take(20) + .runDrain + .provideSomeLayer[Kafka]( + consumer( + client2, + Some(group), + clientInstanceId = Some("consumer2"), + rebalanceSafeCommits = true, + properties = Map(ConsumerConfig.MAX_POLL_RECORDS_CONFIG -> "10") + ) + ) + } + .fork + + // Waiting until fib1's partition streams got restarted because of the rebalancing + _ <- drainCount.get.repeat(Schedule.recurUntil((n: Int) => n == 1) && Schedule.fixed(100.millis)) + _ <- ZIO.logInfo("Consumer 1 finished rebalancing") + + // All messages processed, the partition streams of fib are still running. + // Saving the values and resetting the counters + messagesReceived0 <- + ZIO + .foreach((0 until nrPartitions).toList) { i => + messagesReceived(i).get.flatMap { v => + Ref.make(v).map(r => i -> r) + } <* messagesReceived(i).set(0) + } + .map(_.toMap) + + // Publishing another N messages - now they will be distributed among the two consumers until + // fib2 stops after 20 messages + _ <- ZIO.foreachDiscard((nrMessages + 1) to (2 * nrMessages)) { i => + produceMany(topic, partition = i % nrPartitions, kvs = List(s"key$i" -> s"msg$i")) + } + _ <- fib2.join + _ <- ZIO.logInfo("Consumer 2 done") + _ <- fib.join + _ <- ZIO.logInfo("Consumer 1 done") + // fib2 terminates after 20 messages, fib terminates after fib2 because of the rebalancing (drainCount==2) + messagesPerPartition0 <- + ZIO.foreach(messagesReceived0.values)(_.get) // counts from the first N messages (single consumer) + messagesPerPartition <- + ZIO.foreach(messagesReceived.values)(_.get) // counts from fib after the second consumer joined + + // The first set must contain all the produced messages + // The second set must have at least one and maximum N-20 (because fib2 stops after consuming 20) - + // the exact count cannot be known because fib2's termination triggers fib1's rebalancing asynchronously. + } yield assert(messagesPerPartition0)(forall(equalTo(nrMessages / nrPartitions))) && + assert(messagesPerPartition.view.sum)(isGreaterThan(0) && isLessThanEqualTo(nrMessages - 20)) + } @@ TestAspect.nonFlaky(3), + test("handles RebalanceInProgressExceptions transparently") { + val nrPartitions = 5 + val nrMessages = 10000 + + def customConsumer(clientId: String, groupId: Option[String]) = + ZLayer( + consumerSettings( + clientId = clientId, + groupId = groupId, + clientInstanceId = None, + rebalanceSafeCommits = true + ).map( + _.withProperties( + ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG -> classOf[CooperativeStickyAssignor].getName + ) + .withPollTimeout(500.millis) + ) + ) ++ ZLayer.succeed(Diagnostics.NoOp) >>> Consumer.live + + for { + // Produce messages on several partitions + topic <- randomTopic + group <- randomGroup + + _ <- ZIO.fromTry(EmbeddedKafka.createCustomTopic(topic, partitions = nrPartitions)) + _ <- ZIO + .foreachDiscard(1 to nrMessages) { i => + produceMany(topic, partition = i % nrPartitions, kvs = List(s"key$i" -> s"msg$i")) + } + .forkScoped + + // Consume messages + messagesReceivedConsumer1 <- Ref.make[Int](0) + messagesReceivedConsumer2 <- Ref.make[Int](0) + drainCount <- Ref.make(0) + subscription = Subscription.topics(topic) + stopConsumer1 <- Promise.make[Nothing, Unit] + fib <- + ZIO + .logAnnotate("consumer", "1") { + Consumer + .partitionedAssignmentStream(subscription, Serde.string, Serde.string) + .rechunk(1) + .mapZIOPar(16) { partitions => + ZIO.logInfo(s"Consumer 1 got new partition assignment: ${partitions.map(_._1.toString)}") *> + ZStream + .fromIterable(partitions.map(_._2)) + .flatMapPar(Int.MaxValue)(s => s) + .mapZIO(record => messagesReceivedConsumer1.update(_ + 1).as(record)) + .map(_.offset) + .aggregateAsync(Consumer.offsetBatches) + .mapZIO(offsetBatch => offsetBatch.commit) + .runDrain + } + .mapZIO(_ => drainCount.updateAndGet(_ + 1)) + .interruptWhen(stopConsumer1.await) + .runDrain + .provideSomeLayer[Kafka]( + customConsumer("consumer1", Some(group)) ++ Scope.default + ) + .tapError(e => ZIO.logErrorCause(e.getMessage, Cause.fail(e))) + } + .forkScoped + + _ <- messagesReceivedConsumer1.get + .repeat(Schedule.recurUntil((n: Int) => n >= 20) && Schedule.fixed(100.millis)) + _ <- ZIO.logInfo("Starting consumer 2") + + fib2 <- + ZIO + .logAnnotate("consumer", "2") { + Consumer + .plainStream(subscription, Serde.string, Serde.string) + .mapZIO(record => messagesReceivedConsumer2.update(_ + 1).as(record)) + .map(_.offset) + .aggregateAsync(Consumer.offsetBatches) + .mapZIO(offsetBatch => offsetBatch.commit) + .runDrain + .provideSomeLayer[Kafka]( + customConsumer("consumer2", Some(group)) + ) + .tapError(e => ZIO.logErrorCause("Error in consumer 2", Cause.fail(e))) + } + .forkScoped + + _ <- messagesReceivedConsumer2.get + .repeat(Schedule.recurUntil((n: Int) => n >= 20) && Schedule.fixed(100.millis)) + _ <- stopConsumer1.succeed(()) + _ <- fib.join + _ <- fib2.interrupt + } yield assertCompletes + }, + suite("does not process messages twice for transactional producer, even when rebalancing")({ + + /** + * Outline of this test: + * - A producer generates some messages on topic A, + * - a transactional consumer/producer pair (copier1) reads these and copies them to topic B transactionally, + * - after a few messages we start a second transactional consumer/producer pair (copier2) that does the same + * (in the same consumer group) this triggers a rebalance, + * - produce some more messages to topic A, + * - a consumer that empties topic B, + * - when enough messages have been received, the copiers are interrupted. + * + * We will assert that the produced messages to topic A correspond exactly with the read messages from topic B. + */ + def testForPartitionAssignmentStrategy[T <: ConsumerPartitionAssignor: ClassTag] = + test(implicitly[ClassTag[T]].runtimeClass.getName) { + val partitionCount = 6 + val messageCount = 5000 + val allMessages = (1 to messageCount).map(i => s"$i" -> f"msg$i%06d") + val (messagesBeforeRebalance, messagesAfterRebalance) = allMessages.splitAt(messageCount / 2) + + def transactionalRebalanceListener(streamCompleteOnRebalanceRef: Ref[Option[Promise[Nothing, Unit]]]) = + RebalanceListener( + onAssigned = (_, _) => ZIO.unit, + onRevoked = (_, _) => + streamCompleteOnRebalanceRef.get.flatMap { + case Some(p) => + ZIO.logWarning("onRevoked, awaiting stream completion") *> + p.await.timeoutFail(new InterruptedException("Timed out waiting stream to complete"))(1.minute) + case None => ZIO.unit + }, + onLost = (_, _) => ZIO.logWarning("Lost some partitions") + ) + + def makeCopyingTransactionalConsumer( + name: String, + consumerGroupId: String, + clientId: String, + fromTopic: String, + toTopic: String, + tProducer: TransactionalProducer, + consumerCreated: Promise[Nothing, Unit] + ): ZIO[Kafka, Throwable, Unit] = + ZIO.logAnnotate("consumer", name) { + for { + consumedMessagesCounter <- Ref.make(0) + _ <- consumedMessagesCounter.get + .flatMap(consumed => ZIO.logInfo(s"Consumed so far: $consumed")) + .repeat(Schedule.fixed(1.second)) + .fork + streamCompleteOnRebalanceRef <- Ref.make[Option[Promise[Nothing, Unit]]](None) + tConsumer <- + Consumer + .partitionedAssignmentStream(Subscription.topics(fromTopic), Serde.string, Serde.string) + .mapZIO { assignedPartitions => + for { + p <- Promise.make[Nothing, Unit] + _ <- streamCompleteOnRebalanceRef.set(Some(p)) + _ <- ZIO.logInfo(s"${assignedPartitions.size} partitions assigned") + _ <- consumerCreated.succeed(()) + partitionStreams = assignedPartitions.map(_._2) + s <- ZStream + .mergeAllUnbounded(64)(partitionStreams: _*) + .mapChunksZIO { records => + ZIO.scoped { + for { + t <- tProducer.createTransaction + _ <- t.produceChunkBatch( + records.map(r => new ProducerRecord(toTopic, r.key, r.value)), + Serde.string, + Serde.string, + OffsetBatch(records.map(_.offset)) + ) + _ <- consumedMessagesCounter.update(_ + records.size) + } yield Chunk.empty + }.uninterruptible + } + .runDrain + .ensuring { + for { + _ <- streamCompleteOnRebalanceRef.set(None) + _ <- p.succeed(()) + c <- consumedMessagesCounter.get + _ <- ZIO.logInfo(s"Consumed $c messages") + } yield () + } + } yield s + } + .runDrain + .provideSome[Kafka]( + transactionalConsumer( + clientId, + consumerGroupId, + offsetRetrieval = OffsetRetrieval.Auto(AutoOffsetStrategy.Earliest), + restartStreamOnRebalancing = true, + rebalanceSafeCommits = true, + properties = Map( + ConsumerConfig.PARTITION_ASSIGNMENT_STRATEGY_CONFIG -> + implicitly[ClassTag[T]].runtimeClass.getName, + ConsumerConfig.MAX_POLL_RECORDS_CONFIG -> "200" + ), + rebalanceListener = transactionalRebalanceListener(streamCompleteOnRebalanceRef) + ) + ) + .tapError(e => ZIO.logError(s"Error: $e")) <* ZIO.logInfo("Done") + } yield tConsumer + } + + for { + tProducerSettings <- transactionalProducerSettings + tProducer <- TransactionalProducer.make(tProducerSettings) + + topicA <- randomTopic + topicB <- randomTopic + _ <- ZIO.attempt(EmbeddedKafka.createCustomTopic(topicA, partitions = partitionCount)) + _ <- ZIO.attempt(EmbeddedKafka.createCustomTopic(topicB, partitions = partitionCount)) + + _ <- produceMany(topicA, messagesBeforeRebalance) + + copyingGroup <- randomGroup + + _ <- ZIO.logInfo("Starting copier 1") + copier1ClientId = copyingGroup + "-1" + copier1Created <- Promise.make[Nothing, Unit] + copier1 <- makeCopyingTransactionalConsumer( + "1", + copyingGroup, + copier1ClientId, + topicA, + topicB, + tProducer, + copier1Created + ).fork + _ <- copier1Created.await + + _ <- ZIO.logInfo("Starting copier 2") + copier2ClientId = copyingGroup + "-2" + copier2Created <- Promise.make[Nothing, Unit] + copier2 <- makeCopyingTransactionalConsumer( + "2", + copyingGroup, + copier2ClientId, + topicA, + topicB, + tProducer, + copier2Created + ).fork + _ <- ZIO.logInfo("Waiting for copier 2 to start") + _ <- copier2Created.await + + _ <- ZIO.logInfo("Producing some more messages") + _ <- produceMany(topicA, messagesAfterRebalance) + + _ <- ZIO.logInfo("Collecting messages on topic B") + groupB <- randomGroup + validatorClientId <- randomClient + messagesOnTopicB <- ZIO.logAnnotate("consumer", "validator") { + Consumer + .plainStream(Subscription.topics(topicB), Serde.string, Serde.string) + .map(_.value) + .timeout(5.seconds) + .runCollect + .provideSome[Kafka]( + transactionalConsumer( + validatorClientId, + groupB, + offsetRetrieval = OffsetRetrieval.Auto(AutoOffsetStrategy.Earliest), + properties = Map(ConsumerConfig.MAX_POLL_RECORDS_CONFIG -> "200") + ) + ) + .tapError(e => ZIO.logError(s"Error: $e")) <* ZIO.logInfo("Done") + } + _ <- copier1.interrupt + _ <- copier2.interrupt + messagesOnTopicBCount = messagesOnTopicB.size + messagesOnTopicBDistinctCount = messagesOnTopicB.distinct.size + } yield assertTrue(messageCount == messagesOnTopicBCount && messageCount == messagesOnTopicBDistinctCount) + } + + // Test for both default partition assignment strategies + Seq( + testForPartitionAssignmentStrategy[RangeAssignor], + testForPartitionAssignmentStrategy[CooperativeStickyAssignor] + ) + + }: _*) @@ TestAspect.nonFlaky(3), + test("running streams don't stall after a poll timeout") { + for { + topic <- randomTopic + clientId <- randomClient + _ <- ZIO.fromTry(EmbeddedKafka.createCustomTopic(topic)) + settings <- consumerSettings(clientId, rebalanceSafeCommits = true) + consumer <- Consumer.make(settings.withPollTimeout(50.millis)) + recordsOut <- Queue.unbounded[Unit] + _ <- produceOne(topic, "key1", "message1") + _ <- consumer + .plainStream(Subscription.manual(topic -> 0), Serde.string, Serde.string) + .debug + .foreach(_ => recordsOut.offer(())) + .forkScoped + _ <- recordsOut.take // first record consumed + _ <- ZIO.sleep(200.millis) // wait for `pollTimeout` + _ <- produceOne(topic, "key2", "message2") + _ <- recordsOut.take + } yield assertCompletes + } + ).provideSomeLayerShared[TestEnvironment & Kafka]( + producer ++ Scope.default ++ Runtime.removeDefaultLoggers ++ Runtime.addLogger(logger) + ) @@ withLiveClock @@ TestAspect.sequential @@ timeout(180.seconds) + + private def scheduledProducer[R]( + topic: String, + partitionCount: Int, + schedule: Schedule[R, Any, Long] + ): ZStream[R with Producer, Throwable, Chunk[RecordMetadata]] = + ZStream + .fromSchedule(schedule) + .mapZIO { i => + produceMany(topic, partition = i.toInt % partitionCount, kvs = List(s"key$i" -> s"msg$i")) + } + +} diff --git a/zio-kafka/src/main/scala/zio/kafka/consumer/Consumer.scala b/zio-kafka/src/main/scala/zio/kafka/consumer/Consumer.scala index d15572cacd..c844715e44 100644 --- a/zio-kafka/src/main/scala/zio/kafka/consumer/Consumer.scala +++ b/zio-kafka/src/main/scala/zio/kafka/consumer/Consumer.scala @@ -371,6 +371,7 @@ object Consumer { offsetRetrieval = settings.offsetRetrieval, userRebalanceListener = settings.rebalanceListener, restartStreamsOnRebalancing = settings.restartStreamOnRebalancing, + rebalanceSafeCommits = settings.rebalanceSafeCommits, runloopTimeout = settings.runloopTimeout ) subscriptions <- Ref.Synchronized.make(Set.empty[Subscription]) diff --git a/zio-kafka/src/main/scala/zio/kafka/consumer/ConsumerSettings.scala b/zio-kafka/src/main/scala/zio/kafka/consumer/ConsumerSettings.scala index 4ffe4d6fef..edd0a3f3d1 100644 --- a/zio-kafka/src/main/scala/zio/kafka/consumer/ConsumerSettings.scala +++ b/zio-kafka/src/main/scala/zio/kafka/consumer/ConsumerSettings.scala @@ -16,6 +16,29 @@ import zio.kafka.security.KafkaCredentialStore * @param restartStreamOnRebalancing * When `true` _all_ streams are restarted during a rebalance, including those streams that are not revoked. The * default is `false`. + * + * @param rebalanceSafeCommits + * Whether to hold up a rebalance until all offsets of consumed messages have been committed. The default is `false`, + * but the recommended value is `true` as it prevents duplicate messages. + * + * Use `false` _only_ when your streams does not do commits, or when it is okay to have messages processed twice + * concurrently and you cannot afford the performance hit during a rebalance. + * + * When `true`, messages consumed from revoked partitions must be committed before we allow the rebalance to continue. + * + * When a partition is revoked, consuming the messages will be taken over by another consumer. The other consumer will + * continue from the committed offset. It it therefore important that this consumer commits offsets of all consumed + * messages. Therefore, by holding up the rebalance until these commits are done, we ensure that the new consumer will + * start from the correct offset. + * + * During a rebalance no new messages can be received _for any stream_. Therefore, _all_ streams are deprived of new + * messages until the revoked streams are ready committing. + * + * When `false`, streams for revoked partitions may continue to run even though the rebalance is not held up. Any offset + * commits from these streams have a high chance of being delayed (commits are not possible during some phases of a + * rebalance). The consumer that takes over the partition will likely not see these delayed commits and will start from + * an earlier offset. The result is that some messages are processed twice and concurrently. + * * @param runloopTimeout * Internal timeout for each iteration of the command processing and polling loop, use to detect stalling. This should * be much larger than the pollTimeout and the time it takes to process chunks of records. If your consumer is not @@ -31,6 +54,7 @@ case class ConsumerSettings( offsetRetrieval: OffsetRetrieval = OffsetRetrieval.Auto(), rebalanceListener: RebalanceListener = RebalanceListener.noop, restartStreamOnRebalancing: Boolean = false, + rebalanceSafeCommits: Boolean = false, runloopTimeout: Duration = ConsumerSettings.defaultRunloopTimeout ) { private[this] def autoOffsetResetConfig: Map[String, String] = offsetRetrieval match { @@ -86,6 +110,9 @@ case class ConsumerSettings( def withRestartStreamOnRebalancing(value: Boolean): ConsumerSettings = copy(restartStreamOnRebalancing = value) + def withRebalanceSafeCommits(value: Boolean): ConsumerSettings = + copy(rebalanceSafeCommits = value) + def withCredentials(credentialsStore: KafkaCredentialStore): ConsumerSettings = withProperties(credentialsStore.properties) diff --git a/zio-kafka/src/main/scala/zio/kafka/consumer/Offset.scala b/zio-kafka/src/main/scala/zio/kafka/consumer/Offset.scala index 69f8b98423..99bd158d18 100644 --- a/zio-kafka/src/main/scala/zio/kafka/consumer/Offset.scala +++ b/zio-kafka/src/main/scala/zio/kafka/consumer/Offset.scala @@ -9,7 +9,7 @@ sealed trait Offset { def partition: Int def offset: Long def commit: Task[Unit] - def batch: OffsetBatch + def asOffsetBatch: OffsetBatch def consumerGroupMetadata: Option[ConsumerGroupMetadata] /** @@ -42,6 +42,6 @@ private final case class OffsetImpl( commitHandle: Map[TopicPartition, Long] => Task[Unit], consumerGroupMetadata: Option[ConsumerGroupMetadata] ) extends Offset { - def commit: Task[Unit] = commitHandle(Map(topicPartition -> offset)) - def batch: OffsetBatch = OffsetBatchImpl(Map(topicPartition -> offset), commitHandle, consumerGroupMetadata) + def commit: Task[Unit] = commitHandle(Map(topicPartition -> offset)) + def asOffsetBatch: OffsetBatch = OffsetBatchImpl(Map(topicPartition -> offset), commitHandle, consumerGroupMetadata) } diff --git a/zio-kafka/src/main/scala/zio/kafka/consumer/OffsetBatch.scala b/zio-kafka/src/main/scala/zio/kafka/consumer/OffsetBatch.scala index 3c0c0f6cc6..4eed55bccc 100644 --- a/zio-kafka/src/main/scala/zio/kafka/consumer/OffsetBatch.scala +++ b/zio-kafka/src/main/scala/zio/kafka/consumer/OffsetBatch.scala @@ -58,7 +58,7 @@ private final case class OffsetBatchImpl( case object EmptyOffsetBatch extends OffsetBatch { override val offsets: Map[TopicPartition, Long] = Map.empty override val commit: Task[Unit] = ZIO.unit - override def add(offset: Offset): OffsetBatch = offset.batch + override def add(offset: Offset): OffsetBatch = offset.asOffsetBatch override def merge(offset: Offset): OffsetBatch = add(offset) override def merge(offsets: OffsetBatch): OffsetBatch = offsets override def consumerGroupMetadata: Option[ConsumerGroupMetadata] = None diff --git a/zio-kafka/src/main/scala/zio/kafka/consumer/RebalanceListener.scala b/zio-kafka/src/main/scala/zio/kafka/consumer/RebalanceListener.scala index d302fb076f..3a00c04424 100644 --- a/zio-kafka/src/main/scala/zio/kafka/consumer/RebalanceListener.scala +++ b/zio-kafka/src/main/scala/zio/kafka/consumer/RebalanceListener.scala @@ -7,6 +7,9 @@ import scala.jdk.CollectionConverters._ /** * ZIO wrapper around Kafka's `ConsumerRebalanceListener` to work with Scala collection types and ZIO effects. + * + * Note that the given ZIO effects are executed directly on the Kafka poll thread. Fork and shift to another executor + * when this is not desired. */ final case class RebalanceListener( onAssigned: (Set[TopicPartition], RebalanceConsumer) => Task[Unit], diff --git a/zio-kafka/src/main/scala/zio/kafka/consumer/internal/ConsumerAccess.scala b/zio-kafka/src/main/scala/zio/kafka/consumer/internal/ConsumerAccess.scala index 4b9b01b1f5..897c206c3a 100644 --- a/zio-kafka/src/main/scala/zio/kafka/consumer/internal/ConsumerAccess.scala +++ b/zio-kafka/src/main/scala/zio/kafka/consumer/internal/ConsumerAccess.scala @@ -19,7 +19,7 @@ private[consumer] final class ConsumerAccess( def withConsumerZIO[R, A](f: ByteArrayKafkaConsumer => RIO[R, A]): RIO[R, A] = access.withPermit(withConsumerNoPermit(f)) - private[consumer] def withConsumerNoPermit[R, A]( + private def withConsumerNoPermit[R, A]( f: ByteArrayKafkaConsumer => RIO[R, A] ): RIO[R, A] = ZIO @@ -31,10 +31,17 @@ private[consumer] final class ConsumerAccess( .flatMap(fib => fib.join.onInterrupt(ZIO.succeed(consumer.wakeup()) *> fib.interrupt)) /** - * Do not use this method outside of the Runloop + * Use this method only from Runloop. */ private[internal] def runloopAccess[R, E, A](f: ByteArrayKafkaConsumer => ZIO[R, E, A]): ZIO[R, E, A] = access.withPermit(f(consumer)) + + /** + * Use this method ONLY from the rebalance listener. + */ + private[internal] def rebalanceListenerAccess[R, A](f: ByteArrayKafkaConsumer => RIO[R, A]): RIO[R, A] = + withConsumerNoPermit(f) + } private[consumer] object ConsumerAccess { diff --git a/zio-kafka/src/main/scala/zio/kafka/consumer/internal/PartitionStreamControl.scala b/zio-kafka/src/main/scala/zio/kafka/consumer/internal/PartitionStreamControl.scala index fdfa8d1107..342b542a0f 100644 --- a/zio-kafka/src/main/scala/zio/kafka/consumer/internal/PartitionStreamControl.scala +++ b/zio-kafka/src/main/scala/zio/kafka/consumer/internal/PartitionStreamControl.scala @@ -1,18 +1,22 @@ package zio.kafka.consumer.internal import org.apache.kafka.common.TopicPartition +import zio.kafka.consumer.Offset import zio.kafka.consumer.diagnostics.{ DiagnosticEvent, Diagnostics } import zio.kafka.consumer.internal.Runloop.Command.Request import zio.kafka.consumer.internal.Runloop.{ ByteArrayCommittableRecord, Command } import zio.stream.{ Take, ZStream } -import zio.{ Chunk, LogAnnotation, Promise, Queue, UIO, ZIO } +import zio.{ Chunk, LogAnnotation, Promise, Queue, Ref, UIO, ZIO } private[internal] final class PartitionStreamControl private ( val tp: TopicPartition, stream: ZStream[Any, Throwable, ByteArrayCommittableRecord], + val lastOffset: Ref[Option[Offset]], dataQueue: Queue[Take[Throwable, ByteArrayCommittableRecord]], - interruptPromise: Promise[Throwable, Unit], - completedPromise: Promise[Nothing, Unit] + startedPromise: Promise[Nothing, Unit], + endedPromise: Promise[Nothing, Unit], + completedPromise: Promise[Nothing, Unit], + interruptPromise: Promise[Throwable, Unit] ) { private val logAnnotate = ZIO.logAnnotate( @@ -22,7 +26,8 @@ private[internal] final class PartitionStreamControl private ( /** Offer new data for the stream to process. */ def offerRecords(data: Chunk[ByteArrayCommittableRecord]): ZIO[Any, Nothing, Unit] = - dataQueue.offer(Take.chunk(data)).unit + data.lastOption.fold(ZIO.unit)(last => lastOffset.set(Some(last.offset))) *> + dataQueue.offer(Take.chunk(data)).unit /** To be invoked when the partition was lost. */ def lost(): UIO[Boolean] = @@ -32,16 +37,40 @@ private[internal] final class PartitionStreamControl private ( def end(): ZIO[Any, Nothing, Unit] = logAnnotate { ZIO.logTrace(s"Partition ${tp.toString} ending") *> - dataQueue.offer(Take.end).unit + ZIO + .whenZIO(endedPromise.succeed(())) { + dataQueue.offer(Take.end) + } + .unit } - /** Returns true when the stream is done. */ - def isCompleted: ZIO[Any, Nothing, Boolean] = - completedPromise.isDone + /** Returns true when the stream accepts new data. */ + def acceptsData: ZIO[Any, Nothing, Boolean] = + for { + ended <- endedPromise.isDone + completed <- completedPromise.isDone + interrupted <- interruptPromise.isDone + } yield !(ended || completed || interrupted) - /** Returns true when the stream is running. */ - def isRunning: ZIO[Any, Nothing, Boolean] = - isCompleted.negate + /** Returns true when the stream is done (or when it didn't even start). */ + def isCompletedAfterStart: ZIO[Any, Nothing, Boolean] = + for { + started <- startedPromise.isDone + completed <- completedPromise.isDone + } yield !started || completed + + def lasOffsetIsIn(committedOffsets: Map[TopicPartition, Long]): ZIO[Any, Nothing, Boolean] = + lastOffset.get.map(_.forall(offset => committedOffsets.get(offset.topicPartition).exists(_ >= offset.offset))).tap { + result => + for { + lo <- lastOffset.get + _ <- ZIO.logDebug( + s"${tp.partition()} lastOffset: ${lo.map(_.offset.toString).getOrElse("-")} " + + s"in committedOffsets: ${committedOffsets.get(tp).map(_.toString).getOrElse("-")} " + + s"==> $result" + ) + } yield () + } val tpStream: (TopicPartition, ZStream[Any, Throwable, ByteArrayCommittableRecord]) = (tp, stream) @@ -56,8 +85,11 @@ private[internal] object PartitionStreamControl { ): ZIO[Any, Nothing, PartitionStreamControl] = for { _ <- ZIO.logTrace(s"Creating partition stream ${tp.toString}") - interruptionPromise <- Promise.make[Throwable, Unit] + startedPromise <- Promise.make[Nothing, Unit] + endedPromise <- Promise.make[Nothing, Unit] completedPromise <- Promise.make[Nothing, Unit] + interruptionPromise <- Promise.make[Throwable, Unit] + lastOffset <- Ref.make[Option[Offset]](None) dataQueue <- Queue.unbounded[Take[Throwable, ByteArrayCommittableRecord]] requestAndAwaitData = for { @@ -74,12 +106,22 @@ private[internal] object PartitionStreamControl { completedPromise.succeed(()) <* ZIO.logDebug(s"Partition stream ${tp.toString} has ended") ) *> + ZStream.fromZIO(startedPromise.succeed(())) *> ZStream.repeatZIOChunk { // First try to take all records that are available right now. // When no data is available, request more data and await its arrival. dataQueue.takeAll.flatMap(data => if (data.isEmpty) requestAndAwaitData else ZIO.succeed(data)) }.flattenTake .interruptWhen(interruptionPromise) - } yield new PartitionStreamControl(tp, stream, dataQueue, interruptionPromise, completedPromise) + } yield new PartitionStreamControl( + tp, + stream, + lastOffset, + dataQueue, + startedPromise, + endedPromise, + completedPromise, + interruptionPromise + ) } diff --git a/zio-kafka/src/main/scala/zio/kafka/consumer/internal/Runloop.scala b/zio-kafka/src/main/scala/zio/kafka/consumer/internal/Runloop.scala index 398239b55d..79f106e354 100644 --- a/zio-kafka/src/main/scala/zio/kafka/consumer/internal/Runloop.scala +++ b/zio-kafka/src/main/scala/zio/kafka/consumer/internal/Runloop.scala @@ -7,27 +7,81 @@ import zio._ import zio.kafka.consumer.Consumer.{ OffsetRetrieval, RunloopTimeout } import zio.kafka.consumer.diagnostics.{ DiagnosticEvent, Diagnostics } import zio.kafka.consumer.internal.ConsumerAccess.ByteArrayKafkaConsumer -import zio.kafka.consumer.internal.Runloop.Command.{ Commit, Request, StopAllStreams, StopRunloop } +import zio.kafka.consumer.internal.Runloop.Command.{ CommitAvailable, Request, StopAllStreams, StopRunloop } import zio.kafka.consumer.internal.Runloop._ import zio.kafka.consumer.{ CommittableRecord, RebalanceConsumer, RebalanceListener, Subscription } import zio.stream._ import java.util +import java.util.{ Map => JavaMap } +import scala.collection.mutable import scala.jdk.CollectionConverters._ +/** + * Runloop is the heart of the zio-kafka consumer. + * + * ## Stream management + * + * - When a partition gets assigned manually or by the broker, a new stream is started. + * - When a partition is revoked by the broker, the stream is ended. + * - When a partition is reported as lost, the stream is interrupted. + * + * ## Fetching data + * + * - Streams that needs data request this via a [[Request]] command to the command-queue. + * - Partitions for which no data is needed are paused. This backpressure prevents unnecessary buffering of data. + * + * ## Poll-loop + * + * The poll-loop continuously polls the broker for new data. Since polling is also needed for learning about partition + * assignment changes, or for completing commits, polling also continuous when no partitions are assigned, or when there + * are pending commits. + * + * When all streams stop processing, polling stops so that the broker can detect that this Kafka client is stalled. + * + * ## Rebalance listener + * + * The rebalance listener runs during a poll to the broker. It is used to track changes to partition assignments. + * Partitions can be assigned, revoked or lost. + * + * When a partition is revoked, the stream that handles it will be ended (signal the stream that no more data will be + * available). Even though there is no more data, the stream still needs to complete processing the messages it already + * got. + * + * ### Rebalance listener - Commit-loop + * + * When `rebalanceSafeCommits` is `true`, we wait for a revoked stream to commit offsets from within the rebalance + * listener callback. This gives the program a chance to commit offsets before its partition is given to another + * consumer. + * + * While the rebalance listener is waiting, new commits must still be send to the broker. In addition we need to + * continue polling the broker so that we know that earlier commits completed. For both we use commitAsync (in the + * second case with an empty map of offsets). This forms the commit-loop. + * + * The commit-loop ends when the offsets have been committed or a time out occured. + * + * ## The command-queue and the commit-queue + * + * TODO: Move this document to a central place. + */ +// Disable zio-intellij's inspection `SimplifyWhenInspection` because its suggestion is not +// equivalent performance-wise. +//noinspection SimplifyWhenInspection private[consumer] final class Runloop private ( - runtime: Runtime[Any], + sameThreadRuntime: Runtime[Any], hasGroupId: Boolean, consumer: ConsumerAccess, pollTimeout: Duration, runloopTimeout: Duration, commandQueue: Queue[Command], - lastRebalanceEvent: Ref.Synchronized[Option[Runloop.RebalanceEvent]], + commitQueue: Queue[Commit], + rebalanceListenerEvent: Ref[RebalanceEvent], val partitions: Queue[Take[Throwable, (TopicPartition, Stream[Throwable, ByteArrayCommittableRecord])]], diagnostics: Diagnostics, offsetRetrieval: OffsetRetrieval, userRebalanceListener: RebalanceListener, restartStreamsOnRebalancing: Boolean, + rebalanceSafeCommits: Boolean, currentState: Ref[State] ) { @@ -50,114 +104,226 @@ private[consumer] final class Runloop private ( .unit .uninterruptible - val rebalanceListener: RebalanceListener = { + private val rebalanceListener: RebalanceListener = { val emitDiagnostics = RebalanceListener( (assigned, _) => diagnostics.emitIfEnabled(DiagnosticEvent.Rebalance.Assigned(assigned)), (revoked, _) => diagnostics.emitIfEnabled(DiagnosticEvent.Rebalance.Revoked(revoked)), (lost, _) => diagnostics.emitIfEnabled(DiagnosticEvent.Rebalance.Lost(lost)) ) - def restartStreamsRebalancingListener = RebalanceListener( + val endRevokedStreamsRebalancingListener = RebalanceListener( onAssigned = (assigned, _) => - ZIO.logDebug("Rebalancing completed") *> - lastRebalanceEvent.updateZIO { - case None => - ZIO.some(Runloop.RebalanceEvent.Assigned(assigned)) - case Some(Runloop.RebalanceEvent.Revoked(revokeResult)) => - ZIO.some(Runloop.RebalanceEvent.RevokedAndAssigned(revokeResult, assigned)) - case Some(_) => - ZIO.fail(new IllegalStateException(s"Multiple onAssigned calls on rebalance listener")) - }, - onRevoked = (_, _) => - ZIO.logDebug("Rebalancing started") *> - currentState.get.flatMap { state => - // End all streams - endRevokedPartitions( - state.pendingRequests, - state.assignedStreams, - isRevoked = _ => true - ).flatMap { result => - lastRebalanceEvent.updateZIO { - case None => - ZIO.some(Runloop.RebalanceEvent.Revoked(result)) - case _ => - ZIO.fail( - new IllegalStateException( - s"onRevoked called on rebalance listener with pending assigned event" - ) - ) - } - } - } + for { + _ <- ZIO.logDebug(s"${assigned.size} partitions are assigned") + rebalanceEvent <- rebalanceListenerEvent.get + state <- currentState.get + streamsToEnd = if (restartStreamsOnRebalancing && !rebalanceEvent.wasInvoked) state.assignedStreams + else Chunk.empty + pendingCommits <- consumer.rebalanceListenerAccess { consumer => + endStreams(state, consumer, streamsToEnd, rebalanceSafeCommits) + } + _ <- rebalanceListenerEvent.set(rebalanceEvent.onAssigned(assigned, pendingCommits, streamsToEnd)) + _ <- ZIO.logTrace("onAssigned done") + } yield (), + onRevoked = (revokedTps, _) => + for { + _ <- ZIO.logDebug(s"${revokedTps.size} partitions are revoked") + rebalanceEvent <- rebalanceListenerEvent.get + state <- currentState.get + streamsToEnd = if (restartStreamsOnRebalancing && !rebalanceEvent.wasInvoked) state.assignedStreams + else state.assignedStreams.filter(control => revokedTps.contains(control.tp)) + pendingCommits <- consumer.rebalanceListenerAccess { consumer => + endStreams(state, consumer, streamsToEnd, rebalanceSafeCommits) + } + _ <- rebalanceListenerEvent.set(rebalanceEvent.onRevokedOrLost(pendingCommits, streamsToEnd)) + _ <- ZIO.logTrace("onRevoked done") + } yield (), + onLost = (lostTps, _) => + for { + _ <- ZIO.logDebug(s"${lostTps.size} partitions are lost") + rebalanceEvent <- rebalanceListenerEvent.get + state <- currentState.get + (lostStreams, remainingStreams) = state.assignedStreams.partition(control => lostTps.contains(control.tp)) + _ <- ZIO.foreachDiscard(lostStreams)(_.lost()) + streamsToEnd = if (restartStreamsOnRebalancing && !rebalanceEvent.wasInvoked) remainingStreams + else Chunk.empty + pendingCommits <- consumer.rebalanceListenerAccess { consumer => + endStreams(state, consumer, streamsToEnd, rebalanceSafeCommits) + } + _ <- rebalanceListenerEvent.update(_.onRevokedOrLost(pendingCommits, streamsToEnd)) + _ <- ZIO.logTrace(s"onLost done") + } yield () ) - if (restartStreamsOnRebalancing) { - emitDiagnostics ++ restartStreamsRebalancingListener ++ userRebalanceListener + emitDiagnostics ++ endRevokedStreamsRebalancingListener ++ userRebalanceListener + } + + /** + * Ends streams, optionally waiting for consumed offsets to be committed. + * + * @return + * all commits that were created while waiting + */ + private def endStreams( + state: State, + consumer: ByteArrayKafkaConsumer, + streamsToEnd: Chunk[PartitionStreamControl], + awaitStreamCommits: Boolean + ): Task[Chunk[Commit]] = + if (streamsToEnd.nonEmpty) { + for { + _ <- ZIO.foreachDiscard(streamsToEnd)(_.end()) + pendingCommits <- if (awaitStreamCommits) doAwaitStreamCommits(state, consumer, streamsToEnd) + else ZIO.succeed(Chunk.empty) + } yield pendingCommits } else { - emitDiagnostics ++ userRebalanceListener + ZIO.succeed(Chunk.empty) } + + private def doAwaitStreamCommits( + state: State, + consumer: ByteArrayKafkaConsumer, + streamsToEnd: Chunk[PartitionStreamControl] + ): Task[Chunk[Commit]] = + // When the queue is empty we still need to call commit (with 0 offsets) so that we poll + // the broker and earlier commits can complete. + // We cannot use ZStream.fromQueue because that will emit nothing when the queue is empty. + ZStream + .fromZIO(commitQueue.takeAll) + .tap(handleCommitsDuringWait(consumer)) + .forever + // TODO: remove following delay (perhaps replace `forever` with a repeat) + // .tap(_ => ZIO.sleep(100.millis)) + .scan(Chunk.empty[Commit])(_ ++ _) + .tap(_ => ZIO.logTrace(s"Waiting for ${streamsToEnd.size} streams to end")) + .takeUntilZIO { commits => + val sentCommits = state.pendingCommits ++ commits + val sentOffsets = state.committedOffsets ++ sentCommits + .map(_.offsets) + .foldLeft(Map.empty[TopicPartition, Long])(_ ++ _) + for { + _ <- + ZIO.logDebug( + s"state.committedOffsets: ${state.committedOffsets}\nstate.pendingCommits: ${state.pendingCommits}\ncommits: $commits" + ) + allCommitted <- ZIO.forall(streamsToEnd)(_.lasOffsetIsIn(sentOffsets)) + } yield allCommitted + } + .tap(_ => ZIO.attempt(consumer.commitSync(Map.empty[TopicPartition, OffsetAndMetadata].asJava, runloopTimeout))) + .run(ZSink.last) + .map(_.getOrElse(Chunk.empty)) + // .timeoutFail(new RuntimeException("Timeout waiting for streams to end"))(runloopTimeout) + .ensuring { + ZIO.logInfo(s"Done waiting for ${streamsToEnd.size} streams to end") + } + .interruptible + + /** + * Handle commits while waiting for streams to end. + * + * We need to ensure the streams can end. This is only possible if the commits these streams started complete. The + * commits complete when the callback is invoked. The callback is invoked when the underlying consumer polls the + * broker. This can be achieved by invoking `commitAsync`. Even when we pass no offsets, the broker will be polled and + * callbacks will be called. + */ + private def handleCommitsDuringWait( + consumer: ByteArrayKafkaConsumer + )(commits: Chunk[Commit]): UIO[Unit] = { + val (offsets, callback, onFailure) = asyncCommitParameters(commits) + // Note, as described above, we always call commit, even when offsets is empty. + ZIO.logDebug(s"Async commit of ${offsets.size} offsets for ${commits.size} commits") *> + ZIO.attempt { + if (commits.nonEmpty) consumer.commitAsync(offsets, callback) + else () + } + .catchAll(onFailure) +// ( +// if (commits.nonEmpty) +// ZIO +// .attempt(consumer.commitAsync(offsets, callback)) +// .catchAll(onFailure) +// else { +// ZIO.logDebug(s"Calling sync commit") *> +// ZIO.attempt { +// consumer.commitSync(java.util.Collections.emptyMap[TopicPartition, OffsetAndMetadata]()) +// }.orDie.unit <* +// ZIO.logDebug(s"Done calling sync commit") +// } +// ) } - private val commit: Map[TopicPartition, Long] => Task[Unit] = - offsets => - for { - p <- Promise.make[Throwable, Unit] - _ <- commandQueue.offer(Commit(offsets, p)).unit - _ <- diagnostics.emitIfEnabled(DiagnosticEvent.Commit.Started(offsets)) - _ <- p.await - } yield () + private def handleCommits(state: State, commits: Chunk[Commit]): UIO[State] = + if (commits.isEmpty) { + ZIO.succeed(state) + } else { + val (offsets, callback, onFailure) = asyncCommitParameters(commits) + val newState = state.addCommits(commits) + consumer.runloopAccess { c => + // We don't wait for the completion of the commit here, because it + // will only complete once we poll again. + ZIO.attempt(c.commitAsync(offsets, callback)) + } + .catchAll(onFailure) + .as(newState) + } - private def doCommit(cmd: Commit): UIO[Unit] = { - val offsets = cmd.offsets.map { case (tp, offset) => tp -> new OffsetAndMetadata(offset + 1) } - val cont = (e: Exit[Throwable, Unit]) => cmd.cont.done(e).asInstanceOf[UIO[Unit]] - val onSuccess = cont(Exit.unit) <* diagnostics.emitIfEnabled(DiagnosticEvent.Commit.Success(offsets)) + private def asyncCommitParameters( + commits: Chunk[Commit] + ): (JavaMap[TopicPartition, OffsetAndMetadata], OffsetCommitCallback, Throwable => UIO[Unit]) = { + val offsets = commits + .foldLeft(mutable.Map.empty[TopicPartition, Long]) { case (acc, commit) => + commit.offsets.foreach { case (tp, offset) => + acc += (tp -> acc.get(tp).map(_ max offset).getOrElse(offset)) + } + acc + } + .toMap + val offsetsWithMetaData = offsets.map { case (tp, offset) => tp -> new OffsetAndMetadata(offset + 1) } + val cont = (e: Exit[Throwable, Unit]) => ZIO.foreachDiscard(commits)(_.cont.done(e)) + val onSuccess = cont(Exit.unit) <* diagnostics.emitIfEnabled(DiagnosticEvent.Commit.Success(offsetsWithMetaData)) val onFailure: Throwable => UIO[Unit] = { case _: RebalanceInProgressException => - ZIO.logDebug(s"Rebalance in progress, retrying commit for offsets $offsets") *> - commandQueue.offer(cmd).unit - case err => - cont(Exit.fail(err)) <* diagnostics.emitIfEnabled(DiagnosticEvent.Commit.Failure(offsets, err)) + for { + _ <- ZIO.logDebug(s"Rebalance in progress, commit for offsets $offsets will be retried") + _ <- commitQueue.offerAll(commits) + _ <- commandQueue.offer(CommitAvailable) + } yield () + case err: Throwable => + cont(Exit.fail(err)) <* diagnostics.emitIfEnabled(DiagnosticEvent.Commit.Failure(offsetsWithMetaData, err)) } val callback = new OffsetCommitCallback { - override def onComplete(offsets: util.Map[TopicPartition, OffsetAndMetadata], exception: Exception): Unit = + override def onComplete(offsets: util.Map[TopicPartition, OffsetAndMetadata], exception: Exception): Unit = { + if (exception ne null) { + println( + s"onComplete callback: ${exception.getMessage}" + ) + } else { + println( + s"onComplete callback: ${commits + .map(c => s"${c.offsets.map { case (tp, offset) => s"${tp.partition()}-$offset" }.mkString(", ")}")}" + ) + } Unsafe.unsafe { implicit u => - runtime.unsafe.run(if (exception eq null) onSuccess else onFailure(exception)).getOrThrowFiberFailure() + sameThreadRuntime.unsafe + .run(if (exception eq null) onSuccess else onFailure(exception)) + .getOrThrowFiberFailure() } + } } - - // We don't wait for the completion of the commit here, because it - // will only complete once we poll again. - consumer.runloopAccess { c => - ZIO - .attempt(c.commitAsync(offsets.asJava, callback)) - .catchAll(onFailure) - } + (offsetsWithMetaData.asJava, callback, onFailure) } - /** - * Does all needed to end revoked partitions: - * 1. Complete the revoked assigned streams 2. Remove from the list of pending requests - * @return - * New pending requests, new active assigned streams - */ - private def endRevokedPartitions( - pendingRequests: Chunk[Request], - assignedStreams: Chunk[PartitionStreamControl], - isRevoked: TopicPartition => Boolean - ): UIO[Runloop.RevokeResult] = { - val (revokedStreams, newAssignedStreams) = - assignedStreams.partition(control => isRevoked(control.tp)) - - ZIO - .foreachDiscard(revokedStreams)(_.end()) - .as( - Runloop.RevokeResult( - pendingRequests = pendingRequests.filter(req => !isRevoked(req.tp)), - assignedStreams = newAssignedStreams - ) - ) - } + /** This is the implementation behind the user facing api `Offset.commit`. */ + private val commit: Map[TopicPartition, Long] => Task[Unit] = + offsets => + for { + p <- Promise.make[Throwable, Unit] + _ <- commitQueue.offer(Commit(offsets, p)) + _ <- commandQueue.offer(CommitAvailable) + _ <- diagnostics.emitIfEnabled(DiagnosticEvent.Commit.Started(offsets)) + _ <- p.await + } yield () /** * Offer records retrieved from poll() call to the streams. @@ -259,10 +425,10 @@ private[consumer] final class Runloop private ( s"Starting poll with ${state.pendingRequests.size} pending requests and ${state.pendingCommits.size} pending commits" ) _ <- currentState.set(state) + _ <- rebalanceListenerEvent.set(RebalanceEvent.None) pollResult <- consumer.runloopAccess { c => ZIO.suspend { - val prevAssigned = c.assignment().asScala.toSet val requestedPartitions = state.pendingRequests.map(_.tp).toSet @@ -270,70 +436,39 @@ private[consumer] final class Runloop private ( val records = doPoll(c) - val currentAssigned = c.assignment().asScala.toSet - val newlyAssigned = currentAssigned -- prevAssigned - - for { - ignoreRecordsForTps <- doSeekForNewPartitions(c, newlyAssigned) - - rebalanceEvent <- lastRebalanceEvent.getAndSet(None) - - revokeResult <- rebalanceEvent match { - case Some(Runloop.RebalanceEvent.Revoked(result)) => - // If we get here, `restartStreamsOnRebalancing == true` - // Use revoke result from endRevokedPartitions that was called previously in the rebalance listener - ZIO.succeed(result) - case Some(Runloop.RebalanceEvent.RevokedAndAssigned(result, _)) => - // If we get here, `restartStreamsOnRebalancing == true` - // Use revoke result from endRevokedPartitions that was called previously in the rebalance listener - ZIO.succeed(result) - case Some(Runloop.RebalanceEvent.Assigned(_)) => - // If we get here, `restartStreamsOnRebalancing == true` - // endRevokedPartitions was not called yet in the rebalance listener, - // and all partitions should be revoked - endRevokedPartitions( - state.pendingRequests, - state.assignedStreams, - isRevoked = _ => true - ) - case None => - // End streams for partitions that are no longer assigned - endRevokedPartitions( - state.pendingRequests, - state.assignedStreams, - isRevoked = (tp: TopicPartition) => !currentAssigned.contains(tp) - ) - } - - startingTps = rebalanceEvent match { - case Some(_) => - // If we get here, `restartStreamsOnRebalancing == true`, - // some partitions were revoked and/or assigned and - // all already assigned streams were ended. - // Therefore, all currently assigned tps are starting, - // either because they are restarting, or because they - // are new. - currentAssigned - case None => - newlyAssigned - } - - _ <- diagnostics.emitIfEnabled { - val providedTps = records.partitions().asScala.toSet - DiagnosticEvent.Poll( - tpRequested = requestedPartitions, - tpWithData = providedTps, - tpWithoutData = requestedPartitions -- providedTps - ) - } - - } yield Runloop.PollResult( - startingTps = startingTps, - pendingRequests = revokeResult.pendingRequests, - assignedStreams = revokeResult.assignedStreams, - records = records, - ignoreRecordsForTps = ignoreRecordsForTps - ) + rebalanceListenerEvent.get.flatMap { + case RebalanceEvent(false, _, _, _) => + // The fast track: rebalance listener was not invoked, no changes, only new records. + ZIO.succeed(Runloop.PollResult(records)) + + case RebalanceEvent(true, newlyAssigned, pendingCommits, endedStreams) => + // Some partitions were revoked, lost or assigned, + // some new commits might have been initiated, + // some streams might have been ended. + + // When `restartStreamsOnRebalancing == true`, + // all already assigned streams were ended. + // Therefore, _all_ currently assigned tps are starting, + // either because they are restarting, or because they + // are new. + val startingTps = + if (restartStreamsOnRebalancing) c.assignment().asScala.toSet + else newlyAssigned + + for { + ignoreRecordsForTps <- doSeekForNewPartitions(c, newlyAssigned) + + _ <- diagnostics.emitIfEnabled { + val providedTps = records.partitions().asScala.toSet + DiagnosticEvent.Poll( + tpRequested = requestedPartitions, + tpWithData = providedTps, + tpWithoutData = requestedPartitions -- providedTps + ) + } + + } yield Runloop.PollResult(pendingCommits, startingTps, records, ignoreRecordsForTps, endedStreams) + } } } startingStreams <- @@ -347,60 +482,78 @@ private[consumer] final class Runloop private ( partitions.offer(Take.chunk(Chunk.fromIterable(newStreams.map(_.tpStream)))) } } - runningStreams <- ZIO.filter(pollResult.assignedStreams)(_.isRunning) + runningStreams <- ZIO.filter(state.assignedStreams diff pollResult.endedStreams)(_.acceptsData) updatedStreams = runningStreams ++ startingStreams + updatedPendingRequests = { + val streamTps = updatedStreams.map(_.tp).toSet + state.pendingRequests.filter(req => streamTps.contains(req.tp)) + } fulfillResult <- offerRecordsToStreams( updatedStreams, - pollResult.pendingRequests, + updatedPendingRequests, pollResult.ignoreRecordsForTps, pollResult.records ) - updatedPendingCommits <- ZIO.filter(state.pendingCommits)(_.isPending) + (updatedPendingCommits, updatedCommittedOffsets) <- updateCommits( + state.pendingCommits ++ pollResult.newCommits, + state.committedOffsets + ) } yield State( pendingRequests = fulfillResult.pendingRequests, pendingCommits = updatedPendingCommits, + committedOffsets = updatedCommittedOffsets -- pollResult.endedStreams.map(_.tp), assignedStreams = updatedStreams, subscription = state.subscription ) + private def updateCommits( + pendingCommits: Chunk[Commit], + committedOffsets: Map[TopicPartition, Long] + ): ZIO[Any, Nothing, (Chunk[Commit], Map[TopicPartition, Long])] = + ZIO.foreach(pendingCommits)(commit => commit.isDone.map(commit -> _)).map { commitsWithDone => + val (doneCommits, updatedPendingCommits) = commitsWithDone.partitionMap { case (c, done) => + if (done) Left(c) else Right(c) + } + val updatedCommittedOffsets = committedOffsets ++ doneCommits.flatMap(_.offsets) + (updatedPendingCommits, updatedCommittedOffsets) + } + private def handleCommand(state: State, cmd: Command): Task[State] = cmd match { case req: Request => ZIO.succeed(state.addRequest(req)) - case cmd @ Command.Commit(_, _) => - doCommit(cmd).as(state.addCommit(cmd)) case cmd @ Command.ChangeSubscription(subscription, _) => - handleChangeSubscription(cmd).flatMap { newAssignedStreams => - val newState = state.copy( - assignedStreams = state.assignedStreams ++ newAssignedStreams, - subscription = subscription - ) - if (subscription.isDefined) ZIO.succeed(newState) - else { - // End all streams and pending requests - endRevokedPartitions( - newState.pendingRequests, - newState.assignedStreams, - isRevoked = _ => true - ).map { revokeResult => - newState.copy( - pendingRequests = revokeResult.pendingRequests, - assignedStreams = revokeResult.assignedStreams - ) - } + handleChangeSubscription(state, cmd).map { newAssignedStreams => + if (subscription.isDefined) { + state.copy( + assignedStreams = state.assignedStreams ++ newAssignedStreams, + subscription = subscription + ) + } else { + state.copy( + pendingRequests = Chunk.empty, + pendingCommits = Chunk.empty, + assignedStreams = Chunk.empty, + subscription = None + ) } } .tapBoth(e => cmd.fail(e), _ => cmd.succeed) .uninterruptible case Command.StopAllStreams => - { - for { - _ <- ZIO.logDebug("Graceful shutdown") - _ <- ZIO.foreachDiscard(state.assignedStreams)(_.end()) - _ <- partitions.offer(Take.end) - _ <- ZIO.logTrace("Graceful shutdown initiated") - } yield () - }.as(state.copy(pendingRequests = Chunk.empty)) + // End all streams. Since we're waiting for the stream to end, there should be no pending commits. + for { + _ <- ZIO.logDebug("Graceful shutdown") + _ <- consumer.runloopAccess { c => + endStreams(state, c, state.assignedStreams, awaitStreamCommits = false) + } + _ <- partitions.offer(Take.end) + _ <- ZIO.logTrace("Graceful shutdown done") + } yield state.copy( + pendingRequests = Chunk.empty, + pendingCommits = Chunk.empty, + assignedStreams = Chunk.empty + ) case _: Command.Control => ZIO.succeed(state) } @@ -410,25 +563,31 @@ private[consumer] final class Runloop private ( * any created streams */ private def handleChangeSubscription( + state: State, command: Command.ChangeSubscription ): Task[Chunk[PartitionStreamControl]] = consumer.runloopAccess { c => command.subscription match { case None => - ZIO - .attempt(c.unsubscribe()) - .as(Chunk.empty) + // We assume that the invoker of this method will clear the state. This allows us to + // ignore whatever happens to the state while in unsubscribe (callback will be called). + for { + _ <- ZIO.logDebug(s"Unsubscribing, storing state: $state") + _ <- currentState.set(state) + _ <- ZIO.attempt(c.unsubscribe()) + _ <- ZIO.logTrace("Unsubscribing done") + } yield Chunk.empty case Some(subscription) => subscription match { case Subscription.Pattern(pattern) => val rc = RebalanceConsumer.Live(c) ZIO - .attempt(c.subscribe(pattern.pattern, rebalanceListener.toKafka(runtime, rc))) + .attempt(c.subscribe(pattern.pattern, rebalanceListener.toKafka(sameThreadRuntime, rc))) .as(Chunk.empty) case Subscription.Topics(topics) => val rc = RebalanceConsumer.Live(c) ZIO - .attempt(c.subscribe(topics.asJava, rebalanceListener.toKafka(runtime, rc))) + .attempt(c.subscribe(topics.asJava, rebalanceListener.toKafka(sameThreadRuntime, rc))) .as(Chunk.empty) case Subscription.Manual(topicPartitions) => // For manual subscriptions we have to do some manual work before starting the run loop @@ -455,6 +614,7 @@ private[consumer] final class Runloop private ( * - Process all currently queued commands before polling instead of one by one * - Immediately after polling, if there are available commands, process them instead of waiting until some periodic * trigger + * - Process the commitQueue (otherwise it would fill up when there are no rebalances) * - Poll only when subscribed (leads to exceptions from the Apache Kafka Consumer if not) * - Poll continuously when there are (still) unfulfilled requests or pending commits * - Poll periodically when we are subscribed but do not have assigned streams yet. This happens after @@ -467,8 +627,10 @@ private[consumer] final class Runloop private ( .takeWhile(_ != StopRunloop) .runFoldChunksDiscardZIO(State.initial) { (state, commands) => for { - _ <- ZIO.logTrace(s"Processing ${commands.size} commands: ${commands.mkString(",")}") - stateAfterCommands <- ZIO.foldLeft(commands)(state)(handleCommand) + commits <- commitQueue.takeAll + _ <- ZIO.logTrace(s"Processing ${commits.size} commits, ${commands.size} commands: ${commands.mkString(",")}") + stateAfterCommits <- handleCommits(state, commits) + stateAfterCommands <- ZIO.foldLeft(commands)(stateAfterCommits)(handleCommand) updatedStateAfterPoll <- if (stateAfterCommands.shouldPoll) handlePoll(stateAfterCommands) else ZIO.succeed(stateAfterCommands) @@ -504,31 +666,61 @@ private[consumer] object Runloop { type ByteArrayCommittableRecord = CommittableRecord[Array[Byte], Array[Byte]] // Internal parameters, should not be necessary to tune - private val commandQueueSize = 1024 + private val CommandQueueSize = 1024 + private val CommitQueueSize = 1024 private final case class PollResult( + newCommits: Chunk[Commit], startingTps: Set[TopicPartition], - pendingRequests: Chunk[Request], - assignedStreams: Chunk[PartitionStreamControl], records: ConsumerRecords[Array[Byte], Array[Byte]], - ignoreRecordsForTps: Set[TopicPartition] - ) - private final case class RevokeResult( - pendingRequests: Chunk[Request], - assignedStreams: Chunk[PartitionStreamControl] + ignoreRecordsForTps: Set[TopicPartition], + endedStreams: Chunk[PartitionStreamControl] ) + private object PollResult { + def apply(records: ConsumerRecords[Array[Byte], Array[Byte]]): PollResult = + PollResult( + newCommits = Chunk.empty, + startingTps = Set.empty, + records = records, + ignoreRecordsForTps = Set.empty, + endedStreams = Chunk.empty + ) + } + private final case class FulfillResult( pendingRequests: Chunk[Request] ) - private sealed trait RebalanceEvent + private final case class RebalanceEvent( + wasInvoked: Boolean, + newlyAssigned: Set[TopicPartition], + pendingCommits: Chunk[Commit], + endedStreams: Chunk[PartitionStreamControl] + ) { + def onAssigned( + assigned: Set[TopicPartition], + commits: Chunk[Commit], + streamsToEnd: Chunk[PartitionStreamControl] + ): RebalanceEvent = + RebalanceEvent( + wasInvoked = true, + newlyAssigned = newlyAssigned ++ assigned, + pendingCommits = pendingCommits ++ commits, + endedStreams = endedStreams ++ streamsToEnd + ) + def onRevokedOrLost( + commits: Chunk[Commit], + streamsToEnd: Chunk[PartitionStreamControl] + ): RebalanceEvent = + copy( + wasInvoked = true, + pendingCommits = pendingCommits ++ commits, + endedStreams = endedStreams ++ streamsToEnd + ) + } + private object RebalanceEvent { - final case class Revoked(revokeResult: Runloop.RevokeResult) extends RebalanceEvent - final case class Assigned(newlyAssigned: Set[TopicPartition]) extends RebalanceEvent - final case class RevokedAndAssigned( - revokeResult: Runloop.RevokeResult, - newlyAssigned: Set[TopicPartition] - ) extends RebalanceEvent + val None: RebalanceEvent = RebalanceEvent(wasInvoked = false, Set.empty, Chunk.empty, Chunk.empty) } sealed trait Command @@ -540,14 +732,12 @@ private[consumer] object Runloop { /** Used as a signal that another poll is needed. */ case object Poll extends Control + /** Used as a signal to the poll-loop that commits are available in the commit-queue. */ + case object CommitAvailable extends Control + case object StopRunloop extends Control case object StopAllStreams extends Control - final case class Commit(offsets: Map[TopicPartition, Long], cont: Promise[Throwable, Unit]) extends Command { - @inline def isDone: UIO[Boolean] = cont.isDone - @inline def isPending: UIO[Boolean] = isDone.negate - } - /** Used by a stream to request more records. */ final case class Request(tp: TopicPartition) extends Command @@ -560,6 +750,14 @@ private[consumer] object Runloop { } } + final case class Commit( + offsets: Map[TopicPartition, Long], + cont: Promise[Throwable, Unit] + ) { + @inline def isDone: UIO[Boolean] = cont.isDone + @inline def isPending: UIO[Boolean] = isDone.negate + } + def apply( hasGroupId: Boolean, consumer: ConsumerAccess, @@ -568,32 +766,36 @@ private[consumer] object Runloop { offsetRetrieval: OffsetRetrieval, userRebalanceListener: RebalanceListener, restartStreamsOnRebalancing: Boolean, + rebalanceSafeCommits: Boolean, runloopTimeout: Duration ): ZIO[Scope, Throwable, Runloop] = for { - commandQueue <- ZIO.acquireRelease(Queue.bounded[Runloop.Command](commandQueueSize))(_.shutdown) - lastRebalanceEvent <- Ref.Synchronized.make[Option[Runloop.RebalanceEvent]](None) + commandQueue <- ZIO.acquireRelease(Queue.bounded[Runloop.Command](CommandQueueSize))(_.shutdown) + commitQueue <- ZIO.acquireRelease(Queue.bounded[Runloop.Commit](CommitQueueSize))(_.shutdown) + rebalanceListenerEvent <- Ref.make[RebalanceEvent](RebalanceEvent.None) partitions <- ZIO.acquireRelease( Queue .unbounded[ Take[Throwable, (TopicPartition, Stream[Throwable, ByteArrayCommittableRecord])] ] )(_.shutdown) - currentStateRef <- Ref.make(State.initial) - runtime <- ZIO.runtime[Any] + currentStateRef <- Ref.make(State.initial) + sameThreadRuntime <- ZIO.runtime[Any].provideLayer(SameThreadRuntimeLayer) runloop = new Runloop( - runtime, + sameThreadRuntime, hasGroupId, consumer, pollTimeout, runloopTimeout, commandQueue, - lastRebalanceEvent, + commitQueue, + rebalanceListenerEvent, partitions, diagnostics, offsetRetrieval, userRebalanceListener, restartStreamsOnRebalancing, + rebalanceSafeCommits, currentStateRef ) _ <- ZIO.logDebug("Starting Runloop") @@ -615,11 +817,12 @@ private[consumer] object Runloop { private[internal] final case class State( pendingRequests: Chunk[Request], pendingCommits: Chunk[Commit], + committedOffsets: Map[TopicPartition, Long], assignedStreams: Chunk[PartitionStreamControl], subscription: Option[Subscription] ) { - def addCommit(c: Commit): State = copy(pendingCommits = pendingCommits :+ c) - def addRequest(r: Request): State = copy(pendingRequests = pendingRequests :+ r) + def addCommits(c: Chunk[Commit]): State = copy(pendingCommits = pendingCommits ++ c) + def addRequest(r: Request): State = copy(pendingRequests = pendingRequests :+ r) def isSubscribed: Boolean = subscription.isDefined @@ -631,6 +834,7 @@ object State { val initial: State = State( pendingRequests = Chunk.empty, pendingCommits = Chunk.empty, + committedOffsets = Map.empty, assignedStreams = Chunk.empty, subscription = None ) diff --git a/zio-kafka/src/main/scala/zio/kafka/consumer/internal/package.scala b/zio-kafka/src/main/scala/zio/kafka/consumer/internal/package.scala new file mode 100644 index 0000000000..aec2309db7 --- /dev/null +++ b/zio-kafka/src/main/scala/zio/kafka/consumer/internal/package.scala @@ -0,0 +1,28 @@ +package zio.kafka.consumer + +import zio._ +import zio.internal.ExecutionMetrics + +package object internal { + + /** + * A runtime layer that can be used to run everything on the thread of the caller. + * + * Provided by Adam Fraser in Discord: + * https://discord.com/channels/629491597070827530/630498701860929559/1094279123880386590 but with cooperative + * yielding enabled. + */ + private[internal] val SameThreadRuntimeLayer: ZLayer[Any, Nothing, Unit] = { + val sameThreadExecutor = new Executor() { + override def metrics(implicit unsafe: Unsafe): Option[ExecutionMetrics] = None + + override def submit(runnable: Runnable)(implicit unsafe: Unsafe): Boolean = { + runnable.run() + true + } + } + + Runtime.setExecutor(sameThreadExecutor) ++ Runtime.setBlockingExecutor(sameThreadExecutor) + } + +}