Skip to content

Commit

Permalink
InitialPacket decoder performance improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
takapi327 committed Jan 30, 2025
1 parent 0764a27 commit 6f4c5e5
Showing 1 changed file with 39 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@
package ldbc.connector.net.packet
package response

import java.nio.charset.StandardCharsets.UTF_8
import scodec.*
import scodec.codecs.*

import scodec.bits.*
import cats.syntax.all.*

import org.typelevel.otel4s.Attribute

import ldbc.connector.data.*
import ldbc.connector.util.Version

Expand All @@ -40,7 +38,7 @@ case class InitialPacket(
authPlugin: String
) extends ResponsePacket:

val attributes: List[Attribute[String]] = List(
def attributes: List[Attribute[String]] = List(
Attribute("protocol.version", protocolVersion.toString),
Attribute("server.version", serverVersion.toString),
Attribute("thread.id", threadId.toString),
Expand All @@ -52,45 +50,44 @@ case class InitialPacket(

object InitialPacket:

private val protocolVersionCodec: Codec[Int] = uint8
private val threadIdCodec: Codec[Int] = int32
private val authPluginDataPart1Codec: Codec[(Byte, Byte, Byte, Byte, Byte, Byte, Byte, Byte)] =
byte :: byte :: byte :: byte :: byte :: byte :: byte :: byte
private val capabilityFlagsLowerCodec: Codec[Int] = uint16L
private val capabilityFlagsUpperCodec: Codec[Int] = uint16L

val decoder: Decoder[InitialPacket] =
for
protocolVersion <- protocolVersionCodec.asDecoder
serverVersion <- nullTerminatedStringCodec.asDecoder
threadId <- threadIdCodec.asDecoder
authPluginDataPart1 <- authPluginDataPart1Codec.map {
case (a, b, c, d, e, f, g, h) => Array(a, b, c, d, e, f, g, h)
}
_ <- ignore(8) // Skip filter [0x00]
capabilityFlagsLower <- capabilityFlagsLowerCodec.asDecoder
characterSet <- uint8L.asDecoder
statusFlag <- uint16L.asDecoder
capabilityFlagsUpper <- capabilityFlagsUpperCodec.asDecoder
capabilityFlags = (capabilityFlagsUpper << 16) | capabilityFlagsLower
authPluginDataPart2Length <- if (capabilityFlags & (1 << 19)) != 0 then uint8.asDecoder else Decoder.pure(0)
_ <- ignore(10 * 8) // Skip reserved bytes (10 bytes)
authPluginDataPart2 <- bytes(math.max(13, authPluginDataPart2Length - 8)).asDecoder
authPluginName <-
if (capabilityFlags & (1 << 19)) != 0 then nullTerminatedStringCodec.asDecoder else Decoder.pure("")
yield
val capabilityFlags = (capabilityFlagsUpper << 16) | capabilityFlagsLower
(bits: BitVector) =>
val (protocolVersion, reminder0) = bits.splitAt(8)
val bytes = reminder0.bytes.takeWhile(_ != 0)
val serverVersion = new String(bytes.toArray, UTF_8)
val remainder1 = reminder0.drop((bytes.size + 1) * 8) // +1 is a null character, so *8 is a byte to bit
val (threadId, reminder2) = remainder1.splitAt(32)
val (authPluginDataPart1, reminder3) = reminder2.splitAt(64)
val reminder4 = reminder3.drop(8) // Skip filter [0x00]
val (capabilityFlagsLower, reminder5) = reminder4.splitAt(16)
val (characterSet, reminder6) = reminder5.splitAt(8)
val (statusFlag, reminder7) = reminder6.splitAt(16)
val (capabilityFlagsUpper, reminder8) = reminder7.splitAt(16)
val capabilityFlags = (capabilityFlagsUpper.toInt(false, ByteOrdering.LittleEndian) << 16) | capabilityFlagsLower.toInt(false, ByteOrdering.LittleEndian)
val (authPluginDataPart2Length, reminder9) = if (capabilityFlags & (1 << 19)) != 0 then
val (v1, v2) = reminder8.splitAt(8)
(v1.toInt(false), v2)
else (0, reminder8)
val reminder10 = reminder9.drop(10 * 8) // Skip reserved bytes (10 bytes)
val (authPluginDataPart2, reminder11) = reminder10.splitAt(math.max(13, authPluginDataPart2Length - 8) * 8)
val authPluginName = if (capabilityFlags & (1 << 19)) != 0 then
val bytes = reminder11.bytes.takeWhile(_ != 0)
new String(bytes.toArray, UTF_8)
else ""

val version = Version(serverVersion) match
case Some(v) => v
case None => Version(0, 0, 0)
case None => Version(0, 0, 0)

InitialPacket(
protocolVersion,
version,
threadId,
CapabilitiesFlags(capabilityFlags),
characterSet,
ServerStatusFlags(statusFlag),
authPluginDataPart1 ++ authPluginDataPart2.toArray.dropRight(1),
authPluginName
val packet = InitialPacket(
protocolVersion = protocolVersion.toInt(false),
serverVersion = version,
threadId = threadId.toInt(true, ByteOrdering.LittleEndian),
capabilityFlags = CapabilitiesFlags(capabilityFlags),
characterSet = characterSet.toInt(false, ByteOrdering.LittleEndian),
statusFlags = ServerStatusFlags(statusFlag.toInt(false, ByteOrdering.LittleEndian)),
scrambleBuff = authPluginDataPart1.toByteArray ++ authPluginDataPart2.toByteArray.dropRight(1),
authPlugin = authPluginName
)

Attempt.successful(DecodeResult(packet, bits))

0 comments on commit 6f4c5e5

Please sign in to comment.