Skip to content

Commit

Permalink
Improve the Client's performance for non-stream bodies (#2919)
Browse files Browse the repository at this point in the history
* Fix memory leak in netty connection pool

* Use a forked effect to monitor the close future for client requests

* Use Netty's future listener again

* fmt

* Remove suspendSucceed

* Improve non-streaming performance of Client

* Cleanups in AsyncBodyReader

* One more improvement

* Reimplement using buffering within `AsyncBodyReader`

* Add benchmarks

* Fix url interpolator macro

* Re-generate GH workflow

* fmt
  • Loading branch information
kyri-petrou authored Jun 20, 2024
1 parent 8de4a41 commit fd5eb22
Show file tree
Hide file tree
Showing 7 changed files with 247 additions and 89 deletions.
43 changes: 42 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,40 @@ jobs:
name: Jmh_Main_CachedDateHeaderBenchmark
path: Main_CachedDateHeaderBenchmark.txt

Jmh_ClientBenchmark:
name: Jmh ClientBenchmark
if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
strategy:
matrix:
os: [ubuntu-latest]
scala: [2.13.14]
java: [temurin@8]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
with:
path: zio-http

- uses: actions/setup-java@v2
with:
distribution: temurin
java-version: 11

- name: Benchmark_Main
id: Benchmark_Main
env:
GITHUB_TOKEN: ${{secrets.ACTIONS_PAT}}
run: |
cd zio-http
sed -i -e '$aaddSbtPlugin("pl.project13.scala" % "sbt-jmh" % "0.4.7")' project/plugins.sbt
cat > Main_ClientBenchmark.txt
sbt -no-colors -v "zioHttpBenchmarks/jmh:run -i 3 -wi 3 -f1 -t1 ClientBenchmark" | grep -e "thrpt" -e "avgt" >> ../Main_ClientBenchmark.txt
- uses: actions/upload-artifact@v3
with:
name: Jmh_Main_ClientBenchmark
path: Main_ClientBenchmark.txt

Jmh_CookieDecodeBenchmark:
name: Jmh CookieDecodeBenchmark
if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
Expand Down Expand Up @@ -627,7 +661,7 @@ jobs:

Jmh_cache:
name: Cache Jmh benchmarks
needs: [Jmh_CachedDateHeaderBenchmark, Jmh_CookieDecodeBenchmark, Jmh_EndpointBenchmark, Jmh_HttpCollectEval, Jmh_HttpCombineEval, Jmh_HttpNestedFlatMapEval, Jmh_HttpRouteTextPerf, Jmh_ProbeContentTypeBenchmark, Jmh_SchemeDecodeBenchmark, Jmh_ServerInboundHandlerBenchmark, Jmh_UtilBenchmark]
needs: [Jmh_CachedDateHeaderBenchmark, Jmh_ClientBenchmark, Jmh_CookieDecodeBenchmark, Jmh_EndpointBenchmark, Jmh_HttpCollectEval, Jmh_HttpCombineEval, Jmh_HttpNestedFlatMapEval, Jmh_HttpRouteTextPerf, Jmh_ProbeContentTypeBenchmark, Jmh_SchemeDecodeBenchmark, Jmh_ServerInboundHandlerBenchmark, Jmh_UtilBenchmark]
if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
strategy:
matrix:
Expand All @@ -643,6 +677,13 @@ jobs:
- name: Format_Main_CachedDateHeaderBenchmark
run: cat Main_CachedDateHeaderBenchmark.txt >> Main_benchmarks.txt

- uses: actions/download-artifact@v3
with:
name: Jmh_Main_ClientBenchmark

- name: Format_Main_ClientBenchmark
run: cat Main_ClientBenchmark.txt >> Main_benchmarks.txt

- uses: actions/download-artifact@v3
with:
name: Jmh_Main_CookieDecodeBenchmark
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package zhttp.benchmarks

import java.util.concurrent.TimeUnit

import scala.annotation.nowarn

import zio._

import zio.http._

import org.openjdk.jmh.annotations._

@nowarn
@State(org.openjdk.jmh.annotations.Scope.Benchmark)
@BenchmarkMode(Array(Mode.Throughput))
@OutputTimeUnit(TimeUnit.SECONDS)
@Warmup(iterations = 3, time = 3)
@Measurement(iterations = 3, time = 3)
@Fork(1)
class ClientBenchmark {
private val random = scala.util.Random
random.setSeed(42)

private implicit val unsafe: Unsafe = Unsafe.unsafe(identity)

@Param(Array("small", "large"))
var path: String = _

private val smallString = "Hello!".getBytes
private val largeString = random.alphanumeric.take(10000).mkString.getBytes

private val smallRequest = Request(url = url"http://0.0.0.0:8080/small")
private val largeRequest = Request(url = url"http://0.0.0.0:8080/large")

private val smallResponse = Response(status = Status.Ok, body = Body.fromArray(smallString))
private val largeResponse = Response(status = Status.Ok, body = Body.fromArray(largeString))

private val smallRoute = Route.route(Method.GET / "small")(handler(smallResponse))
private val largeRoute = Route.route(Method.GET / "large")(handler(largeResponse))

private val shutdownResponse = Response.text("shutting down")

private def shutdownRoute(shutdownSignal: Promise[Nothing, Unit]) =
Route.route(Method.GET / "shutdown")(handler(shutdownSignal.succeed(()).as(shutdownResponse)))

private def http(shutdownSignal: Promise[Nothing, Unit]) =
Routes(smallRoute, largeRoute, shutdownRoute(shutdownSignal))

private val rtm = Runtime.unsafe.fromLayer(ZClient.default)
private val runtime = rtm.unsafe

private def run(f: RIO[Client, Any]): Any = runtime.run(f).getOrThrow()

@Setup(Level.Trial)
def setup(): Unit = {
val startServer: Task[Unit] = (for {
shutdownSignal <- Promise.make[Nothing, Unit]
fiber <- Server.serve(http(shutdownSignal)).fork
_ <- shutdownSignal.await *> fiber.interrupt
} yield ()).provideLayer(Server.default)

val waitForServerStarted: Task[Unit] = (for {
client <- ZIO.service[Client]
_ <- client.request(smallRequest)
} yield ()).provide(ZClient.default, zio.Scope.default)

run(startServer.forkDaemon *> waitForServerStarted.retry(Schedule.fixed(1.second)))
}

@TearDown(Level.Trial)
def tearDown(): Unit = {
val stopServer = (for {
client <- ZIO.service[Client]
_ <- client.request(Request(url = url"http://localhost:8080/shutdown"))
} yield ()).provide(ZClient.default, zio.Scope.default)
run(stopServer)
rtm.shutdown0()
}

@Benchmark
@OperationsPerInvocation(100)
def zhttpChunkBenchmark(): Any = run {
val req = if (path == "small") smallRequest else largeRequest
ZIO.serviceWithZIO[Client] { client =>
ZIO.scoped(client.request(req).flatMap(_.body.asChunk)).repeatN(100)
}
}

@Benchmark
@OperationsPerInvocation(100)
def zhttpStreamToChunkBenchmark(): Any = run {
val req = if (path == "small") smallRequest else largeRequest
ZIO.serviceWithZIO[Client] { client =>
ZIO.scoped(client.request(req).flatMap(_.body.asStream.runCollect)).repeatN(100)
}
}
}
85 changes: 55 additions & 30 deletions zio-http/jvm/src/main/scala/zio/http/netty/AsyncBodyReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,45 +18,50 @@ package zio.http.netty

import java.io.IOException

import scala.collection.mutable

import zio.Chunk
import zio.stacktracer.TracingImplicits.disableAutoTrace
import zio.{Chunk, ChunkBuilder}

import zio.http.netty.AsyncBodyReader.State
import zio.http.netty.NettyBody.UnsafeAsync

import io.netty.buffer.ByteBufUtil
import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
import io.netty.handler.codec.http.{HttpContent, LastHttpContent}

abstract class AsyncBodyReader extends SimpleChannelInboundHandler[HttpContent](true) {
import zio.http.netty.AsyncBodyReader._

private var state: State = State.Buffering
private val buffer = new mutable.ArrayBuilder.ofByte()
private var previousAutoRead: Boolean = false
private var readingDone: Boolean = false
private var ctx: ChannelHandlerContext = _

private var state: State = State.Buffering
private val buffer: ChunkBuilder[(Chunk[Byte], Boolean)] = ChunkBuilder.make[(Chunk[Byte], Boolean)]()
private var previousAutoRead: Boolean = false
private var ctx: ChannelHandlerContext = _
private def result(buffer: mutable.ArrayBuilder.ofByte): Chunk[Byte] = {
val arr = buffer.result()
Chunk.ByteArray(arr, 0, arr.length)
}

private[zio] def connect(callback: UnsafeAsync): Unit = {
val buffer0 = buffer // Avoid reading it from the heap in the synchronized block
this.synchronized {
state match {
case State.Buffering =>
val result: Chunk[(Chunk[Byte], Boolean)] = buffer.result()
val readingDone: Boolean = result.lastOption match {
case None => false
case Some((_, isLast)) => isLast
}
buffer.clear() // GC

if (ctx.channel.isOpen || readingDone) {
state = State.Direct(callback)
result.foreach { case (chunk, isLast) =>
callback(chunk, isLast)
state = State.Direct(callback)

if (readingDone) {
callback(result(buffer0), isLast = true)
} else if (ctx.channel().isOpen) {
callback match {
case UnsafeAsync.Aggregating(bufSize) => buffer.sizeHint(bufSize)
case cb => cb(result(buffer0), isLast = false)
}
ctx.read(): Unit
} else {
throw new IllegalStateException("Attempting to read from a closed channel, which will never finish")
}

case State.Direct(_) =>
case _ =>
throw new IllegalStateException("Cannot connect twice")
}
}
Expand All @@ -76,22 +81,36 @@ abstract class AsyncBodyReader extends SimpleChannelInboundHandler[HttpContent](
ctx: ChannelHandlerContext,
msg: HttpContent,
): Unit = {
val isLast = msg.isInstanceOf[LastHttpContent]
val chunk = Chunk.fromArray(ByteBufUtil.getBytes(msg.content()))
val buffer0 = buffer // Avoid reading it from the heap in the synchronized block

this.synchronized {
val isLast = msg.isInstanceOf[LastHttpContent]
val content = ByteBufUtil.getBytes(msg.content())

state match {
case State.Buffering =>
buffer += ((chunk, isLast))
case State.Direct(callback) =>
callback(chunk, isLast)
ctx.read()
case State.Buffering =>
// `connect` method hasn't been called yet, add all incoming content to the buffer
buffer0.addAll(content)
case State.Direct(callback) if isLast && buffer0.knownSize == 0 =>
// Buffer is empty, we can just use the array directly
callback(Chunk.fromArray(content), isLast = true)
case State.Direct(callback: UnsafeAsync.Aggregating) =>
// We're aggregating the full response, only call the callback on the last message
buffer0.addAll(content)
if (isLast) callback(result(buffer0), isLast = true)
case State.Direct(callback) =>
// We're streaming, emit chunks as they come
callback(Chunk.fromArray(content), isLast)
}
}

if (isLast) {
ctx.channel().pipeline().remove(this)
}: Unit
if (isLast) {
readingDone = true
ctx.channel().pipeline().remove(this)
} else {
ctx.read()
}
()
}
}

override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
Expand Down Expand Up @@ -125,4 +144,10 @@ object AsyncBodyReader {

final case class Direct(callback: UnsafeAsync) extends State
}

// For Scala 2.12. In Scala 2.13+, the methods directly implemented on ArrayBuilder[Byte] are selected over syntax.
private implicit class ByteArrayBuilderOps[A](private val self: mutable.ArrayBuilder[Byte]) extends AnyVal {
def addAll(as: Array[Byte]): Unit = self ++= as
def knownSize: Int = -1
}
}
Loading

0 comments on commit fd5eb22

Please sign in to comment.