diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/packet/response/ColumnDefinition41Packet.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/packet/response/ColumnDefinition41Packet.scala index 0dca5644e..4eab2d159 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/packet/response/ColumnDefinition41Packet.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/packet/response/ColumnDefinition41Packet.scala @@ -10,8 +10,10 @@ package response import java.nio.charset.StandardCharsets.UTF_8 import scodec.* -import scodec.bits.{BitVector, ByteOrdering} +import scodec.bits.{ BitVector, ByteOrdering } + import cats.syntax.all.* + import ldbc.connector.data.* /** @@ -89,36 +91,36 @@ case class ColumnDefinition41Packet( object ColumnDefinition41Packet: private def decodeToString(bits: BitVector): (BitVector, String) = - val (sizeBits, remainder) = bits.splitAt(8) - val size = sizeBits.toLong(signed = false) + val (sizeBits, remainder) = bits.splitAt(8) + val size = sizeBits.toLong(signed = false) val (valueBits, postValue) = remainder.splitAt(size * 8L) (postValue, new String(valueBits.toByteArray, UTF_8)) val decoder: Decoder[ColumnDefinition41Packet] = (bits: BitVector) => - val (catalogBits, catalog) = decodeToString(bits) - val (schemaBits, schema) = decodeToString(catalogBits) - val (tableBits, table) = decodeToString(schemaBits) - val (orgTableBits, orgTable) = decodeToString(tableBits) - val (nameBits, name) = decodeToString(orgTableBits) - val (orgNameBits, orgName) = decodeToString(nameBits) - val (length, lengthBits) = orgNameBits.splitAt(8) + val (catalogBits, catalog) = decodeToString(bits) + val (schemaBits, schema) = decodeToString(catalogBits) + val (tableBits, table) = decodeToString(schemaBits) + val (orgTableBits, orgTable) = decodeToString(tableBits) + val (nameBits, name) = decodeToString(orgTableBits) + val (orgNameBits, orgName) = decodeToString(nameBits) + val (length, lengthBits) = orgNameBits.splitAt(8) val (characterSet, characterSetBits) = lengthBits.splitAt(16) val (columnLength, columnLengthBits) = characterSetBits.splitAt(32) - val (columnType, columnTypeBits) = columnLengthBits.splitAt(8) - val (flags, decimals) = columnTypeBits.splitAt(16) + val (columnType, columnTypeBits) = columnLengthBits.splitAt(8) + val (flags, decimals) = columnTypeBits.splitAt(16) val packet = ColumnDefinition41Packet( - catalog = catalog, - schema = schema, - table = table, - orgTable = orgTable, - name = name, - orgName = orgName, - length = length.toInt(signed = false), + catalog = catalog, + schema = schema, + table = table, + orgTable = orgTable, + name = name, + orgName = orgName, + length = length.toInt(signed = false), characterSet = characterSet.toInt(signed = false), columnLength = columnLength.toLong(signed = false), - columnType = ColumnDataType(columnType.toInt(signed = false)), - flags = ColumnDefinitionFlags(flags.toInt(signed = false, ordering = ByteOrdering.LittleEndian)), - decimals = decimals.toInt() + columnType = ColumnDataType(columnType.toInt(signed = false)), + flags = ColumnDefinitionFlags(flags.toInt(signed = false, ordering = ByteOrdering.LittleEndian)), + decimals = decimals.toInt() ) Attempt.successful(DecodeResult(packet, bits)) diff --git a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/packet/response/InitialPacket.scala b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/packet/response/InitialPacket.scala index 53a6a6875..59bb1696b 100644 --- a/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/packet/response/InitialPacket.scala +++ b/module/ldbc-connector/shared/src/main/scala/ldbc/connector/net/packet/response/InitialPacket.scala @@ -8,10 +8,14 @@ package ldbc.connector.net.packet package response import java.nio.charset.StandardCharsets.UTF_8 + import scodec.* import scodec.bits.* + import cats.syntax.all.* + import org.typelevel.otel4s.Attribute + import ldbc.connector.data.* import ldbc.connector.util.Version @@ -53,22 +57,23 @@ object InitialPacket: val decoder: Decoder[InitialPacket] = (bits: BitVector) => val (protocolVersion, reminder0) = bits.splitAt(8) - val bytes = reminder0.bytes.takeWhile(_ != 0) - val serverVersion = new String(bytes.toArray, UTF_8) - val remainder1 = reminder0.drop((bytes.size + 1) * 8) // +1 is a null character, so *8 is a byte to bit + val bytes = reminder0.bytes.takeWhile(_ != 0) + val serverVersion = new String(bytes.toArray, UTF_8) + val remainder1 = reminder0.drop((bytes.size + 1) * 8) // +1 is a null character, so *8 is a byte to bit val (threadId, reminder2) = remainder1.splitAt(32) - val (authPluginDataPart1, reminder3) = reminder2.splitAt(64) - val reminder4 = reminder3.drop(8) // Skip filter [0x00] + val (authPluginDataPart1, reminder3) = reminder2.splitAt(64) + val reminder4 = reminder3.drop(8) // Skip filter [0x00] val (capabilityFlagsLower, reminder5) = reminder4.splitAt(16) - val (characterSet, reminder6) = reminder5.splitAt(8) - val (statusFlag, reminder7) = reminder6.splitAt(16) + val (characterSet, reminder6) = reminder5.splitAt(8) + val (statusFlag, reminder7) = reminder6.splitAt(16) val (capabilityFlagsUpper, reminder8) = reminder7.splitAt(16) - val capabilityFlags = (capabilityFlagsUpper.toInt(false, ByteOrdering.LittleEndian) << 16) | capabilityFlagsLower.toInt(false, ByteOrdering.LittleEndian) + val capabilityFlags = (capabilityFlagsUpper.toInt(false, ByteOrdering.LittleEndian) << 16) | capabilityFlagsLower + .toInt(false, ByteOrdering.LittleEndian) val (authPluginDataPart2Length, reminder9) = if (capabilityFlags & (1 << 19)) != 0 then val (v1, v2) = reminder8.splitAt(8) (v1.toInt(false), v2) else (0, reminder8) - val reminder10 = reminder9.drop(10 * 8) // Skip reserved bytes (10 bytes) + val reminder10 = reminder9.drop(10 * 8) // Skip reserved bytes (10 bytes) val (authPluginDataPart2, reminder11) = reminder10.splitAt(math.max(13, authPluginDataPart2Length - 8) * 8) val authPluginName = if (capabilityFlags & (1 << 19)) != 0 then val bytes = reminder11.bytes.takeWhile(_ != 0) @@ -77,17 +82,17 @@ object InitialPacket: val version = Version(serverVersion) match case Some(v) => v - case None => Version(0, 0, 0) + case None => Version(0, 0, 0) val packet = InitialPacket( protocolVersion = protocolVersion.toInt(false), - serverVersion = version, - threadId = threadId.toInt(true, ByteOrdering.LittleEndian), + serverVersion = version, + threadId = threadId.toInt(true, ByteOrdering.LittleEndian), capabilityFlags = CapabilitiesFlags(capabilityFlags), - characterSet = characterSet.toInt(false, ByteOrdering.LittleEndian), - statusFlags = ServerStatusFlags(statusFlag.toInt(false, ByteOrdering.LittleEndian)), - scrambleBuff = authPluginDataPart1.toByteArray ++ authPluginDataPart2.toByteArray.dropRight(1), - authPlugin = authPluginName + characterSet = characterSet.toInt(false, ByteOrdering.LittleEndian), + statusFlags = ServerStatusFlags(statusFlag.toInt(false, ByteOrdering.LittleEndian)), + scrambleBuff = authPluginDataPart1.toByteArray ++ authPluginDataPart2.toByteArray.dropRight(1), + authPlugin = authPluginName ) Attempt.successful(DecodeResult(packet, bits))