Skip to content

Commit

Permalink
Postpone removing records until assignment is made.
Browse files Browse the repository at this point in the history
  • Loading branch information
Marcin Kustra committed Aug 18, 2022
1 parent cf9bff2 commit d61250d
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 34 deletions.
3 changes: 1 addition & 2 deletions modules/core/src/main/scala/fs2/kafka/KafkaConsumer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -574,9 +574,8 @@ object KafkaConsumer {
partitions.toList.asJava
)
} >> actor.ref
.updateAndGet(_.asSubscribed)
.updateAndGet(_.asSubscribed.withAssignments(partitions.toSortedSet))
.log(LogEntry.ManuallyAssignedPartitions(partitions, _))

}

override def assign(topic: String): F[Unit] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,36 @@ private[kafka] final class KafkaConsumerActor[F[_]](
res.start.void
}

private[this] def assigned(assigned: SortedSet[TopicPartition]): F[Unit] =
ref
.updateAndGet(_.withRebalancing(false))
.flatMap { state =>
log(AssignedPartitions(assigned, state)) >>
state.onRebalances.foldLeft(F.unit)(_ >> _.onAssigned(assigned))
private[this] def assigned(assigned: SortedSet[TopicPartition]): F[Unit] = {
def withState[A] = StateT.apply[Id, State[F], A](_)

val removeRevokedRecords = withState { st =>
val assignments = st.assignments
val revokedRecords = st.records.filterKeysStrict(!assignments(_))

if (revokedRecords.nonEmpty) {
val newState = st.withoutRecords(revokedRecords.keySet)

val action = logging.log(RemovedRevokedRecords(revokedRecords, newState))

(newState, action)
} else (st, F.unit)
}

for {
action <- ref.modify { state =>
removeRevokedRecords.run(
state
.withRebalancing(false)
.withAssignments(assigned)
)
}
updatedState <- ref.get
_ <- action >>
log(AssignedPartitions(assigned, updatedState)) >>
updatedState.onRebalances.foldLeft(F.unit)(_ >> _.onAssigned(assigned))
} yield ()
}

private[this] def revoked(revoked: SortedSet[TopicPartition]): F[Unit] = {
def withState[A] = StateT.apply[Id, State[F], A](_)
Expand Down Expand Up @@ -170,44 +193,29 @@ private[kafka] final class KafkaConsumerActor[F[_]](
} else (st, F.unit)
}

def removeRevokedRecords(revokedNonFetches: SortedSet[TopicPartition]) = withState { st =>
if (revokedNonFetches.nonEmpty) {
val revokedRecords = st.records.filterKeysStrict(revokedNonFetches)

if (revokedRecords.nonEmpty) {
val newState = st.withoutRecords(revokedRecords.keySet)

val action = logging.log(RemovedRevokedRecords(revokedRecords, newState))

(newState, action)
} else (st, F.unit)
} else (st, F.unit)
}

ref
.modify { state =>
val withRebalancing = state.withRebalancing(true)
val updatedState = state
.withRebalancing(true)
.withoutAssignments(revoked)

val fetches = withRebalancing.fetches.keySetStrict
val records = withRebalancing.records.keySetStrict
val fetches = updatedState.fetches.keySetStrict
val records = updatedState.records.keySetStrict

val revokedFetches = revoked intersect fetches
val revokedNonFetches = revoked diff revokedFetches

val withRecords = records intersect revokedFetches
val withoutRecords = revokedFetches diff records

(for {
completeWithRecords <- completeWithRecords(withRecords)
completeWithoutRecords <- completeWithoutRecords(withoutRecords)
removeRevokedRecords <- removeRevokedRecords(revokedNonFetches)
} yield RevokedResult(
logRevoked = logging.log(RevokedPartitions(revoked, withRebalancing)),
logRevoked = logging.log(RevokedPartitions(revoked, updatedState)),
completeWithRecords = completeWithRecords,
completeWithoutRecords = completeWithoutRecords,
removeRevokedRecords = removeRevokedRecords,
onRebalances = withRebalancing.onRebalances
)).run(withRebalancing)
onRebalances = updatedState.onRebalances
)).run(updatedState)
}
.flatMap { res =>
val onRevoked =
Expand All @@ -216,7 +224,6 @@ private[kafka] final class KafkaConsumerActor[F[_]](
res.logRevoked >>
res.completeWithRecords >>
res.completeWithoutRecords >>
res.removeRevokedRecords >>
onRevoked
}
}
Expand Down Expand Up @@ -246,7 +253,7 @@ private[kafka] final class KafkaConsumerActor[F[_]](
def pollConsumer(state: State[F]): F[ConsumerRecords] =
withConsumer
.blocking { consumer =>
val assigned = consumer.assignment.toSet
val assigned = state.assignments
val requested = state.fetches.keySetStrict
val available = state.records.keySetStrict

Expand Down Expand Up @@ -396,7 +403,6 @@ private[kafka] final class KafkaConsumerActor[F[_]](
logRevoked: F[Unit],
completeWithRecords: F[Unit],
completeWithoutRecords: F[Unit],
removeRevokedRecords: F[Unit],
onRebalances: Chain[OnRebalance[F]]
)

Expand Down Expand Up @@ -457,6 +463,7 @@ private[kafka] object KafkaConsumerActor {
type StreamId = Int

final case class State[F[_]](
assignments: Set[TopicPartition],
fetches: Map[TopicPartition, Map[StreamId, FetchRequest[F]]],
records: Map[TopicPartition, NonEmptyVector[KafkaByteConsumerRecord]],
pendingCommits: Chain[Request.Commit[F]],
Expand All @@ -465,6 +472,23 @@ private[kafka] object KafkaConsumerActor {
subscribed: Boolean,
streaming: Boolean
) {

/** Add new assignments to state.
*
* @param assignments assignments to add
* @return updated state with assignments added
*/
def withAssignments(assignments: Set[TopicPartition]): State[F] =
copy(assignments = this.assignments ++ assignments)

/** Remove assignments to state.
*
* @param assignments assignments to remove
* @return updated state with assignments removed
*/
def withoutAssignments(assignments: Set[TopicPartition]): State[F] =
copy(assignments = this.assignments diff assignments)

def withOnRebalance(onRebalance: OnRebalance[F]): State[F] =
copy(onRebalances = onRebalances append onRebalance)

Expand Down Expand Up @@ -549,6 +573,7 @@ private[kafka] object KafkaConsumerActor {
object State {
def empty[F[_]]: State[F] =
State(
assignments = Set.empty,
fetches = Map.empty,
records = Map.empty,
pendingCommits = Chain.empty,
Expand Down

0 comments on commit d61250d

Please sign in to comment.