diff --git a/delta/sourcing-psql/src/main/scala/ch/epfl/bluebrain/nexus/delta/sourcing/stream/Projection.scala b/delta/sourcing-psql/src/main/scala/ch/epfl/bluebrain/nexus/delta/sourcing/stream/Projection.scala index a53e75bd59..38ac0bb631 100644 --- a/delta/sourcing-psql/src/main/scala/ch/epfl/bluebrain/nexus/delta/sourcing/stream/Projection.scala +++ b/delta/sourcing-psql/src/main/scala/ch/epfl/bluebrain/nexus/delta/sourcing/stream/Projection.scala @@ -1,17 +1,16 @@ package ch.epfl.bluebrain.nexus.delta.sourcing.stream -import cats.effect.{ContextShift, ExitCase, Fiber, IO, Timer} import cats.effect.concurrent.Ref -import cats.implicits.catsSyntaxFlatMapOps +import cats.effect.{ContextShift, ExitCase, Fiber, IO, Timer} +import cats.implicits._ +import cats.implicits.{catsSyntaxFlatMapIdOps, catsSyntaxFlatMapOps, catsSyntaxMonad} +import ch.epfl.bluebrain.nexus.delta.kernel.Logger import ch.epfl.bluebrain.nexus.delta.sourcing.config.BatchConfig import ch.epfl.bluebrain.nexus.delta.sourcing.model.ElemPipe import ch.epfl.bluebrain.nexus.delta.sourcing.stream.Elem.FailedElem -import fs2.concurrent.SignallingRef -import cats.implicits._ -import ch.epfl.bluebrain.nexus.delta.kernel.Logger import ch.epfl.bluebrain.nexus.delta.sourcing.stream.Projection.logger +import fs2.concurrent.SignallingRef -import java.util.concurrent.TimeoutException import scala.concurrent.duration.FiniteDuration /** @@ -60,18 +59,32 @@ final class Projection private[stream] ( * the maximum time expected for the projection to complete * @return */ + def waitForCompletion(timeout: FiniteDuration)(implicit timer: Timer[IO], cs: ContextShift[IO]): IO[ExecutionStatus] = - executionStatus - .iterateUntil { - case ExecutionStatus.Completed => true - case ExecutionStatus.Failed(_) => true - case ExecutionStatus.Stopped => true - case _ => false - } - .timeout(timeout) - .recoverWith { case _: TimeoutException => - logger.error(s"Timeout waiting for completion on projection $name") >> executionStatus + iterateUntilCompletion + .timeoutTo(timeout, logger.error(s"Timeout waiting for completion on projection $name") >> executionStatus) + + private def statusMeansStopped(executionStatus: ExecutionStatus): Boolean = { + executionStatus match { + case ExecutionStatus.Completed => true + case ExecutionStatus.Failed(_) => true + case ExecutionStatus.Stopped => true + case _ => false + } + } + + private def iterateUntilCompletion(implicit cs: ContextShift[IO]): IO[ExecutionStatus] = { + (for { + status <- executionStatus + _ <- cs.shift + } yield status).flatMap { status => + if (statusMeansStopped(status)) { + IO.pure(status) + } else { + iterateUntilCompletion } + } + } /** * Stops the projection. Has no effect if the projection is already stopped.