Skip to content

Commit

Permalink
[ZClient] Unsafely fulfil promises to avoid race conditions (#2843)
Browse files Browse the repository at this point in the history
Unsafely fulfill Client promises to avoid race conditions
  • Loading branch information
kyri-petrou authored May 14, 2024
1 parent 1822f19 commit 21d5324
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 85 deletions.
16 changes: 6 additions & 10 deletions zio-http/jvm/src/main/scala/zio/http/netty/NettyResponse.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package zio.http.netty

import zio.stacktracer.TracingImplicits.disableAutoTrace
import zio.{Promise, Trace, Unsafe, ZIO}
import zio.{Exit, Promise, Trace, Unsafe, ZIO}

import zio.http.internal.ChannelState
import zio.http.netty.client.ClientResponseStreamHandler
Expand All @@ -42,25 +42,21 @@ object NettyResponse {
def make(
ctx: ChannelHandlerContext,
jRes: HttpResponse,
zExec: NettyRuntime,
onComplete: Promise[Throwable, ChannelState],
keepAlive: Boolean,
)(implicit
unsafe: Unsafe,
trace: Trace,
): ZIO[Any, Nothing, Response] = {
): Response = {
val status = Conversions.statusFromNetty(jRes.status())
val headers = Conversions.headersFromNetty(jRes.headers())
val knownContentLength = headers.get(Header.ContentLength).map(_.length)

if (knownContentLength.contains(0L)) {
onComplete
.succeed(ChannelState.forStatus(status))
.as(
Response(status, headers, Body.empty),
)
onComplete.unsafe.done(Exit.succeed(ChannelState.forStatus(status)))
Response(status, headers, Body.empty)
} else {
val responseHandler = new ClientResponseStreamHandler(zExec, onComplete, keepAlive, status)
val responseHandler = new ClientResponseStreamHandler(onComplete, keepAlive, status)
ctx
.pipeline()
.addAfter(
Expand All @@ -70,7 +66,7 @@ object NettyResponse {
): Unit

val data = NettyBody.fromAsync(callback => responseHandler.connect(callback), knownContentLength)
ZIO.succeed(Response(status, headers, data))
Response(status, headers, data)
}
}
}
Original file line number Diff line number Diff line change
@@ -1,25 +1,22 @@
package zio.http.netty.client

import zio.{Promise, Trace, Unsafe}
import zio.{Exit, Promise, Unsafe}

import zio.http.Response
import zio.http.internal.ChannelState
import zio.http.netty.NettyRuntime

import io.netty.channel.{ChannelHandlerContext, ChannelInboundHandlerAdapter}

/** Handles failures happening in ClientInboundHandler */
final class ClientFailureHandler(
rtm: NettyRuntime,
onResponse: Promise[Throwable, Response],
onComplete: Promise[Throwable, ChannelState],
)(implicit trace: Trace)
extends ChannelInboundHandlerAdapter {
) extends ChannelInboundHandlerAdapter {
implicit private val unsafeClass: Unsafe = Unsafe.unsafe

override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
rtm.runUninterruptible(ctx, NettyRuntime.noopEnsuring)(
onResponse.fail(cause) *> onComplete.fail(cause),
)(unsafeClass, trace)
val exit = Exit.fail(cause)
onResponse.unsafe.done(exit)
onComplete.unsafe.done(exit)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,9 @@ final class ClientInboundHandler(
override def channelRead0(ctx: ChannelHandlerContext, msg: HttpObject): Unit = {
msg match {
case response: HttpResponse =>
rtm.runUninterruptible(ctx, NettyRuntime.noopEnsuring) {
NettyResponse
.make(
ctx,
response,
rtm,
onComplete,
enableKeepAlive && HttpUtil.isKeepAlive(response),
)
.flatMap(onResponse.succeed)
}(unsafeClass, trace)
val keepAlive = enableKeepAlive && HttpUtil.isKeepAlive(response)
val resp = NettyResponse.make(ctx, response, onComplete, keepAlive)
onResponse.unsafe.done(Exit.succeed(resp))
case content: HttpContent =>
ctx.fireChannelRead(content): Unit

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,44 +17,38 @@
package zio.http.netty.client

import zio.stacktracer.TracingImplicits.disableAutoTrace
import zio.{Promise, Trace}
import zio.{Exit, Promise, Trace, Unsafe}

import zio.http.Status
import zio.http.internal.ChannelState
import zio.http.netty.{AsyncBodyReader, NettyFutureExecutor, NettyRuntime}
import zio.http.netty.AsyncBodyReader

import io.netty.channel._
import io.netty.handler.codec.http.{HttpContent, LastHttpContent}

final class ClientResponseStreamHandler(
rtm: NettyRuntime,
onComplete: Promise[Throwable, ChannelState],
keepAlive: Boolean,
status: Status,
)(implicit trace: Trace)
extends AsyncBodyReader { self =>

override def channelRead0(
ctx: ChannelHandlerContext,
msg: HttpContent,
): Unit = {
private implicit val unsafe: Unsafe = Unsafe.unsafe

override def channelRead0(ctx: ChannelHandlerContext, msg: HttpContent): Unit = {
val isLast = msg.isInstanceOf[LastHttpContent]
super.channelRead0(ctx, msg)

if (isLast) {
if (keepAlive)
rtm.runUninterruptible(ctx, NettyRuntime.noopEnsuring)(onComplete.succeed(ChannelState.forStatus(status)))(
unsafeClass,
trace,
)
onComplete.unsafe.done(Exit.succeed(ChannelState.forStatus(status)))
else {
rtm.runUninterruptible(ctx, NettyRuntime.noopEnsuring)(
onComplete.succeed(ChannelState.Invalid) *> NettyFutureExecutor.executed(ctx.close()),
)(unsafeClass, trace)
onComplete.unsafe.done(Exit.succeed(ChannelState.Invalid))
ctx.close()
}
}
}

override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
rtm.runUninterruptible(ctx, NettyRuntime.noopEnsuring)(onComplete.fail(cause))(unsafeClass, trace)
}
override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit =
onComplete.unsafe.done(Exit.fail(cause))
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import zio.http.netty._
import zio.http.netty.model.Conversions
import zio.http.netty.socket.NettySocketProtocol

import io.netty.channel.{Channel, ChannelFactory, ChannelHandler, EventLoopGroup}
import io.netty.channel.{Channel, ChannelFactory, ChannelFuture, ChannelHandler, EventLoopGroup}
import io.netty.handler.codec.PrematureChannelClosureException
import io.netty.handler.codec.http.websocketx.{WebSocketClientProtocolHandler, WebSocketFrame => JWebSocketFrame}
import io.netty.handler.codec.http.{FullHttpRequest, HttpObjectAggregator}
Expand Down Expand Up @@ -75,7 +75,7 @@ final case class NettyClientDriver private[netty] (

if (location.scheme.isWebSocket) {
val httpObjectAggregator = new HttpObjectAggregator(Int.MaxValue)
val inboundHandler = new WebSocketClientInboundHandler(nettyRuntime, onResponse, onComplete)
val inboundHandler = new WebSocketClientInboundHandler(onResponse, onComplete)

pipeline.addLast(Names.HttpObjectAggregator, httpObjectAggregator)
pipeline.addLast(Names.ClientInboundHandler, inboundHandler)
Expand Down Expand Up @@ -127,12 +127,7 @@ final case class NettyClientDriver private[netty] (
pipeline.addLast(Names.ClientInboundHandler, clientInbound)
toRemove.add(clientInbound)

val clientFailureHandler =
new ClientFailureHandler(
nettyRuntime,
onResponse,
onComplete,
)
val clientFailureHandler = new ClientFailureHandler(onResponse, onComplete)
pipeline.addLast(Names.ClientFailureHandler, clientFailureHandler)
toRemove.add(clientFailureHandler)

Expand All @@ -155,26 +150,28 @@ final case class NettyClientDriver private[netty] (
}
}

f.ensuring(
ZIO
.unless(location.scheme.isWebSocket) {
// If the channel was closed and the promises were not completed, this will lead to the request hanging so we need
// to listen to the close future and complete the promises
NettyFutureExecutor
.executed(channel.closeFuture())
.interruptible
.zipRight(
// If onComplete was already set, it means another fiber is already in the process of fulfilling the promises
// so we don't need to fulfill `onResponse`
onComplete.interrupt && onResponse.fail(
new PrematureChannelClosureException(
"Channel closed while executing the request. This is likely caused due to a client connection misconfiguration",
f.ensuring {
// If the channel was closed and the promises were not completed, this will lead to the request hanging so we need
// to listen to the close future and complete the promises
ZIO.unless(location.scheme.isWebSocket) {
ZIO.succeedUnsafe { implicit u =>
channel.closeFuture().addListener { (_: ChannelFuture) =>
// If onComplete was already set, it means another fiber is already in the process of fulfilling the promises
// so we don't need to fulfill `onResponse`
nettyRuntime.unsafeRunSync {
ZIO.whenZIO(onComplete.interrupt)(
onResponse.fail(
new PrematureChannelClosureException(
"Channel closed while executing the request. This is likely caused due to a client connection misconfiguration",
),
),
),
)
)
}
}
}
.forkScoped,
)
}
}

}

override def createConnectionPool(dnsResolver: DnsResolver, config: ConnectionPoolConfig)(implicit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package zio.http.netty.client

import zio.stacktracer.TracingImplicits.disableAutoTrace
import zio.{Promise, Trace, Unsafe}
import zio.{Exit, Promise, Trace, Unsafe}

import zio.http.Response
import zio.http.internal.ChannelState
Expand All @@ -27,29 +27,23 @@ import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
import io.netty.handler.codec.http.FullHttpResponse

final class WebSocketClientInboundHandler(
rtm: NettyRuntime,
onResponse: Promise[Throwable, Response],
onComplete: Promise[Throwable, ChannelState],
)(implicit trace: Trace)
extends SimpleChannelInboundHandler[FullHttpResponse](true) {
) extends SimpleChannelInboundHandler[FullHttpResponse](true) {
implicit private val unsafeClass: Unsafe = Unsafe.unsafe

override def channelActive(ctx: ChannelHandlerContext): Unit = {
override def channelActive(ctx: ChannelHandlerContext): Unit =
ctx.fireChannelActive()
}

override def channelRead0(ctx: ChannelHandlerContext, msg: FullHttpResponse): Unit = {
rtm.runUninterruptible(ctx, NettyRuntime.noopEnsuring) {
onResponse.succeed(NettyResponse(msg))
}(unsafeClass, trace)

onResponse.unsafe.done(Exit.succeed(NettyResponse(msg)))
ctx.fireChannelRead(msg.retain())
ctx.pipeline().remove(ctx.name()): Unit
}

override def exceptionCaught(ctx: ChannelHandlerContext, error: Throwable): Unit = {
rtm.runUninterruptible(ctx, NettyRuntime.noopEnsuring)(
onResponse.fail(error) *> onComplete.fail(error),
)(unsafeClass, trace)
val exit = Exit.fail(error)
onResponse.unsafe.done(exit)
onComplete.unsafe.done(exit)
}
}

0 comments on commit 21d5324

Please sign in to comment.