From 21d5324ad58248e9e9ffd2e5e48bdc3ad1671b84 Mon Sep 17 00:00:00 2001 From: kyri-petrou <67301607+kyri-petrou@users.noreply.github.com> Date: Tue, 14 May 2024 11:29:39 +1000 Subject: [PATCH] [ZClient] Unsafely fulfil promises to avoid race conditions (#2843) Unsafely fulfill Client promises to avoid race conditions --- .../scala/zio/http/netty/NettyResponse.scala | 16 +++--- .../netty/client/ClientFailureHandler.scala | 13 ++--- .../netty/client/ClientInboundHandler.scala | 14 ++---- .../client/ClientResponseStreamHandler.scala | 28 +++++------ .../http/netty/client/NettyClientDriver.scala | 49 +++++++++---------- .../WebSocketClientInboundHandler.scala | 20 +++----- 6 files changed, 55 insertions(+), 85 deletions(-) diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/NettyResponse.scala b/zio-http/jvm/src/main/scala/zio/http/netty/NettyResponse.scala index 3390125a5b..c8e7dffd53 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/NettyResponse.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/NettyResponse.scala @@ -17,7 +17,7 @@ package zio.http.netty import zio.stacktracer.TracingImplicits.disableAutoTrace -import zio.{Promise, Trace, Unsafe, ZIO} +import zio.{Exit, Promise, Trace, Unsafe, ZIO} import zio.http.internal.ChannelState import zio.http.netty.client.ClientResponseStreamHandler @@ -42,25 +42,21 @@ object NettyResponse { def make( ctx: ChannelHandlerContext, jRes: HttpResponse, - zExec: NettyRuntime, onComplete: Promise[Throwable, ChannelState], keepAlive: Boolean, )(implicit unsafe: Unsafe, trace: Trace, - ): ZIO[Any, Nothing, Response] = { + ): Response = { val status = Conversions.statusFromNetty(jRes.status()) val headers = Conversions.headersFromNetty(jRes.headers()) val knownContentLength = headers.get(Header.ContentLength).map(_.length) if (knownContentLength.contains(0L)) { - onComplete - .succeed(ChannelState.forStatus(status)) - .as( - Response(status, headers, Body.empty), - ) + onComplete.unsafe.done(Exit.succeed(ChannelState.forStatus(status))) + Response(status, headers, Body.empty) } else { - val responseHandler = new ClientResponseStreamHandler(zExec, onComplete, keepAlive, status) + val responseHandler = new ClientResponseStreamHandler(onComplete, keepAlive, status) ctx .pipeline() .addAfter( @@ -70,7 +66,7 @@ object NettyResponse { ): Unit val data = NettyBody.fromAsync(callback => responseHandler.connect(callback), knownContentLength) - ZIO.succeed(Response(status, headers, data)) + Response(status, headers, data) } } } diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/client/ClientFailureHandler.scala b/zio-http/jvm/src/main/scala/zio/http/netty/client/ClientFailureHandler.scala index b9f7de9c75..b11a0cdf55 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/client/ClientFailureHandler.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/client/ClientFailureHandler.scala @@ -1,25 +1,22 @@ package zio.http.netty.client -import zio.{Promise, Trace, Unsafe} +import zio.{Exit, Promise, Unsafe} import zio.http.Response import zio.http.internal.ChannelState -import zio.http.netty.NettyRuntime import io.netty.channel.{ChannelHandlerContext, ChannelInboundHandlerAdapter} /** Handles failures happening in ClientInboundHandler */ final class ClientFailureHandler( - rtm: NettyRuntime, onResponse: Promise[Throwable, Response], onComplete: Promise[Throwable, ChannelState], -)(implicit trace: Trace) - extends ChannelInboundHandlerAdapter { +) extends ChannelInboundHandlerAdapter { implicit private val unsafeClass: Unsafe = Unsafe.unsafe override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - rtm.runUninterruptible(ctx, NettyRuntime.noopEnsuring)( - onResponse.fail(cause) *> onComplete.fail(cause), - )(unsafeClass, trace) + val exit = Exit.fail(cause) + onResponse.unsafe.done(exit) + onComplete.unsafe.done(exit) } } diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/client/ClientInboundHandler.scala b/zio-http/jvm/src/main/scala/zio/http/netty/client/ClientInboundHandler.scala index 8ad072d644..52424cc925 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/client/ClientInboundHandler.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/client/ClientInboundHandler.scala @@ -65,17 +65,9 @@ final class ClientInboundHandler( override def channelRead0(ctx: ChannelHandlerContext, msg: HttpObject): Unit = { msg match { case response: HttpResponse => - rtm.runUninterruptible(ctx, NettyRuntime.noopEnsuring) { - NettyResponse - .make( - ctx, - response, - rtm, - onComplete, - enableKeepAlive && HttpUtil.isKeepAlive(response), - ) - .flatMap(onResponse.succeed) - }(unsafeClass, trace) + val keepAlive = enableKeepAlive && HttpUtil.isKeepAlive(response) + val resp = NettyResponse.make(ctx, response, onComplete, keepAlive) + onResponse.unsafe.done(Exit.succeed(resp)) case content: HttpContent => ctx.fireChannelRead(content): Unit diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/client/ClientResponseStreamHandler.scala b/zio-http/jvm/src/main/scala/zio/http/netty/client/ClientResponseStreamHandler.scala index f483bfe608..75a71cfb15 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/client/ClientResponseStreamHandler.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/client/ClientResponseStreamHandler.scala @@ -17,44 +17,38 @@ package zio.http.netty.client import zio.stacktracer.TracingImplicits.disableAutoTrace -import zio.{Promise, Trace} +import zio.{Exit, Promise, Trace, Unsafe} import zio.http.Status import zio.http.internal.ChannelState -import zio.http.netty.{AsyncBodyReader, NettyFutureExecutor, NettyRuntime} +import zio.http.netty.AsyncBodyReader import io.netty.channel._ import io.netty.handler.codec.http.{HttpContent, LastHttpContent} + final class ClientResponseStreamHandler( - rtm: NettyRuntime, onComplete: Promise[Throwable, ChannelState], keepAlive: Boolean, status: Status, )(implicit trace: Trace) extends AsyncBodyReader { self => - override def channelRead0( - ctx: ChannelHandlerContext, - msg: HttpContent, - ): Unit = { + private implicit val unsafe: Unsafe = Unsafe.unsafe + + override def channelRead0(ctx: ChannelHandlerContext, msg: HttpContent): Unit = { val isLast = msg.isInstanceOf[LastHttpContent] super.channelRead0(ctx, msg) if (isLast) { if (keepAlive) - rtm.runUninterruptible(ctx, NettyRuntime.noopEnsuring)(onComplete.succeed(ChannelState.forStatus(status)))( - unsafeClass, - trace, - ) + onComplete.unsafe.done(Exit.succeed(ChannelState.forStatus(status))) else { - rtm.runUninterruptible(ctx, NettyRuntime.noopEnsuring)( - onComplete.succeed(ChannelState.Invalid) *> NettyFutureExecutor.executed(ctx.close()), - )(unsafeClass, trace) + onComplete.unsafe.done(Exit.succeed(ChannelState.Invalid)) + ctx.close() } } } - override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - rtm.runUninterruptible(ctx, NettyRuntime.noopEnsuring)(onComplete.fail(cause))(unsafeClass, trace) - } + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = + onComplete.unsafe.done(Exit.fail(cause)) } diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/client/NettyClientDriver.scala b/zio-http/jvm/src/main/scala/zio/http/netty/client/NettyClientDriver.scala index 0c6c53998a..7755f05232 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/client/NettyClientDriver.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/client/NettyClientDriver.scala @@ -28,7 +28,7 @@ import zio.http.netty._ import zio.http.netty.model.Conversions import zio.http.netty.socket.NettySocketProtocol -import io.netty.channel.{Channel, ChannelFactory, ChannelHandler, EventLoopGroup} +import io.netty.channel.{Channel, ChannelFactory, ChannelFuture, ChannelHandler, EventLoopGroup} import io.netty.handler.codec.PrematureChannelClosureException import io.netty.handler.codec.http.websocketx.{WebSocketClientProtocolHandler, WebSocketFrame => JWebSocketFrame} import io.netty.handler.codec.http.{FullHttpRequest, HttpObjectAggregator} @@ -75,7 +75,7 @@ final case class NettyClientDriver private[netty] ( if (location.scheme.isWebSocket) { val httpObjectAggregator = new HttpObjectAggregator(Int.MaxValue) - val inboundHandler = new WebSocketClientInboundHandler(nettyRuntime, onResponse, onComplete) + val inboundHandler = new WebSocketClientInboundHandler(onResponse, onComplete) pipeline.addLast(Names.HttpObjectAggregator, httpObjectAggregator) pipeline.addLast(Names.ClientInboundHandler, inboundHandler) @@ -127,12 +127,7 @@ final case class NettyClientDriver private[netty] ( pipeline.addLast(Names.ClientInboundHandler, clientInbound) toRemove.add(clientInbound) - val clientFailureHandler = - new ClientFailureHandler( - nettyRuntime, - onResponse, - onComplete, - ) + val clientFailureHandler = new ClientFailureHandler(onResponse, onComplete) pipeline.addLast(Names.ClientFailureHandler, clientFailureHandler) toRemove.add(clientFailureHandler) @@ -155,26 +150,28 @@ final case class NettyClientDriver private[netty] ( } } - f.ensuring( - ZIO - .unless(location.scheme.isWebSocket) { - // If the channel was closed and the promises were not completed, this will lead to the request hanging so we need - // to listen to the close future and complete the promises - NettyFutureExecutor - .executed(channel.closeFuture()) - .interruptible - .zipRight( - // If onComplete was already set, it means another fiber is already in the process of fulfilling the promises - // so we don't need to fulfill `onResponse` - onComplete.interrupt && onResponse.fail( - new PrematureChannelClosureException( - "Channel closed while executing the request. This is likely caused due to a client connection misconfiguration", + f.ensuring { + // If the channel was closed and the promises were not completed, this will lead to the request hanging so we need + // to listen to the close future and complete the promises + ZIO.unless(location.scheme.isWebSocket) { + ZIO.succeedUnsafe { implicit u => + channel.closeFuture().addListener { (_: ChannelFuture) => + // If onComplete was already set, it means another fiber is already in the process of fulfilling the promises + // so we don't need to fulfill `onResponse` + nettyRuntime.unsafeRunSync { + ZIO.whenZIO(onComplete.interrupt)( + onResponse.fail( + new PrematureChannelClosureException( + "Channel closed while executing the request. This is likely caused due to a client connection misconfiguration", + ), ), - ), - ) + ) + } + } } - .forkScoped, - ) + } + } + } override def createConnectionPool(dnsResolver: DnsResolver, config: ConnectionPoolConfig)(implicit diff --git a/zio-http/jvm/src/main/scala/zio/http/netty/client/WebSocketClientInboundHandler.scala b/zio-http/jvm/src/main/scala/zio/http/netty/client/WebSocketClientInboundHandler.scala index e1e45f731d..f674b4934a 100644 --- a/zio-http/jvm/src/main/scala/zio/http/netty/client/WebSocketClientInboundHandler.scala +++ b/zio-http/jvm/src/main/scala/zio/http/netty/client/WebSocketClientInboundHandler.scala @@ -17,7 +17,7 @@ package zio.http.netty.client import zio.stacktracer.TracingImplicits.disableAutoTrace -import zio.{Promise, Trace, Unsafe} +import zio.{Exit, Promise, Trace, Unsafe} import zio.http.Response import zio.http.internal.ChannelState @@ -27,29 +27,23 @@ import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import io.netty.handler.codec.http.FullHttpResponse final class WebSocketClientInboundHandler( - rtm: NettyRuntime, onResponse: Promise[Throwable, Response], onComplete: Promise[Throwable, ChannelState], -)(implicit trace: Trace) - extends SimpleChannelInboundHandler[FullHttpResponse](true) { +) extends SimpleChannelInboundHandler[FullHttpResponse](true) { implicit private val unsafeClass: Unsafe = Unsafe.unsafe - override def channelActive(ctx: ChannelHandlerContext): Unit = { + override def channelActive(ctx: ChannelHandlerContext): Unit = ctx.fireChannelActive() - } override def channelRead0(ctx: ChannelHandlerContext, msg: FullHttpResponse): Unit = { - rtm.runUninterruptible(ctx, NettyRuntime.noopEnsuring) { - onResponse.succeed(NettyResponse(msg)) - }(unsafeClass, trace) - + onResponse.unsafe.done(Exit.succeed(NettyResponse(msg))) ctx.fireChannelRead(msg.retain()) ctx.pipeline().remove(ctx.name()): Unit } override def exceptionCaught(ctx: ChannelHandlerContext, error: Throwable): Unit = { - rtm.runUninterruptible(ctx, NettyRuntime.noopEnsuring)( - onResponse.fail(error) *> onComplete.fail(error), - )(unsafeClass, trace) + val exit = Exit.fail(error) + onResponse.unsafe.done(exit) + onComplete.unsafe.done(exit) } }