diff --git a/http-verbs-common/src/main/scala/uk/gov/hmrc/play/http/logging/Mdc.scala b/http-verbs-common/src/main/scala/uk/gov/hmrc/play/http/logging/Mdc.scala index 4b939335..52265f25 100644 --- a/http-verbs-common/src/main/scala/uk/gov/hmrc/play/http/logging/Mdc.scala +++ b/http-verbs-common/src/main/scala/uk/gov/hmrc/play/http/logging/Mdc.scala @@ -28,11 +28,19 @@ object Mdc { def withMdc[A](block: => Future[A], mdcData: Map[String, String])(implicit ec: ExecutionContext): Future[A] = block.map { a => - mdcData.foreach { - case (k, v) => MDC.put(k, v) - } + putMdc(mdcData) a - }(ec) + }.recover { + case t => + putMdc(mdcData) + throw t + } + + private def putMdc(mdc: Map[String, String]): Unit = { + mdc.foreach { + case (k, v) => MDC.put(k, v) + } + } /** Restores MDC data to the continuation of a block, which may be discarding MDC data (e.g. uses a different execution context) */ diff --git a/http-verbs-common/src/main/scala/uk/gov/hmrc/play/http/logging/MdcLoggingExecutionContext.scala b/http-verbs-common/src/main/scala/uk/gov/hmrc/play/http/logging/MdcLoggingExecutionContext.scala index fd42ebc3..d92141e6 100644 --- a/http-verbs-common/src/main/scala/uk/gov/hmrc/play/http/logging/MdcLoggingExecutionContext.scala +++ b/http-verbs-common/src/main/scala/uk/gov/hmrc/play/http/logging/MdcLoggingExecutionContext.scala @@ -20,7 +20,7 @@ import org.slf4j.MDC import scala.concurrent.ExecutionContext -class MdcLoggingExecutionContext(wrapped: ExecutionContext, mdcData: Map[String, String]) +class MdcLoggingExecutionContext(wrapped: ExecutionContext, mdcData: Map[String, String]) extends ExecutionContext { def execute(runnable: Runnable) { diff --git a/http-verbs-common/src/test/scala/uk/gov/hmrc/http/logging/MdcSpec.scala b/http-verbs-common/src/test/scala/uk/gov/hmrc/http/logging/MdcSpec.scala new file mode 100644 index 00000000..e178ba44 --- /dev/null +++ b/http-verbs-common/src/test/scala/uk/gov/hmrc/http/logging/MdcSpec.scala @@ -0,0 +1,121 @@ +/* + * Copyright 2021 HM Revenue & Customs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package uk.gov.hmrc.http.logging + +import akka.dispatch.ExecutorServiceDelegate +import org.scalatest.BeforeAndAfter +import org.scalatest.concurrent.{IntegrationPatience, ScalaFutures} +import org.scalatest.matchers.should.Matchers +import org.scalatest.wordspec.AnyWordSpecLike +import org.scalatestplus.mockito.MockitoSugar +import org.slf4j.MDC +import uk.gov.hmrc.play.http.logging.Mdc + +import java.util.concurrent.{ExecutorService, Executors} +import scala.concurrent.duration.DurationInt +import scala.concurrent.{ExecutionContext, Future} + +class MdcSpec extends AnyWordSpecLike with Matchers with MockitoSugar with ScalaFutures with BeforeAndAfter with IntegrationPatience { + + before { + MDC.clear() + } + + "mdcData" should { + "return a Scala Map" in { + MDC.put("something1", "something2") + Mdc.mdcData shouldBe Map("something1" -> "something2") + } + } + + "Preserving MDC" should { + "show that MDC is lost when switching contexts" in { + implicit val mdcEc: ExecutionContext = mdcPropagatingExecutionContext() + + (for { + _ <- Future.successful(org.slf4j.MDC.put("k", "v")) + _ <- runActionWhichLosesMdc() + } yield + Option(MDC.get("k")) + ).futureValue shouldBe None + } + + "restore MDC" in { + implicit val mdcEc: ExecutionContext = mdcPropagatingExecutionContext() + + (for { + _ <- Future.successful(org.slf4j.MDC.put("k", "v")) + _ <- Mdc.preservingMdc(runActionWhichLosesMdc()) + } yield + Option(MDC.get("k")) + ).futureValue shouldBe Some("v") + } + + "restore MDC when exception is thrown" in { + implicit val mdcEc: ExecutionContext = mdcPropagatingExecutionContext() + + (for { + _ <- Future.successful(org.slf4j.MDC.put("k", "v")) + _ <- Mdc.preservingMdc(runActionWhichLosesMdc(fail = true)) + } yield () + ) + .recover { case _ => + Option(MDC.get("k")) + }.futureValue shouldBe Some("v") + } + } + + private def runActionWhichLosesMdc(fail: Boolean = false): Future[Any] = { + val as = akka.actor.ActorSystem("different-as") + akka.pattern.after(1.second, as.scheduler)(Future(())(as.dispatcher))(as.dispatcher) + .map(a => if (fail) sys.error("expected test exception") else a)(as.dispatcher) + } + + private def mdcPropagatingExecutionContext() = + ExecutionContext.fromExecutor(new MDCPropagatingExecutorService(Executors.newFixedThreadPool(2))) + +} + +// This class is copied from bootstrap-play. +// There is a ticket in the backlog to consider extracting it neatly. For now, it is needed for this test. +class MDCPropagatingExecutorService(val executor: ExecutorService) extends ExecutorServiceDelegate { + + override def execute(command: Runnable): Unit = { + + val mdcData = MDC.getCopyOfContextMap + + executor.execute(new Runnable { + override def run(): Unit = { + val oldMdcData = MDC.getCopyOfContextMap + setMDC(mdcData) + try { + command.run() + } finally { + setMDC(oldMdcData) + } + } + }) + } + + private def setMDC(context: java.util.Map[String, String]): Unit = { + if (context == null) { + MDC.clear() + } else { + MDC.setContextMap(context) + } + } +}