From 8927ca5a139d450abf4b13b5a79b896c36031293 Mon Sep 17 00:00:00 2001 From: Nikita Sokolov Date: Wed, 30 Oct 2024 16:59:06 +0100 Subject: [PATCH 01/12] ArrowTableRowsSerializer --- .../client/ArrowTableRowsSerializer.java | 1171 +++++++++++++++++ .../ArrowWriteSerializationContext.java | 24 + .../ytsaurus/client/InternalRowYTGetters.java | 8 + .../ytsaurus/client/TableWriterBaseImpl.java | 114 ++ .../scala/tech/ytsaurus/client/YTGetters.java | 177 +++ .../ytsaurus/spyt/format/YtOutputWriter.scala | 18 +- .../serializers/InternalRowSerializer.scala | 145 -- .../serializers/WriteSchemaConverter.scala | 3 +- .../spyt/serializers/YtLogicalType.scala | 951 ++++++++++++- .../serializers/YtLogicalTypeSerializer.scala | 8 +- 10 files changed, 2432 insertions(+), 187 deletions(-) create mode 100644 data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java create mode 100644 data-source/src/main/scala/tech/ytsaurus/client/ArrowWriteSerializationContext.java create mode 100644 data-source/src/main/scala/tech/ytsaurus/client/InternalRowYTGetters.java create mode 100644 data-source/src/main/scala/tech/ytsaurus/client/TableWriterBaseImpl.java create mode 100644 data-source/src/main/scala/tech/ytsaurus/client/YTGetters.java diff --git a/data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java b/data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java new file mode 100644 index 00000000..61ac223b --- /dev/null +++ b/data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java @@ -0,0 +1,1171 @@ +package tech.ytsaurus.client; + +import io.netty.buffer.ByteBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.*; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.arrow.vector.ipc.WriteChannel; +import org.apache.arrow.vector.ipc.message.IpcOption; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import tech.ytsaurus.rpcproxy.ERowsetFormat; +import tech.ytsaurus.rpcproxy.TRowsetDescriptor; +import tech.ytsaurus.spyt.format.batch.ArrowUtils; +import tech.ytsaurus.spyt.serialization.YsonDecoder; +import tech.ytsaurus.typeinfo.DecimalType; +import tech.ytsaurus.yson.YsonBinaryWriter; +import tech.ytsaurus.ysontree.YTreeBuilder; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.util.*; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class ArrowTableRowsSerializer> extends TableRowsSerializer implements AutoCloseable { + private abstract class ArrowGetterFromStruct { + public final Field field; + public final ArrowType arrowType; + + ArrowGetterFromStruct(Field field) { + super(); + this.field = field; + this.arrowType = field.getType(); + } + + public final ArrowType getArrowType() { + return arrowType; + } + + public abstract ArrowWriterFromStruct writer(ValueVector valueVector); + } + + private abstract class ArrowWriterFromStruct { + abstract void setFromStruct(Struct struct); + } + + private abstract class ArrowGetterFromList { + public final Field field; + public final ArrowType arrowType; + + ArrowGetterFromList(Field field) { + this.field = field; + this.arrowType = field.getType(); + } + + public final ArrowType getArrowType() { + return arrowType; + } + + public abstract ArrowWriterFromList writer(ValueVector valueVector); + } + + private abstract class ArrowWriterFromList { + abstract void setFromList(List list, int i); + } + + private ArrowGetterFromList arrowGetter(String name, Getters.FromList getter) { + var optionalGetter = getter instanceof YTGetters.FromListToOptional + ? (Getters.FromListToOptional) getter + : null; + var nonEmptyGetter = optionalGetter != null ? (Getters.FromList) optionalGetter.getNotEmptyGetter() : getter; + var arrowGetter = nonComplexArrowGetter(name, nonEmptyGetter); + if (arrowGetter != null) { + return optionalGetter == null ? arrowGetter : new ArrowGetterFromList(new Field(name, new FieldType( + true, arrowGetter.field.getType(), null + ), arrowGetter.field.getChildren())) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var nonOptionalWriter = arrowGetter.writer(valueVector); + return new ArrowWriterFromList() { + @Override + public void setFromList(List list, int i) { + nonOptionalWriter.setFromList(optionalGetter.isEmpty(list, i) ? null : list, i); + } + }; + } + }; + } + return new ArrowGetterFromList(new Field(name, new FieldType( + optionalGetter != null, new ArrowType.Binary(), null + ), new ArrayList<>())) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var varBinaryVector = (VarBinaryVector) valueVector; + return new ArrowWriterFromList() { + @Override + public void setFromList(List list, int i) { + if (optionalGetter != null && optionalGetter.isEmpty(list, i)) { + varBinaryVector.setNull(varBinaryVector.getValueCount()); + } else { + var byteArrayOutputStream = new ByteArrayOutputStream(); + try (var ysonBinaryWriter = new YsonBinaryWriter(byteArrayOutputStream)) { + nonEmptyGetter.getYson(list, i, ysonBinaryWriter); + } + varBinaryVector.set(varBinaryVector.getValueCount(), byteArrayOutputStream.toByteArray()); + } + varBinaryVector.setValueCount(varBinaryVector.getValueCount() + 1); + } + }; + } + }; + } + + private ArrowGetterFromStruct arrowGetter(String name, Getters.FromStruct getter) { + var optionalGetter = getter instanceof YTGetters.FromStructToOptional + ? (Getters.FromStructToOptional) getter + : null; + var nonEmptyGetter = optionalGetter != null ? (Getters.FromStruct) optionalGetter.getNotEmptyGetter() : getter; + var arrowGetter = nonComplexArrowGetter(name, nonEmptyGetter); + if (arrowGetter != null) { + return optionalGetter == null ? arrowGetter : new ArrowGetterFromStruct(new Field(name, new FieldType( + true, arrowGetter.field.getType(), null + ), arrowGetter.field.getChildren())) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var nonOptionalWriter = arrowGetter.writer(valueVector); + return new ArrowWriterFromStruct() { + @Override + public void setFromStruct(Struct struct) { + nonOptionalWriter.setFromStruct(optionalGetter.isEmpty(struct) ? null : struct); + } + }; + } + }; + } else { + return new ArrowGetterFromStruct(new Field(name, new FieldType( + optionalGetter != null, new ArrowType.Binary(), null + ), new ArrayList<>())) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var varBinaryVector = (VarBinaryVector) valueVector; + return new ArrowWriterFromStruct() { + @Override + void setFromStruct(Struct struct) { + if (optionalGetter != null && optionalGetter.isEmpty(struct)) { + varBinaryVector.setNull(varBinaryVector.getValueCount()); + } else { + var byteArrayOutputStream = new ByteArrayOutputStream(); + try (var ysonBinaryWriter = new YsonBinaryWriter(byteArrayOutputStream)) { + nonEmptyGetter.getYson(struct, ysonBinaryWriter); + } + varBinaryVector.set(varBinaryVector.getValueCount(), byteArrayOutputStream.toByteArray()); + } + varBinaryVector.setValueCount(varBinaryVector.getValueCount() + 1); + } + }; + } + }; + } + } + + private Field field(String name, ArrowType arrowType) { + return new Field(name, new FieldType(false, arrowType, null), Collections.emptyList()); + } + + private ArrowGetterFromList nonComplexArrowGetter(String name, Getters.FromList getter) { + var tiType = getter.getTiType(); + switch (tiType.getTypeName()) { + case String: { + var stringGetter = (Getters.FromListToString) getter; + return new ArrowGetterFromList( + new Field(name, new FieldType(false, new ArrowType.Binary(), null), new ArrayList<>()) + ) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var varBinaryVector = (VarBinaryVector) valueVector; + return new ArrowWriterFromList() { + @Override + void setFromList(List list, int i) { + if (list == null) { + varBinaryVector.setNull(varBinaryVector.getValueCount()); + } else { + var byteBuffer = stringGetter.getString(list, i); + varBinaryVector.set( + varBinaryVector.getValueCount(), + byteBuffer, byteBuffer.position(), byteBuffer.remaining() + ); + } + varBinaryVector.setValueCount(varBinaryVector.getValueCount() + 1); + } + }; + } + }; + } + case Int8: { + var byteGetter = (Getters.FromListToByte) getter; + return new ArrowGetterFromList(field(name, new ArrowType.Int(8, true))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var tinyIntVector = (TinyIntVector) valueVector; + return new ArrowWriterFromList() { + @Override + void setFromList(List list, int i) { + if (list == null) { + tinyIntVector.setNull(tinyIntVector.getValueCount()); + } else { + tinyIntVector.set(tinyIntVector.getValueCount(), byteGetter.getByte(list, i)); + } + tinyIntVector.setValueCount(tinyIntVector.getValueCount() + 1); + } + }; + } + }; + } + case Uint8: { + var byteGetter = (Getters.FromListToByte) getter; + return new ArrowGetterFromList(field(name, new ArrowType.Int(8, false))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var uInt1Vector = (UInt1Vector) valueVector; + return new ArrowWriterFromList() { + @Override + void setFromList(List list, int i) { + if (list == null) { + uInt1Vector.setNull(uInt1Vector.getValueCount()); + } else { + uInt1Vector.set(uInt1Vector.getValueCount(), byteGetter.getByte(list, i)); + } + uInt1Vector.setValueCount(uInt1Vector.getValueCount() + 1); + } + }; + } + }; + } + case Int16: { + var shortGetter = (Getters.FromListToShort) getter; + return new ArrowGetterFromList(field(name, new ArrowType.Int(16, true))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var smallIntVector = (SmallIntVector) valueVector; + return new ArrowWriterFromList() { + @Override + void setFromList(List list, int i) { + if (list == null) { + smallIntVector.setNull(smallIntVector.getValueCount()); + } else { + smallIntVector.set(smallIntVector.getValueCount(), shortGetter.getShort(list, i)); + } + smallIntVector.setValueCount(smallIntVector.getValueCount() + 1); + } + }; + } + }; + } + case Uint16: { + var shortGetter = (Getters.FromListToShort) getter; + return new ArrowGetterFromList(field(name, new ArrowType.Int(16, false))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var uInt2Vector = (UInt2Vector) valueVector; + return new ArrowWriterFromList() { + @Override + void setFromList(List list, int i) { + if (list == null) { + uInt2Vector.setNull(uInt2Vector.getValueCount()); + } else { + uInt2Vector.set(uInt2Vector.getValueCount(), shortGetter.getShort(list, i)); + } + uInt2Vector.setValueCount(uInt2Vector.getValueCount() + 1); + } + }; + } + }; + } + case Int32: { + var intGetter = (Getters.FromListToInt) getter; + return new ArrowGetterFromList(field(name, new ArrowType.Int(32, true))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var intVector = (IntVector) valueVector; + return new ArrowWriterFromList() { + @Override + void setFromList(List list, int i) { + if (list == null) { + intVector.setNull(intVector.getValueCount()); + } else { + intVector.set(intVector.getValueCount(), intGetter.getInt(list, i)); + } + intVector.setValueCount(intVector.getValueCount() + 1); + } + }; + } + }; + } + case Uint32: { + var intGetter = (Getters.FromListToInt) getter; + return new ArrowGetterFromList(field(name, new ArrowType.Int(32, false))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var uInt4Vector = (UInt4Vector) valueVector; + return new ArrowWriterFromList() { + @Override + void setFromList(List list, int i) { + if (list == null) { + uInt4Vector.setNull(uInt4Vector.getValueCount()); + } else { + uInt4Vector.set(uInt4Vector.getValueCount(), intGetter.getInt(list, i)); + } + uInt4Vector.setValueCount(uInt4Vector.getValueCount() + 1); + } + }; + } + }; + } + case Interval: + case Interval64: + case Int64: { + var longGetter = (Getters.FromListToLong) getter; + return new ArrowGetterFromList(field(name, new ArrowType.Int(64, true))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var bigIntVector = (BigIntVector) valueVector; + return new ArrowWriterFromList() { + @Override + void setFromList(List list, int i) { + if (list == null) { + bigIntVector.setNull(bigIntVector.getValueCount()); + } else { + bigIntVector.set(bigIntVector.getValueCount(), longGetter.getLong(list, i)); + } + bigIntVector.setValueCount(bigIntVector.getValueCount() + 1); + } + }; + } + }; + } + case Uint64: { + var longGetter = (Getters.FromListToLong) getter; + return new ArrowGetterFromList(field(name, new ArrowType.Int(64, false))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var uInt8Vector = (UInt8Vector) valueVector; + return new ArrowWriterFromList() { + @Override + void setFromList(List list, int i) { + if (list == null) { + uInt8Vector.setNull(uInt8Vector.getValueCount()); + } else { + uInt8Vector.set(uInt8Vector.getValueCount(), longGetter.getLong(list, i)); + } + uInt8Vector.setValueCount(uInt8Vector.getValueCount() + 1); + } + }; + } + }; + } + case Bool: { + var booleanGetter = (Getters.FromListToBoolean) getter; + return new ArrowGetterFromList(field(name, new ArrowType.Bool())) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var bitVector = (BitVector) valueVector; + return new ArrowWriterFromList() { + @Override + void setFromList(List list, int i) { + if (list == null) { + bitVector.setNull(bitVector.getValueCount()); + } else { + bitVector.set(bitVector.getValueCount(), booleanGetter.getBoolean(list, i) ? 1 : 0); + } + bitVector.setValueCount(bitVector.getValueCount() + 1); + } + }; + } + }; + } + case Float: { + var floatGetter = (Getters.FromListToFloat) getter; + return new ArrowGetterFromList(field(name, new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var float4Vector = (Float4Vector) valueVector; + return new ArrowWriterFromList() { + @Override + void setFromList(List list, int i) { + if (list == null) { + float4Vector.setNull(float4Vector.getValueCount()); + } else { + float4Vector.set(float4Vector.getValueCount(), floatGetter.getFloat(list, i)); + } + float4Vector.setValueCount(float4Vector.getValueCount() + 1); + } + }; + } + }; + } + case Double: { + var doubleGetter = (Getters.FromListToDouble) getter; + return new ArrowGetterFromList(field(name, new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var float8Vector = (Float8Vector) valueVector; + return new ArrowWriterFromList() { + @Override + void setFromList(List list, int i) { + if (list == null) { + float8Vector.setNull(float8Vector.getValueCount()); + } else { + float8Vector.set(float8Vector.getValueCount(), doubleGetter.getDouble(list, i)); + } + float8Vector.setValueCount(float8Vector.getValueCount() + 1); + } + }; + } + }; + } + case Decimal: { + var decimalGetter = (Getters.FromListToBigDecimal) getter; + var decimalType = (DecimalType) decimalGetter.getTiType(); + return new ArrowGetterFromList(field(name, new ArrowType.Decimal( + decimalType.getPrecision(), decimalType.getScale() + ))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var decimalVector = (DecimalVector) valueVector; + return new ArrowWriterFromList() { + @Override + void setFromList(List list, int i) { + if (list == null) { + decimalVector.setNull(decimalVector.getValueCount()); + } else { + decimalVector.set(decimalVector.getValueCount(), decimalGetter.getBigDecimal(list, i)); + } + decimalVector.setValueCount(decimalVector.getValueCount() + 1); + } + }; + } + }; + } + case Date: + case Date32: { + var intGetter = (Getters.FromListToInt) getter; + return new ArrowGetterFromList(field(name, new ArrowType.Date(DateUnit.DAY))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var dateDayVector = (DateDayVector) valueVector; + return new ArrowWriterFromList() { + @Override + void setFromList(List list, int i) { + if (list == null) { + dateDayVector.setNull(dateDayVector.getValueCount()); + } else { + dateDayVector.set(dateDayVector.getValueCount(), intGetter.getInt(list, i)); + } + dateDayVector.setValueCount(dateDayVector.getValueCount() + 1); + } + }; + } + }; + } + case Datetime: + case Datetime64: { + var longGetter = (Getters.FromListToLong) getter; + return new ArrowGetterFromList(field(name, new ArrowType.Date(DateUnit.MILLISECOND))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var dateMilliVector = (DateMilliVector) valueVector; + return new ArrowWriterFromList() { + @Override + void setFromList(List list, int i) { + if (list == null) { + dateMilliVector.setNull(dateMilliVector.getValueCount()); + } else { + dateMilliVector.set(dateMilliVector.getValueCount(), longGetter.getLong(list, i)); + } + dateMilliVector.setValueCount(dateMilliVector.getValueCount() + 1); + } + }; + } + }; + } + case Timestamp: + case Timestamp64: { + var longGetter = (Getters.FromListToLong) getter; + return new ArrowGetterFromList(field(name, new ArrowType.Timestamp(TimeUnit.MICROSECOND, null))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var timeStampMicroVector = (TimeStampMicroVector) valueVector; + return new ArrowWriterFromList() { + @Override + void setFromList(List list, int i) { + if (list == null) { + timeStampMicroVector.setNull(timeStampMicroVector.getValueCount()); + } else { + timeStampMicroVector.set(timeStampMicroVector.getValueCount(), longGetter.getLong(list, i)); + } + timeStampMicroVector.setValueCount(timeStampMicroVector.getValueCount() + 1); + } + }; + } + }; + } + case List: { + var listGetter = (Getters.FromListToList) getter; + var elementGetter = listGetter.getElementGetter(); + var itemGetter = arrowGetter("item", (Getters.FromList) elementGetter); + return new ArrowGetterFromList(new Field(name, new FieldType( + false, new ArrowType.List(), null + ), Collections.singletonList(itemGetter.field))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var listVector = (ListVector) valueVector; + var dataWriter = itemGetter.writer(listVector.getDataVector()); + return new ArrowWriterFromList() { + @Override + void setFromList(List list, int i) { + var value = list == null ? null : (List) listGetter.getList(list, i); + if (value != null) { + int size = elementGetter.getSize(value); + listVector.startNewValue(listVector.getValueCount()); + for (int j = 0; j < size; j++) { + dataWriter.setFromList(value, j); + } + listVector.endValue(listVector.getValueCount(), size); + } + listVector.setValueCount(listVector.getValueCount() + 1); + } + }; + } + }; + } + case Dict: { + var dictGetter = (Getters.FromListToDict) getter; + var fromDictGetter = dictGetter.getGetter(); + var keyGetter = nonComplexArrowGetter("key", (Getters.FromList) fromDictGetter.getKeyGetter()); + var valueGetter = arrowGetter("value", (Getters.FromList) fromDictGetter.getValueGetter()); + if (keyGetter == null || valueGetter == null) { + return null; + } + return new ArrowGetterFromList(new Field( + name, new FieldType(false, new ArrowType.Map(false), null), + Collections.singletonList(new Field( + "entries", new FieldType(false, new ArrowType.Struct(), null), + Arrays.asList(keyGetter.field, valueGetter.field) + )) + )) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var mapVector = (MapVector) valueVector; + var structVector = (StructVector) mapVector.getDataVector(); + var keyWriter = keyGetter.writer(structVector.getChildByOrdinal(0)); + var valueWriter = valueGetter.writer(structVector.getChildByOrdinal(1)); + return new ArrowWriterFromList() { + @Override + void setFromList(List list, int i) { + var dict = list == null ? null : dictGetter.getDict(list, i); + if (dict != null) { + int size = fromDictGetter.getSize(dict); + var keys = fromDictGetter.getKeys(dict); + var values = fromDictGetter.getValues(dict); + mapVector.startNewValue(mapVector.getValueCount()); + for (int j = 0; j < size; j++) { + structVector.setIndexDefined(structVector.getValueCount()); + keyWriter.setFromList((List) keys, j); + valueWriter.setFromList((List) values, j); + structVector.setValueCount(structVector.getValueCount() + 1); + } + mapVector.endValue(mapVector.getValueCount(), size); + } + mapVector.setValueCount(mapVector.getValueCount() + 1); + } + }; + } + }; + } + case Struct: { + var structGetter = (Getters.FromListToStruct) getter; + var members = (java.util.List>) structGetter.getMembersGetters(); + var membersGetters = new ArrayList(members.size()); + for (Map.Entry member : members) { + membersGetters.add(arrowGetter(member.getKey(), member.getValue())); + } + return new ArrowGetterFromList(new Field( + name, new FieldType(false, new ArrowType.Struct(), null), + membersGetters.stream().map(member -> member.field).collect(Collectors.toList()) + )) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var structVector = (StructVector) valueVector; + var membersWriters = new ArrayList(members.size()); + for (int i = 0; i < members.size(); i++) { + membersWriters.add(membersGetters.get(i).writer(structVector.getChildByOrdinal(i))); + } + return new ArrowWriterFromList() { + @Override + void setFromList(List list, int i) { + if (list == null) { + for (int j = 0; j < members.size(); j++) { + membersWriters.get(j).setFromStruct(null); + } + } else { + var struct = (Struct) structGetter.getStruct(list, i); + structVector.setIndexDefined(structVector.getValueCount()); + for (int j = 0; j < members.size(); j++) { + membersWriters.get(j).setFromStruct(struct); + } + } + structVector.setValueCount(structVector.getValueCount() + 1); + } + }; + } + }; + } + default: + return null; + } + } + + private ArrowGetterFromStruct nonComplexArrowGetter(String name, Getters.FromStruct getter) { + var tiType = getter.getTiType(); + switch (tiType.getTypeName()) { + case String: { + var stringGetter = (Getters.FromStructToString) getter; + return new ArrowGetterFromStruct(field(name, new ArrowType.Binary())) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var varBinaryVector = (VarBinaryVector) valueVector; + return new ArrowWriterFromStruct() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + varBinaryVector.setNull(varBinaryVector.getValueCount()); + } else { + var byteBuffer = stringGetter.getString(struct); + varBinaryVector.set( + varBinaryVector.getValueCount(), + byteBuffer, byteBuffer.position(), byteBuffer.remaining() + ); + } + varBinaryVector.setValueCount(varBinaryVector.getValueCount() + 1); + } + }; + } + }; + } + case Int8: { + var byteGetter = (Getters.FromStructToByte) getter; + return new ArrowGetterFromStruct(field(name, new ArrowType.Int(8, true))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var tinyIntVector = (TinyIntVector) valueVector; + return new ArrowWriterFromStruct() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + tinyIntVector.setNull(tinyIntVector.getValueCount()); + } else { + tinyIntVector.set(tinyIntVector.getValueCount(), byteGetter.getByte(struct)); + } + tinyIntVector.setValueCount(tinyIntVector.getValueCount() + 1); + } + }; + } + }; + } + case Uint8: { + var byteGetter = (Getters.FromStructToByte) getter; + return new ArrowGetterFromStruct(field(name, new ArrowType.Int(8, false))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var uInt1Vector = (UInt1Vector) valueVector; + return new ArrowWriterFromStruct() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + uInt1Vector.setNull(uInt1Vector.getValueCount()); + } else { + uInt1Vector.set(uInt1Vector.getValueCount(), byteGetter.getByte(struct)); + } + uInt1Vector.setValueCount(uInt1Vector.getValueCount() + 1); + } + }; + } + }; + } + case Int16: { + var shortGetter = (Getters.FromStructToShort) getter; + return new ArrowGetterFromStruct(field(name, new ArrowType.Int(16, true))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var smallIntVector = (SmallIntVector) valueVector; + return new ArrowWriterFromStruct() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + smallIntVector.setNull(smallIntVector.getValueCount()); + } else { + smallIntVector.set(smallIntVector.getValueCount(), shortGetter.getShort(struct)); + } + smallIntVector.setValueCount(smallIntVector.getValueCount() + 1); + } + }; + } + }; + } + case Uint16: { + var shortGetter = (Getters.FromStructToShort) getter; + return new ArrowGetterFromStruct(field(name, new ArrowType.Int(16, false))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var uInt2Vector = (UInt2Vector) valueVector; + return new ArrowWriterFromStruct() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + uInt2Vector.setNull(uInt2Vector.getValueCount()); + } else { + uInt2Vector.set(uInt2Vector.getValueCount(), shortGetter.getShort(struct)); + } + uInt2Vector.setValueCount(uInt2Vector.getValueCount() + 1); + } + }; + } + }; + } + case Int32: { + var intGetter = (Getters.FromStructToInt) getter; + return new ArrowGetterFromStruct(field(name, new ArrowType.Int(32, true))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var intVector = (IntVector) valueVector; + return new ArrowWriterFromStruct() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + intVector.setNull(intVector.getValueCount()); + } else { + intVector.set(intVector.getValueCount(), intGetter.getInt(struct)); + } + intVector.setValueCount(intVector.getValueCount() + 1); + } + }; + } + }; + } + case Uint32: { + var intGetter = (Getters.FromStructToInt) getter; + return new ArrowGetterFromStruct(field(name, new ArrowType.Int(32, false))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var uInt4Vector = (UInt4Vector) valueVector; + return new ArrowWriterFromStruct() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + uInt4Vector.setNull(uInt4Vector.getValueCount()); + } else { + uInt4Vector.set(uInt4Vector.getValueCount(), intGetter.getInt(struct)); + } + uInt4Vector.setValueCount(uInt4Vector.getValueCount() + 1); + } + }; + } + }; + } + case Interval: + case Interval64: + case Int64: { + var longGetter = (Getters.FromStructToLong) getter; + return new ArrowGetterFromStruct(field(name, new ArrowType.Int(64, true))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var bigIntVector = (BigIntVector) valueVector; + return new ArrowWriterFromStruct() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + bigIntVector.setNull(bigIntVector.getValueCount()); + } else { + bigIntVector.set(bigIntVector.getValueCount(), longGetter.getLong(struct)); + } + bigIntVector.setValueCount(bigIntVector.getValueCount() + 1); + } + }; + } + }; + } + case Uint64: { + var longGetter = (Getters.FromStructToLong) getter; + return new ArrowGetterFromStruct(field(name, new ArrowType.Int(64, false))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var uInt8Vector = (UInt8Vector) valueVector; + return new ArrowWriterFromStruct() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + uInt8Vector.setNull(uInt8Vector.getValueCount()); + } else { + uInt8Vector.set(uInt8Vector.getValueCount(), longGetter.getLong(struct)); + } + uInt8Vector.setValueCount(uInt8Vector.getValueCount() + 1); + } + }; + } + }; + } + case Bool: { + var booleanGetter = (Getters.FromStructToBoolean) getter; + return new ArrowGetterFromStruct(field(name, new ArrowType.Bool())) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var bitVector = (BitVector) valueVector; + return new ArrowWriterFromStruct() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + bitVector.setNull(bitVector.getValueCount()); + } else { + bitVector.set(bitVector.getValueCount(), booleanGetter.getBoolean(struct) ? 1 : 0); + } + bitVector.setValueCount(bitVector.getValueCount() + 1); + } + }; + } + }; + } + case Float: { + var floatGetter = (Getters.FromStructToFloat) getter; + return new ArrowGetterFromStruct(field(name, new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var float4Vector = (Float4Vector) valueVector; + return new ArrowWriterFromStruct() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + float4Vector.setNull(float4Vector.getValueCount()); + } else { + float4Vector.set(float4Vector.getValueCount(), floatGetter.getFloat(struct)); + } + float4Vector.setValueCount(float4Vector.getValueCount() + 1); + } + }; + } + }; + } + case Double: { + var doubleGetter = (Getters.FromStructToDouble) getter; + return new ArrowGetterFromStruct(field(name, new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var float8Vector = (Float8Vector) valueVector; + return new ArrowWriterFromStruct() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + float8Vector.setNull(float8Vector.getValueCount()); + } else { + float8Vector.set(float8Vector.getValueCount(), doubleGetter.getDouble(struct)); + } + float8Vector.setValueCount(float8Vector.getValueCount() + 1); + } + }; + } + }; + } + case Decimal: { + var decimalGetter = (Getters.FromStructToBigDecimal) getter; + var decimalType = (DecimalType) decimalGetter.getTiType(); + return new ArrowGetterFromStruct(field(name, new ArrowType.Decimal( + decimalType.getPrecision(), decimalType.getScale() + ))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var decimalVector = (DecimalVector) valueVector; + return new ArrowWriterFromStruct() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + decimalVector.setNull(decimalVector.getValueCount()); + } else { + decimalVector.set(decimalVector.getValueCount(), decimalGetter.getBigDecimal(struct)); + } + decimalVector.setValueCount(decimalVector.getValueCount() + 1); + } + }; + } + }; + } + case Date: + case Date32: { + var intGetter = (Getters.FromStructToInt) getter; + return new ArrowGetterFromStruct(field(name, new ArrowType.Date(DateUnit.DAY))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var dateDayVector = (DateDayVector) valueVector; + return new ArrowWriterFromStruct() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + dateDayVector.setNull(dateDayVector.getValueCount()); + } else { + dateDayVector.set(dateDayVector.getValueCount(), intGetter.getInt(struct)); + } + dateDayVector.setValueCount(dateDayVector.getValueCount() + 1); + } + }; + } + }; + } + case Datetime: + case Datetime64: { + var longGetter = (Getters.FromStructToLong) getter; + return new ArrowGetterFromStruct(field(name, new ArrowType.Date(DateUnit.MILLISECOND))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var dateMilliVector = (DateMilliVector) valueVector; + return new ArrowWriterFromStruct() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + dateMilliVector.setNull(dateMilliVector.getValueCount()); + } else { + dateMilliVector.set(dateMilliVector.getValueCount(), longGetter.getLong(struct)); + } + dateMilliVector.setValueCount(dateMilliVector.getValueCount() + 1); + } + }; + } + }; + } + case Timestamp: + case Timestamp64: { + var longGetter = (Getters.FromStructToLong) getter; + return new ArrowGetterFromStruct(field(name, new ArrowType.Timestamp(TimeUnit.MICROSECOND, null))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var timeStampMicroVector = (TimeStampMicroVector) valueVector; + return new ArrowWriterFromStruct() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + timeStampMicroVector.setNull(timeStampMicroVector.getValueCount()); + } else { + timeStampMicroVector.set(timeStampMicroVector.getValueCount(), longGetter.getLong(struct)); + } + timeStampMicroVector.setValueCount(timeStampMicroVector.getValueCount() + 1); + } + }; + } + }; + } + case List: { + var listGetter = (Getters.FromStructToList) getter; + var elementGetter = (Getters.FromList) listGetter.getElementGetter(); + var itemGetter = arrowGetter("item", elementGetter); + return new ArrowGetterFromStruct(new Field(name, new FieldType( + false, new ArrowType.List(), null + ), Collections.singletonList(itemGetter.field))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var listVector = (ListVector) valueVector; + var dataWriter = itemGetter.writer(listVector.getDataVector()); + return new ArrowWriterFromStruct() { + @Override + public void setFromStruct(Struct struct) { + var list = struct == null ? null : (List) listGetter.getList(struct); + if (list != null) { + int size = elementGetter.getSize(list); + listVector.startNewValue(listVector.getValueCount()); + for (int i = 0; i < size; i++) { + dataWriter.setFromList(list, i); + } + listVector.endValue(listVector.getValueCount(), size); + } + listVector.setValueCount(listVector.getValueCount() + 1); + } + }; + } + }; + } + case Dict: { + var dictGetter = (Getters.FromStructToDict) getter; + var fromDictGetter = (Getters.FromDict) dictGetter.getGetter(); + var keyGetter = nonComplexArrowGetter("key", (Getters.FromList) fromDictGetter.getKeyGetter()); + var valueGetter = arrowGetter("value", (Getters.FromList) fromDictGetter.getValueGetter()); + if (keyGetter == null || valueGetter == null) { + return null; + } + return new ArrowGetterFromStruct(new Field( + name, new FieldType(false, new ArrowType.Map(false), null), + Collections.singletonList(new Field( + "entries", new FieldType(false, new ArrowType.Struct(), null), + Arrays.asList(keyGetter.field, valueGetter.field) + )) + )) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var mapVector = (MapVector) valueVector; + var structVector = (StructVector) mapVector.getDataVector(); + var keyWriter = keyGetter.writer(structVector.getChildByOrdinal(0)); + var valueWriter = valueGetter.writer(structVector.getChildByOrdinal(1)); + return new ArrowWriterFromStruct() { + @Override + public void setFromStruct(Struct struct) { + var dict = struct == null ? null : dictGetter.getDict(struct); + if (dict != null) { + int size = fromDictGetter.getSize(dict); + var keys = (List) fromDictGetter.getKeys(dict); + var values = (List) fromDictGetter.getValues(dict); + mapVector.startNewValue(mapVector.getValueCount()); + for (int i = 0; i < size; i++) { + structVector.setIndexDefined(structVector.getValueCount()); + keyWriter.setFromList(keys, i); + valueWriter.setFromList(values, i); + structVector.setValueCount(structVector.getValueCount() + 1); + } + mapVector.endValue(mapVector.getValueCount(), size); + } + mapVector.setValueCount(mapVector.getValueCount() + 1); + } + }; + } + }; + } + case Struct: { + var structGetter = (Getters.FromStructToStruct) getter; + var members = (java.util.List>) structGetter.getMembersGetters(); + var membersGetters = new ArrayList(members.size()); + for (Map.Entry member : members) { + membersGetters.add(arrowGetter(member.getKey(), member.getValue())); + } + return new ArrowGetterFromStruct(new Field( + name, new FieldType(false, new ArrowType.Struct(), null), + membersGetters.stream().map(member -> member.field).collect(Collectors.toList()) + )) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var structVector = (StructVector) valueVector; + var membersWriters = new ArrayList(members.size()); + for (int i = 0; i < members.size(); i++) { + membersWriters.add(membersGetters.get(i).writer(structVector.getChildByOrdinal(i))); + } + return new ArrowWriterFromStruct() { + @Override + void setFromStruct(Struct row) { + if (row == null) { + for (int i = 0; i < members.size(); i++) { + membersWriters.get(i).setFromStruct(null); + } + } else { + var struct = (Struct) structGetter.getStruct(row); + structVector.setIndexDefined(structVector.getValueCount()); + for (int i = 0; i < members.size(); i++) { + membersWriters.get(i).setFromStruct(struct); + } + } + structVector.setValueCount(structVector.getValueCount() + 1); + } + }; + } + }; + } + default: + return null; + } + } + + private final java.util.List fieldGetters; + private final Schema schema; + private final BufferAllocator allocator = + ArrowUtils.rootAllocator().newChildAllocator("toBatchIterator", 0, Long.MAX_VALUE); + + public ArrowTableRowsSerializer(java.util.List> structsGetter) { + super(ERowsetFormat.RF_FORMAT); + fieldGetters = structsGetter.stream().map(memberGetter -> arrowGetter( + memberGetter.getKey(), memberGetter.getValue() + )).collect(Collectors.toList()); + schema = new Schema(() -> fieldGetters.stream().map(getter -> getter.field).iterator()); + } + + @Override + public void close() { + allocator.close(); + } + + private static class ByteBufWritableByteChannel implements WritableByteChannel { + private final ByteBuf buf; + + private ByteBufWritableByteChannel(ByteBuf buf) { + this.buf = buf; + } + + @Override + public int write(ByteBuffer src) { + int remaining = src.remaining(); + buf.writeBytes(src); + return remaining - src.remaining(); + } + + @Override + public boolean isOpen() { + return buf.isWritable(); + } + + @Override + public void close() { + } + } + + @Override + protected void writeMeta(ByteBuf buf, ByteBuf serializedRows, int rowsCount) { + try { + var writeChannel = new WriteChannel(new ByteBufWritableByteChannel(buf)); + MessageSerializer.serialize(writeChannel, schema); + writeChannel.write(serializedRows.nioBuffer()); + ArrowStreamWriter.writeEndOfStream(writeChannel, new IpcOption()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + protected void writeRowsWithoutCount( + ByteBuf buf, TRowsetDescriptor descriptor, java.util.List rows, int[] idMapping + ) { + writeRows(buf, descriptor, rows, idMapping); + } + + @Override + protected void writeRows(ByteBuf buf, TRowsetDescriptor descriptor, java.util.List rows, int[] idMapping) { + try { + var writeChannel = new WriteChannel(new ByteBufWritableByteChannel(buf)); + MessageSerializer.serialize(writeChannel, schema); + var root = VectorSchemaRoot.create(schema, allocator); + var unloader = new VectorUnloader(root); + var writers = IntStream.range(0, fieldGetters.size()).mapToObj(column -> { + var valueVector = root.getFieldVectors().get(column); + if (valueVector instanceof FixedWidthVector) { + ((FixedWidthVector) valueVector).allocateNew(rows.size()); + } else { + valueVector.allocateNew(); + } + return fieldGetters.get(column).writer(valueVector); + }).collect(Collectors.toList()); + for (var row : rows) { + for (var writer : writers) { + writer.setFromStruct(row); + } + } + root.setRowCount(rows.size()); + try (var batch = unloader.getRecordBatch()) { + MessageSerializer.serialize(writeChannel, batch); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/data-source/src/main/scala/tech/ytsaurus/client/ArrowWriteSerializationContext.java b/data-source/src/main/scala/tech/ytsaurus/client/ArrowWriteSerializationContext.java new file mode 100644 index 00000000..6779e849 --- /dev/null +++ b/data-source/src/main/scala/tech/ytsaurus/client/ArrowWriteSerializationContext.java @@ -0,0 +1,24 @@ +package tech.ytsaurus.client; + +import tech.ytsaurus.client.request.Format; +import tech.ytsaurus.client.request.SerializationContext; +import tech.ytsaurus.rpcproxy.ERowsetFormat; + +import java.util.HashMap; +import java.util.Map; + +public class ArrowWriteSerializationContext> extends SerializationContext { + private final java.util.List> rowGetters; + + public ArrowWriteSerializationContext( + java.util.List> rowGetters + ) { + this.rowsetFormat = ERowsetFormat.RF_FORMAT; + this.format = new Format("arrow", new HashMap<>()); + this.rowGetters = rowGetters; + } + + public java.util.List> getRowGetters() { + return rowGetters; + } +} diff --git a/data-source/src/main/scala/tech/ytsaurus/client/InternalRowYTGetters.java b/data-source/src/main/scala/tech/ytsaurus/client/InternalRowYTGetters.java new file mode 100644 index 00000000..b823af88 --- /dev/null +++ b/data-source/src/main/scala/tech/ytsaurus/client/InternalRowYTGetters.java @@ -0,0 +1,8 @@ +package tech.ytsaurus.client; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.MapData; + +public class InternalRowYTGetters extends YTGetters { +} diff --git a/data-source/src/main/scala/tech/ytsaurus/client/TableWriterBaseImpl.java b/data-source/src/main/scala/tech/ytsaurus/client/TableWriterBaseImpl.java new file mode 100644 index 00000000..21bac7a9 --- /dev/null +++ b/data-source/src/main/scala/tech/ytsaurus/client/TableWriterBaseImpl.java @@ -0,0 +1,114 @@ +package tech.ytsaurus.client; + +import java.io.IOException; +import java.util.List; +import java.util.concurrent.CompletableFuture; + +import javax.annotation.Nullable; + +import tech.ytsaurus.client.request.WriteTable; +import tech.ytsaurus.client.rows.UnversionedRow; +import tech.ytsaurus.client.rows.UnversionedRowSerializer; +import tech.ytsaurus.client.rpc.Compression; +import tech.ytsaurus.client.rpc.RpcUtil; +import tech.ytsaurus.core.tables.TableSchema; +import tech.ytsaurus.lang.NonNullApi; +import tech.ytsaurus.rpcproxy.TWriteTableMeta; + + +@NonNullApi +class TableWriterBaseImpl extends RawTableWriterImpl { + protected @Nullable + TableSchema schema; + protected final WriteTable req; + protected @Nullable + TableRowsSerializer tableRowsSerializer; + private final SerializationResolver serializationResolver; + @Nullable + protected ApiServiceTransaction transaction; + + TableWriterBaseImpl(WriteTable req, SerializationResolver serializationResolver) { + super(req.getWindowSize(), req.getPacketSize()); + this.req = req; + this.serializationResolver = serializationResolver; + var format = this.req.getSerializationContext().getFormat(); + if (format.isEmpty() || !"arrow".equals(format.get().getType())) { + tableRowsSerializer = TableRowsSerializer.createTableRowsSerializer( + this.req.getSerializationContext(), serializationResolver + ).orElse(null); + } + } + + public void setTransaction(ApiServiceTransaction transaction) { + if (this.transaction != null) { + throw new IllegalStateException("Write transaction already started"); + } + this.transaction = transaction; + } + + public CompletableFuture> startUploadImpl() { + TableWriterBaseImpl self = this; + + return startUpload.thenApply((attachments) -> { + if (attachments.size() != 1) { + throw new IllegalArgumentException("protocol error"); + } + byte[] head = attachments.get(0); + if (head == null) { + throw new IllegalArgumentException("protocol error"); + } + + TWriteTableMeta metadata = RpcUtil.parseMessageBodyWithCompression( + head, + TWriteTableMeta.parser(), + Compression.None + ); + self.schema = ApiServiceUtil.deserializeTableSchema(metadata.getSchema()); + logger.debug("schema -> {}", schema.toYTree().toString()); + + { + var format = this.req.getSerializationContext().getFormat(); + if (format.isPresent() && "arrow".equals(format.get().getType())) { + tableRowsSerializer = new ArrowTableRowsSerializer<>( + ((ArrowWriteSerializationContext) this.req.getSerializationContext()).getRowGetters() + ); + } + } + + if (this.tableRowsSerializer == null) { + if (this.req.getSerializationContext().getObjectClass().isEmpty()) { + throw new IllegalStateException("No object clazz"); + } + Class objectClazz = self.req.getSerializationContext().getObjectClass().get(); + if (UnversionedRow.class.equals(objectClazz)) { + this.tableRowsSerializer = + (TableRowsSerializer) new TableRowsWireSerializer<>(new UnversionedRowSerializer()); + } else { + this.tableRowsSerializer = new TableRowsWireSerializer<>( + serializationResolver.createWireRowSerializer( + serializationResolver.forClass(objectClazz, self.schema)) + ); + } + } + + return self; + }); + } + + public boolean write(List rows, TableSchema schema) throws IOException { + byte[] serializedRows = tableRowsSerializer.serializeRows(rows, schema); + return write(serializedRows); + } + + @Override + public CompletableFuture close() { + return super.close() + .thenCompose(response -> { + if (transaction != null && transaction.isActive()) { + return transaction.commit() + .thenApply(unused -> response); + } + return CompletableFuture.completedFuture(response); + }); + } +} diff --git a/data-source/src/main/scala/tech/ytsaurus/client/YTGetters.java b/data-source/src/main/scala/tech/ytsaurus/client/YTGetters.java new file mode 100644 index 00000000..6c75768f --- /dev/null +++ b/data-source/src/main/scala/tech/ytsaurus/client/YTGetters.java @@ -0,0 +1,177 @@ +package tech.ytsaurus.client; + +import tech.ytsaurus.typeinfo.*; +import tech.ytsaurus.yson.YsonConsumer; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.util.Map; + +public class YTGetters { + public abstract class Getter { + private Getter() { + } + + public abstract TiType getTiType(); + } + + public abstract class FromStruct extends Getter { + private FromStruct() { + } + + public abstract void getYson(Struct struct, YsonConsumer ysonConsumer); + } + + public abstract class FromList extends Getter { + private FromList() { + } + + public abstract int getSize(List list); + + public abstract void getYson(List list, int i, YsonConsumer ysonConsumer); + } + + public abstract class FromStructToYson extends FromStruct { + } + + public abstract class FromListToYson extends FromList { + } + + public abstract class FromDict extends Getter { + public abstract FromList getKeyGetter(); + + public abstract FromList getValueGetter(); + + public abstract int getSize(Dict dict); + + public abstract List getKeys(Dict dict); + + public abstract List getValues(Dict dict); + } + + public abstract class FromStructToNull extends FromStruct { + } + + public abstract class FromListToNull extends FromList { + } + + public abstract class FromStructToOptional extends FromStruct { + public abstract FromStruct getNotEmptyGetter(); + + public abstract boolean isEmpty(Struct struct); + } + + public abstract class FromListToOptional extends FromList { + public abstract FromList getNotEmptyGetter(); + + public abstract boolean isEmpty(List list, int i); + } + + public abstract class FromStructToString extends FromStruct { + public abstract ByteBuffer getString(Struct struct); + } + + public abstract class FromListToString extends FromList { + public abstract ByteBuffer getString(List struct, int i); + } + + public abstract class FromStructToByte extends FromStruct { + public abstract byte getByte(Struct struct); + } + + public abstract class FromListToByte extends FromList { + public abstract byte getByte(List list, int i); + } + + public abstract class FromStructToShort extends FromStruct { + public abstract short getShort(Struct struct); + } + + public abstract class FromListToShort extends FromList { + public abstract short getShort(List list, int i); + } + + public abstract class FromStructToInt extends FromStruct { + public abstract int getInt(Struct struct); + } + + public abstract class FromListToInt extends FromList { + public abstract int getInt(List list, int i); + } + + public abstract class FromStructToLong extends FromStruct { + public abstract long getLong(Struct struct); + } + + public abstract class FromListToLong extends FromList { + public abstract long getLong(List list, int i); + } + + public abstract class FromStructToBoolean extends FromStruct { + public abstract boolean getBoolean(Struct struct); + } + + public abstract class FromListToBoolean extends FromList { + public abstract boolean getBoolean(List list, int i); + } + + public abstract class FromStructToFloat extends FromStruct { + public abstract float getFloat(Struct struct); + } + + public abstract class FromListToFloat extends FromList { + public abstract float getFloat(List list, int i); + } + + public abstract class FromStructToDouble extends FromStruct { + public abstract double getDouble(Struct struct); + } + + public abstract class FromListToDouble extends FromList { + public abstract double getDouble(List list, int i); + } + + public abstract class FromStructToStruct extends FromStruct { + public abstract java.util.List> getMembersGetters(); + + public abstract Struct getStruct(Struct struct); + } + + public abstract class FromListToStruct extends FromList { + public abstract java.util.List> getMembersGetters(); + + public abstract Struct getStruct(List list, int i); + } + + public abstract class FromStructToList extends FromStruct { + public abstract FromList getElementGetter(); + + public abstract List getList(Struct struct); + } + + public abstract class FromListToList extends FromList { + public abstract FromList getElementGetter(); + + public abstract List getList(List list, int i); + } + + public abstract class FromStructToDict extends FromStruct { + public abstract FromDict getGetter(); + + public abstract Dict getDict(Struct struct); + } + + public abstract class FromListToDict extends FromList { + public abstract FromDict getGetter(); + + public abstract Dict getDict(List list, int i); + } + + public abstract class FromStructToBigDecimal extends FromStruct { + public abstract BigDecimal getBigDecimal(Struct struct); + } + + public abstract class FromListToBigDecimal extends FromList { + public abstract BigDecimal getBigDecimal(List list, int i); + } +} diff --git a/data-source/src/main/scala/tech/ytsaurus/spyt/format/YtOutputWriter.scala b/data-source/src/main/scala/tech/ytsaurus/spyt/format/YtOutputWriter.scala index 5e87d476..3ad532d0 100644 --- a/data-source/src/main/scala/tech/ytsaurus/spyt/format/YtOutputWriter.scala +++ b/data-source/src/main/scala/tech/ytsaurus/spyt/format/YtOutputWriter.scala @@ -5,21 +5,23 @@ import org.apache.spark.executor.TaskMetricUpdater import org.apache.spark.metrics.yt.YtMetricsRegister import org.apache.spark.metrics.yt.YtMetricsRegister.ytMetricsSource._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.execution.datasources.OutputWriter import org.apache.spark.sql.types.StructType import org.slf4j.LoggerFactory +import tech.ytsaurus.client.request.{TransactionalOptions, WriteTable} +import tech.ytsaurus.client.{ArrowWriteSerializationContext, CompoundClient, InternalRowYTGetters, TableWriter} +import tech.ytsaurus.core.GUID +import tech.ytsaurus.spyt.format.conf.SparkYtWriteConfiguration import tech.ytsaurus.spyt.format.conf.YtTableSparkSettings._ import tech.ytsaurus.spyt.fs.conf._ import tech.ytsaurus.spyt.fs.path.YPathEnriched import tech.ytsaurus.spyt.serializers.{InternalRowSerializer, WriteSchemaConverter} import tech.ytsaurus.spyt.wrapper.LogLazy -import tech.ytsaurus.client.request.{TransactionalOptions, WriteSerializationContext, WriteTable} -import tech.ytsaurus.client.{CompoundClient, TableWriter} -import tech.ytsaurus.core.GUID -import tech.ytsaurus.spyt.format.conf.SparkYtWriteConfiguration import java.util import java.util.concurrent.{CompletableFuture, TimeUnit} +import scala.collection.JavaConverters.seqAsJavaListConverter import scala.concurrent.{Await, Future} import scala.util.{Failure, Try} @@ -151,9 +153,15 @@ class YtOutputWriter(richPath: YPathEnriched, protected def initializeWriter(): TableWriter[InternalRow] = { val appendPath = richPath.withAttr("append", "true").toYPath log.debugLazy(s"Initialize new write: $appendPath, transaction: $transactionGuid") + val internalRowGetters = new InternalRowYTGetters() val request = WriteTable.builder[InternalRow]() .setPath(appendPath) - .setSerializationContext(new WriteSerializationContext(new InternalRowSerializer(schema, WriteSchemaConverter(options)))) + .setSerializationContext(new ArrowWriteSerializationContext[InternalRow, ArrayData, MapData, InternalRowYTGetters]( + WriteSchemaConverter(options).ytLogicalTypeStruct(schema).fields.zipWithIndex.map { + case ((name, ytLogicalType, _), i) => + util.Map.entry(name, ytLogicalType.ytGettersFromStruct(internalRowGetters, i)) + }.asJava + )) .setTransactionalOptions(new TransactionalOptions(GUID.valueOf(transactionGuid))) .setNeedRetries(false) .build() diff --git a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/InternalRowSerializer.scala b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/InternalRowSerializer.scala index 71922c42..1da39bb7 100644 --- a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/InternalRowSerializer.scala +++ b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/InternalRowSerializer.scala @@ -3,144 +3,16 @@ package tech.ytsaurus.spyt.serializers import org.apache.spark.metrics.yt.YtMetricsRegister import org.apache.spark.metrics.yt.YtMetricsRegister.ytMetricsSource._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.spyt.types._ -import org.apache.spark.sql.types._ -import org.slf4j.LoggerFactory import tech.ytsaurus.client.TableWriter -import tech.ytsaurus.client.rows.{WireProtocolWriteable, WireRowSerializer} -import tech.ytsaurus.core.tables.{ColumnValueType, TableSchema} -import tech.ytsaurus.spyt.format.conf.YtTableSparkSettings.{WriteSchemaHint, WriteTypeV3} -import tech.ytsaurus.spyt.serialization.YsonEncoder -import tech.ytsaurus.spyt.serializers.InternalRowSerializer._ -import tech.ytsaurus.spyt.serializers.SchemaConverter.{Unordered, decimalToBinary} -import tech.ytsaurus.spyt.types.YTsaurusTypes -import tech.ytsaurus.spyt.wrapper.LogLazy -import tech.ytsaurus.typeinfo.TiType import java.util.concurrent.{Executors, TimeUnit} import scala.annotation.tailrec -import scala.collection.mutable import scala.concurrent.duration.Duration import scala.concurrent.{ExecutionContext, Future} -class InternalRowSerializer(schema: StructType, writeSchemaConverter: WriteSchemaConverter) extends WireRowSerializer[InternalRow] with LogLazy { - - private val log = LoggerFactory.getLogger(getClass) - - private val tableSchema = writeSchemaConverter.tableSchema(schema, Unordered) - - override def getSchema: TableSchema = tableSchema - - private def getColumnType(i: Int): ColumnValueType = { - def isComposite(t: TiType): Boolean = t.isList || t.isDict || t.isStruct || t.isTuple || t.isVariant - - if (writeSchemaConverter.typeV3Format) { - val column = tableSchema.getColumnSchema(i) - val t = column.getTypeV3 - if (t.isOptional) { - val inner = t.asOptional().getItem - if (inner.isOptional || isComposite(inner)) { - ColumnValueType.COMPOSITE - } else { - column.getType - } - } else if (isComposite(t)) { - ColumnValueType.COMPOSITE - } else { - column.getType - } - } else { - tableSchema.getColumnType(i) - } - } - - override def serializeRow(row: InternalRow, - writeable: WireProtocolWriteable, - keyFieldsOnly: Boolean, - aggregate: Boolean, - idMapping: Array[Int]): Unit = { - writeable.writeValueCount(row.numFields) - for { - i <- 0 until row.numFields - } { - if (row.isNullAt(i)) { - writeable.writeValueHeader(valueId(i, idMapping), ColumnValueType.NULL, aggregate, 0) - } else { - val sparkField = schema(i) - val ytFieldHint = if (writeSchemaConverter.typeV3Format) Some(tableSchema.getColumnSchema(i).getTypeV3) else None - sparkField.dataType match { - case BinaryType => - writeBytes(writeable, idMapping, aggregate, i, row.getBinary(i), getColumnType) - case StringType => - writeBytes(writeable, idMapping, aggregate, i, row.getUTF8String(i).getBytes, getColumnType) - case d: DecimalType => - val value = row.getDecimal(i, d.precision, d.scale) - if (writeSchemaConverter.typeV3Format) { - val binary = decimalToBinary(ytFieldHint, d, value) - writeBytes(writeable, idMapping, aggregate, i, binary, getColumnType) - } else { - val targetColumnType = getColumnType(i) - targetColumnType match { - case ColumnValueType.INT64 | ColumnValueType.UINT64 | ColumnValueType.DOUBLE | ColumnValueType.STRING => - writeHeader(writeable, idMapping, aggregate, i, 0, _ => targetColumnType) - targetColumnType match { - case ColumnValueType.INT64 | ColumnValueType.UINT64 => - writeable.onInteger(value.toLong) - case ColumnValueType.DOUBLE => - writeable.onDouble(value.toDouble) - case ColumnValueType.STRING => - writeable.onBytes(value.toString().getBytes) - } - case _ => - throw new IllegalArgumentException("Writing decimal type without enabled type_v3 is not supported") - } - } - case t@(ArrayType(_, _) | StructType(_) | MapType(_, _, _)) => - val skipNulls = sparkField.metadata.contains("skipNulls") && sparkField.metadata.getBoolean("skipNulls") - writeBytes(writeable, idMapping, aggregate, i, - YsonEncoder.encode(row.get(i, sparkField.dataType), t, skipNulls, writeSchemaConverter.typeV3Format, ytFieldHint), - getColumnType) - case otherType => - val isExtendedType = YTsaurusTypes - .instance - .wireWriteRow(otherType, row, writeable, aggregate, idMapping, i, getColumnType) - if (!isExtendedType) { - writeHeader(writeable, idMapping, aggregate, i, 0, getColumnType) - otherType match { - case ByteType => writeable.onInteger(row.getByte(i)) - case ShortType => writeable.onInteger(row.getShort(i)) - case IntegerType => writeable.onInteger(row.getInt(i)) - case LongType => writeable.onInteger(row.getLong(i)) - case BooleanType => writeable.onBoolean(row.getBoolean(i)) - case FloatType => writeable.onDouble(row.getFloat(i)) - case DoubleType => writeable.onDouble(row.getDouble(i)) - case DateType => writeable.onInteger(row.getLong(i)) - case _: DatetimeType => writeable.onInteger(row.getLong(i)) - case TimestampType => writeable.onInteger(row.getLong(i)) - case _: Date32Type => writeable.onInteger(row.getInt(i)) - case _: Datetime64Type => writeable.onInteger(row.getLong(i)) - case _: Timestamp64Type => writeable.onInteger(row.getLong(i)) - case _: Interval64Type => writeable.onInteger(row.getLong(i)) - } - } - } - } - } - } -} - object InternalRowSerializer { - private val deserializers: ThreadLocal[mutable.Map[StructType, InternalRowSerializer]] = ThreadLocal.withInitial(() => mutable.ListMap.empty) private val context = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(4)) - def getOrCreate(schema: StructType, - schemaHint: Map[String, YtLogicalType], - filters: Array[Filter] = Array.empty, - typeV3Format: Boolean = false): InternalRowSerializer = { - deserializers.get().getOrElseUpdate(schema, new InternalRowSerializer(schema, new WriteSchemaConverter(schemaHint, typeV3Format))) - } - final def writeRows(writer: TableWriter[InternalRow], rows: java.util.ArrayList[InternalRow], timeout: Duration): Future[Unit] = { @@ -160,23 +32,6 @@ object InternalRowSerializer { writeRowsRecursive(writer, rows, timeout) } } - - private def valueId(id: Int, idMapping: Array[Int]): Int = { - if (idMapping != null) { - idMapping(id) - } else id - } - - def writeHeader(writeable: WireProtocolWriteable, idMapping: Array[Int], aggregate: Boolean, - i: Int, length: Int, getColumnType: Int => ColumnValueType): Unit = { - writeable.writeValueHeader(valueId(i, idMapping), getColumnType(i), aggregate, length) - } - - def writeBytes(writeable: WireProtocolWriteable, idMapping: Array[Int], aggregate: Boolean, - i: Int, bytes: Array[Byte], getColumnType: Int => ColumnValueType): Unit = { - writeHeader(writeable, idMapping, aggregate, i, bytes.length, getColumnType) - writeable.onBytes(bytes) - } } diff --git a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/WriteSchemaConverter.scala b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/WriteSchemaConverter.scala index 23a0b85e..e8f992da 100644 --- a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/WriteSchemaConverter.scala +++ b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/WriteSchemaConverter.scala @@ -58,9 +58,10 @@ class WriteSchemaConverter( case BooleanType => YtLogicalType.Boolean case d: DecimalType => val dT = if (d.precision > 35) applyYtLimitToSparkDecimal(d) else d - YtLogicalType.Decimal(dT.precision, dT.scale) + YtLogicalType.Decimal(dT.precision, dT.scale, d) case aT: ArrayType => YtLogicalType.Array(wrapSparkAttributes(ytLogicalTypeV3(aT.elementType), aT.containsNull)) + case _: StructType if hint != null => hint case sT: StructType if isTuple(sT) => YtLogicalType.Tuple { sT.fields.map(tF => diff --git a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalType.scala b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalType.scala index 84d6ed0e..c2668dff 100644 --- a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalType.scala +++ b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalType.scala @@ -1,21 +1,50 @@ package tech.ytsaurus.spyt.serializers +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.spyt.types._ import org.apache.spark.sql.types._ +import tech.ytsaurus.client.{InternalRowYTGetters, YTGetters} import tech.ytsaurus.core.tables.ColumnValueType import tech.ytsaurus.spyt.serializers.SchemaConverter.MetadataFields +import tech.ytsaurus.spyt.serializers.YsonRowConverter.{isNull, serializeValue} +import tech.ytsaurus.spyt.serializers.YtLogicalType.Binary.tiType +import tech.ytsaurus.spyt.serializers.YtLogicalType.Boolean.tiType +import tech.ytsaurus.spyt.serializers.YtLogicalType.Date.tiType +import tech.ytsaurus.spyt.serializers.YtLogicalType.Datetime.tiType +import tech.ytsaurus.spyt.serializers.YtLogicalType.Double.tiType +import tech.ytsaurus.spyt.serializers.YtLogicalType.Float.tiType +import tech.ytsaurus.spyt.serializers.YtLogicalType.Int16.tiType +import tech.ytsaurus.spyt.serializers.YtLogicalType.Int32.tiType +import tech.ytsaurus.spyt.serializers.YtLogicalType.Int64.tiType +import tech.ytsaurus.spyt.serializers.YtLogicalType.Int8.tiType +import tech.ytsaurus.spyt.serializers.YtLogicalType.Interval.tiType +import tech.ytsaurus.spyt.serializers.YtLogicalType.Null.tiType +import tech.ytsaurus.spyt.serializers.YtLogicalType.String.tiType +import tech.ytsaurus.spyt.serializers.YtLogicalType.Timestamp.tiType +import tech.ytsaurus.spyt.serializers.YtLogicalType.Uint32.tiType +import tech.ytsaurus.spyt.serializers.YtLogicalType.Uint64.tiType +import tech.ytsaurus.spyt.serializers.YtLogicalType.Utf8.tiType import tech.ytsaurus.typeinfo.StructType.Member import tech.ytsaurus.typeinfo.{TiType, TypeName} +import tech.ytsaurus.yson.YsonConsumer +import tech.ytsaurus.ysontree.YTreeBinarySerializer +import java.io.ByteArrayInputStream +import java.nio.ByteBuffer import scala.annotation.tailrec sealed trait SparkType { def topLevel: DataType + def innerLevel: DataType } + case class SingleSparkType(topLevel: DataType) extends SparkType { override def innerLevel: DataType = topLevel } + case class TopInnerSparkTypes(topLevel: DataType, innerLevel: DataType) extends SparkType sealed trait YtLogicalType { @@ -48,10 +77,15 @@ sealed trait YtLogicalType { def alias: YtLogicalTypeAlias def arrowSupported: Boolean = true + + def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList + + def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct } sealed trait YtLogicalTypeAlias { def name: String = aliases.head + def aliases: Seq[String] } @@ -69,6 +103,7 @@ sealed abstract class AtomicYtLogicalType(name: String, this(name, value, columnValueType, tiType, SingleSparkType(sparkType), otherAliases, arrowSupported) override def alias: YtLogicalTypeAlias = this + override def aliases: Seq[String] = name +: otherAliases } @@ -84,60 +119,591 @@ sealed abstract class CompositeYtLogicalTypeAlias(name: String, } object YtLogicalType { + import tech.ytsaurus.spyt.types.YTsaurusTypes.instance.sparkTypeFor - case object Null extends AtomicYtLogicalType("null", 0x02, ColumnValueType.NULL, TiType.nullType(), NullType) + case object Null extends AtomicYtLogicalType("null", 0x02, ColumnValueType.NULL, TiType.nullType(), NullType) { + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToNull { + override def getTiType: TiType = tiType + + override def getSize(list: ArrayData): Int = list.numElements() + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onEntity() + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = + new ytGetter.FromStructToNull { + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onEntity() + } + } + + case object Int64 extends AtomicYtLogicalType("int64", 0x03, ColumnValueType.INT64, TiType.int64(), LongType) { + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToLong { + override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) + + override def getSize(list: ArrayData): Int = list.numElements() + + override def getTiType: TiType = tiType + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(list.getLong(i)) + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToLong { + override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) + + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(struct.getLong(ordinal)) + } + } + + case object Uint64 extends AtomicYtLogicalType("uint64", 0x04, ColumnValueType.UINT64, TiType.uint64(), sparkTypeFor(TiType.uint64())) { + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToLong { + override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) + + override def getSize(list: ArrayData): Int = list.numElements() + + override def getTiType: TiType = tiType + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getLong(list, i)) + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToLong { + override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) + + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getLong(struct)) + } + } + + case object Float extends AtomicYtLogicalType( + "float", 0x05, ColumnValueType.DOUBLE, TiType.floatType(), + TopInnerSparkTypes(FloatType, DoubleType), Seq.empty, arrowSupported = false, + ) { + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToFloat { + override def getFloat(list: ArrayData, i: Int): Float = list.getFloat(i) + + override def getSize(list: ArrayData): Int = list.numElements() + + override def getTiType: TiType = tiType + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onDouble(getFloat(list, i)) + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToFloat { + override def getFloat(struct: InternalRow): Float = struct.getFloat(ordinal) + + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onDouble(getFloat(struct)) + } + } + + case object Double extends AtomicYtLogicalType("double", 0x05, ColumnValueType.DOUBLE, TiType.doubleType(), DoubleType) { + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToDouble { + override def getDouble(list: ArrayData, i: Int): Double = list.getDouble(i) - case object Int64 extends AtomicYtLogicalType("int64", 0x03, ColumnValueType.INT64, TiType.int64(), LongType) - case object Uint64 extends AtomicYtLogicalType("uint64", 0x04, ColumnValueType.UINT64, TiType.uint64(), sparkTypeFor(TiType.uint64())) - case object Float extends AtomicYtLogicalType("float", 0x05, ColumnValueType.DOUBLE, TiType.floatType(), - TopInnerSparkTypes(FloatType, DoubleType), Seq.empty, arrowSupported = false) - case object Double extends AtomicYtLogicalType("double", 0x05, ColumnValueType.DOUBLE, TiType.doubleType(), DoubleType) - case object Boolean extends AtomicYtLogicalType("boolean", 0x06, ColumnValueType.BOOLEAN, TiType.bool(), BooleanType, Seq("bool")) + override def getSize(list: ArrayData): Int = list.numElements() + + override def getTiType: TiType = tiType + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onDouble(getDouble(list, i)) + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToDouble { + override def getDouble(struct: InternalRow): Double = struct.getDouble(ordinal) + + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onDouble(getDouble(struct)) + } + } + + case object Boolean extends AtomicYtLogicalType("boolean", 0x06, ColumnValueType.BOOLEAN, TiType.bool(), BooleanType, Seq("bool")) { + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToBoolean { + override def getBoolean(list: ArrayData, i: Int): Boolean = list.getBoolean(i) + + override def getSize(list: ArrayData): Int = list.numElements() + + override def getTiType: TiType = tiType + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onBoolean(list.getBoolean(i)) + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToBoolean { + override def getBoolean(struct: InternalRow): Boolean = struct.getBoolean(ordinal) + + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onBoolean(struct.getBoolean(ordinal)) + } + } + + private def getBytes(byteBuffer: ByteBuffer): scala.Array[Byte] = { + val bytes = new scala.Array[Byte](byteBuffer.remaining()) + byteBuffer.get(bytes) + bytes + } + + case object String extends AtomicYtLogicalType("string", 0x10, ColumnValueType.STRING, TiType.string(), StringType) { + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToString { + override def getString(list: ArrayData, i: Int): ByteBuffer = list.getUTF8String(i).getByteBuffer + + override def getSize(list: ArrayData): Int = list.numElements() + + override def getTiType: TiType = tiType + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = { + val bytes = getBytes(getString(list, i)) + ysonConsumer.onString(bytes, 0, bytes.length) + } + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToString { + override def getString(struct: InternalRow): ByteBuffer = struct.getUTF8String(ordinal).getByteBuffer + + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = { + val bytes = getBytes(getString(struct)) + ysonConsumer.onString(bytes, 0, bytes.length) + } + } + } - case object String extends AtomicYtLogicalType("string", 0x10, ColumnValueType.STRING, TiType.string(), StringType) case object Binary extends AtomicYtLogicalType("binary", 0x10, ColumnValueType.STRING, TiType.string(), BinaryType) { override def getName(isColumnType: Boolean): String = columnValueType.getName override def getNameV3(inner: Boolean): String = { if (inner) alias.name else "string" } + + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToString { + override def getString(list: ArrayData, i: Int): ByteBuffer = ByteBuffer.wrap(list.getBinary(i)) + + override def getSize(list: ArrayData): Int = list.numElements() + + override def getTiType: TiType = tiType + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = { + val byteBuffer = getString(list, i) + val bytes = new scala.Array[Byte](byteBuffer.remaining()) + byteBuffer.get(bytes) + ysonConsumer.onString(bytes, 0, bytes.length) + } + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToString { + override def getString(struct: InternalRow): ByteBuffer = ByteBuffer.wrap(struct.getBinary(ordinal)) + + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = { + val byteBuffer = getString(struct) + val bytes = new scala.Array[Byte](byteBuffer.remaining()) + byteBuffer.get(bytes) + ysonConsumer.onString(bytes, 0, bytes.length) + } + } } + case object Any extends AtomicYtLogicalType("any", 0x11, ColumnValueType.ANY, TiType.yson(), sparkTypeFor(TiType.yson()), Seq("yson")) { override def nullable: Boolean = true + + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToYson { + override def getSize(list: ArrayData): Int = list.numElements() + + override def getTiType: TiType = tiType + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + YTreeBinarySerializer.deserialize(new ByteArrayInputStream(list.getBinary(i)), ysonConsumer) + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToYson { + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + YTreeBinarySerializer.deserialize(new ByteArrayInputStream(struct.getBinary(ordinal)), ysonConsumer) + } + } + + case object Int8 extends AtomicYtLogicalType("int8", 0x1000, ColumnValueType.INT64, TiType.int8(), ByteType) { + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToByte { + override def getByte(list: ArrayData, i: Int): Byte = list.getByte(i) + + override def getSize(list: ArrayData): Int = list.numElements() + + override def getTiType: TiType = tiType + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(list.getByte(i)) + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToByte { + override def getByte(struct: InternalRow): Byte = struct.getByte(ordinal) + + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(struct.getByte(ordinal)) + } + } + + case object Uint8 extends AtomicYtLogicalType("uint8", 0x1001, ColumnValueType.INT64, TiType.uint8(), ShortType) { + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToByte { + override def getByte(list: ArrayData, i: Int): Byte = list.getShort(i).toByte + + override def getSize(list: ArrayData): Int = list.numElements() + + override def getTiType: TiType = tiType + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getByte(list, i)) + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToByte { + override def getByte(struct: InternalRow): Byte = struct.getShort(ordinal).toByte + + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getByte(struct)) + } } - case object Int8 extends AtomicYtLogicalType("int8", 0x1000, ColumnValueType.INT64, TiType.int8(), ByteType) - case object Uint8 extends AtomicYtLogicalType("uint8", 0x1001, ColumnValueType.INT64, TiType.uint8(), ShortType) + case object Int16 extends AtomicYtLogicalType("int16", 0x1003, ColumnValueType.INT64, TiType.int16(), ShortType) { + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToShort { + override def getShort(list: ArrayData, i: Int): Short = list.getShort(i) - case object Int16 extends AtomicYtLogicalType("int16", 0x1003, ColumnValueType.INT64, TiType.int16(), ShortType) - case object Uint16 extends AtomicYtLogicalType("uint16", 0x1004, ColumnValueType.INT64, TiType.uint16(), IntegerType) + override def getSize(list: ArrayData): Int = list.numElements() - case object Int32 extends AtomicYtLogicalType("int32", 0x1005, ColumnValueType.INT64, TiType.int32(), IntegerType) - case object Uint32 extends AtomicYtLogicalType("uint32", 0x1006, ColumnValueType.INT64, TiType.uint32(), LongType) + override def getTiType: TiType = tiType + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(list.getShort(i)) + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToShort { + override def getShort(struct: InternalRow): Short = struct.getShort(ordinal) + + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(struct.getShort(ordinal)) + } + } + + case object Uint16 extends AtomicYtLogicalType("uint16", 0x1004, ColumnValueType.INT64, TiType.uint16(), IntegerType) { + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToShort { + override def getShort(list: ArrayData, i: Int): Short = list.getInt(i).toShort + + override def getSize(list: ArrayData): Int = list.numElements() + + override def getTiType: TiType = tiType + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getShort(list, i)) + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToShort { + override def getShort(struct: InternalRow): Short = struct.getInt(ordinal).toShort + + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getShort(struct)) + } + } + + case object Int32 extends AtomicYtLogicalType("int32", 0x1005, ColumnValueType.INT64, TiType.int32(), IntegerType) { + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToInt { + override def getInt(list: ArrayData, i: Int): Int = list.getInt(i) + + override def getSize(list: ArrayData): Int = list.numElements() + + override def getTiType: TiType = tiType + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(list.getInt(i)) + } - case object Utf8 extends AtomicYtLogicalType("utf8", 0x1007, ColumnValueType.STRING, TiType.utf8(), StringType) + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToInt { + override def getInt(struct: InternalRow): Int = struct.getInt(ordinal) + + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(struct.getInt(ordinal)) + } + } + + case object Uint32 extends AtomicYtLogicalType("uint32", 0x1006, ColumnValueType.INT64, TiType.uint32(), LongType) { + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToInt { + override def getInt(list: ArrayData, i: Int): Int = list.getLong(i).toInt + + override def getSize(list: ArrayData): Int = list.numElements() + + override def getTiType: TiType = tiType + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getInt(list, i)) + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToInt { + override def getInt(struct: InternalRow): Int = struct.getLong(ordinal).toInt + + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getInt(struct)) + } + } + + case object Utf8 extends AtomicYtLogicalType("utf8", 0x1007, ColumnValueType.STRING, TiType.utf8(), StringType) { + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToString { + override def getString(list: ArrayData, i: Int): ByteBuffer = list.getUTF8String(i).getByteBuffer + + override def getSize(list: ArrayData): Int = list.numElements() + + override def getTiType: TiType = tiType + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = { + val bytes = getBytes(list.getUTF8String(i).getByteBuffer) + ysonConsumer.onString(bytes, 0, bytes.length) + } + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToString { + override def getString(struct: InternalRow): ByteBuffer = struct.getUTF8String(ordinal).getByteBuffer + + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = { + val bytes = getBytes(struct.getUTF8String(ordinal).getByteBuffer) + ysonConsumer.onString(bytes, 0, bytes.length) + } + } + } // Unsupported types are listed here: yt/yt/client/arrow/arrow_row_stream_encoder.cpp - case object Date extends AtomicYtLogicalType("date", 0x1008, ColumnValueType.UINT64, TiType.date(), DateType, arrowSupported = false) - case object Datetime extends AtomicYtLogicalType("datetime", 0x1009, ColumnValueType.UINT64, TiType.datetime(), new DatetimeType(), arrowSupported = false) - case object Timestamp extends AtomicYtLogicalType("timestamp", 0x100a, ColumnValueType.UINT64, TiType.timestamp(), TimestampType, arrowSupported = false) - case object Interval extends AtomicYtLogicalType("interval", 0x100b, ColumnValueType.INT64, TiType.interval(), LongType, arrowSupported = false) + case object Date extends AtomicYtLogicalType("date", 0x1008, ColumnValueType.UINT64, TiType.date(), DateType, arrowSupported = false) { + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToInt { + override def getInt(list: ArrayData, i: Int): Int = list.getInt(i) + + override def getSize(list: ArrayData): Int = list.numElements() + + override def getTiType: TiType = tiType + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getInt(list, i)) + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToInt { + override def getInt(struct: InternalRow): Int = struct.getInt(ordinal) + + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getInt(struct)) + } + } - case object Void extends AtomicYtLogicalType("void", 0x100c, ColumnValueType.NULL, TiType.voidType(), NullType) //? + case object Datetime extends AtomicYtLogicalType("datetime", 0x1009, ColumnValueType.UINT64, TiType.datetime(), new DatetimeType(), arrowSupported = false) { + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToLong { + override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) - case object Date32 extends AtomicYtLogicalType("date32", 0x1018, ColumnValueType.INT64, TiType.date32(), new Date32Type(), arrowSupported = false) - case object Datetime64 extends AtomicYtLogicalType("datetime64", 0x1019, ColumnValueType.INT64, TiType.datetime64(), new Datetime64Type(), arrowSupported = false) - case object Timestamp64 extends AtomicYtLogicalType("timestamp64", 0x101a, ColumnValueType.INT64, TiType.timestamp64(), new Timestamp64Type(), arrowSupported = false) - case object Interval64 extends AtomicYtLogicalType("interval64", 0x101b, ColumnValueType.INT64, TiType.interval64(), new Interval64Type(), arrowSupported = false) + override def getSize(list: ArrayData): Int = list.numElements() + override def getTiType: TiType = tiType - case class Decimal(precision: Int, scale: Int) extends CompositeYtLogicalType { + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getLong(list, i)) + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToLong { + override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) + + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getLong(struct)) + } + } + + case object Timestamp extends AtomicYtLogicalType("timestamp", 0x100a, ColumnValueType.UINT64, TiType.timestamp(), TimestampType, arrowSupported = false) { + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToLong { + override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) + + override def getSize(list: ArrayData): Int = list.numElements() + + override def getTiType: TiType = tiType + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getLong(list, i)) + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToLong { + override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) + + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getLong(struct)) + } + } + + case object Interval extends AtomicYtLogicalType("interval", 0x100b, ColumnValueType.INT64, TiType.interval(), LongType, arrowSupported = false) { + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToLong { + override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) + + override def getSize(list: ArrayData): Int = list.numElements() + + override def getTiType: TiType = tiType + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onInteger(getLong(list, i)) + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToLong { + override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) + + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onInteger(getLong(struct)) + } + } + + case object Void extends AtomicYtLogicalType("void", 0x100c, ColumnValueType.NULL, TiType.voidType(), NullType) { + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToNull { + override def getTiType: TiType = tiType + + override def getSize(list: ArrayData): Int = list.numElements() + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onEntity() + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = + new ytGetter.FromStructToNull { + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onEntity() + } + } + + case object Date32 extends AtomicYtLogicalType("date32", 0x1018, ColumnValueType.INT64, TiType.date32(), new Date32Type(), arrowSupported = false) { + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToInt { + override def getInt(list: ArrayData, i: Int): Int = list.getInt(i) + + override def getSize(list: ArrayData): Int = list.numElements() + + override def getTiType: TiType = tiType + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getInt(list, i)) + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToInt { + override def getInt(struct: InternalRow): Int = struct.getInt(ordinal) + + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getInt(struct)) + } + } + + case object Datetime64 extends AtomicYtLogicalType("datetime64", 0x1019, ColumnValueType.INT64, TiType.datetime64(), new Datetime64Type(), arrowSupported = false) { + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToLong { + override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) + + override def getSize(list: ArrayData): Int = list.numElements() + + override def getTiType: TiType = tiType + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getLong(list, i)) + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToLong { + override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) + + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getLong(struct)) + } + } + + case object Timestamp64 extends AtomicYtLogicalType("timestamp64", 0x101a, ColumnValueType.INT64, TiType.timestamp64(), new Timestamp64Type(), arrowSupported = false) { + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToLong { + override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) + + override def getSize(list: ArrayData): Int = list.numElements() + + override def getTiType: TiType = tiType + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getLong(list, i)) + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToLong { + override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) + + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getLong(struct)) + } + } + + case object Interval64 extends AtomicYtLogicalType("interval64", 0x101b, ColumnValueType.INT64, TiType.interval64(), new Interval64Type(), arrowSupported = false) { + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToLong { + override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) + + override def getSize(list: ArrayData): Int = list.numElements() + + override def getTiType: TiType = tiType + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onInteger(getLong(list, i)) + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToLong { + override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) + + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onInteger(getLong(struct)) + } + } + + case class Decimal(precision: Int, scale: Int, decimalType: DecimalType) extends CompositeYtLogicalType { override def sparkType: SparkType = SingleSparkType(DecimalType(precision, scale)) override def alias: CompositeYtLogicalTypeAlias = Decimal override def tiType: TiType = TiType.decimal(precision, scale) + + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToBigDecimal { + override def getBigDecimal(list: ArrayData, i: Int): java.math.BigDecimal = + list.getDecimal(i, decimalType.precision, decimalType.scale).toJavaBigDecimal.setScale(scale) + + override def getSize(list: ArrayData): Int = list.numElements() + + override def getTiType: TiType = tiType + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = { + val bytes = getBigDecimal(list, i).unscaledValue().toByteArray + ysonConsumer.onString(bytes, 0, bytes.length) + } + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToBigDecimal { + override def getBigDecimal(struct: InternalRow): java.math.BigDecimal = + struct.getDecimal(ordinal, decimalType.precision, decimalType.scale).toJavaBigDecimal.setScale(scale) + + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = { + val bytes = getBigDecimal(struct).unscaledValue().toByteArray + ysonConsumer.onString(bytes, 0, bytes.length) + } + } } case object Decimal extends CompositeYtLogicalTypeAlias("decimal") @@ -158,6 +724,54 @@ object YtLogicalType { override def alias: CompositeYtLogicalTypeAlias = Optional override def arrowSupported: Boolean = inner.arrowSupported + + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToOptional { + private val notEmptyGetter = inner.ytGettersFromList(ytGetter) + + override def getNotEmptyGetter: ytGetter.FromList = notEmptyGetter + + override def isEmpty(list: ArrayData, i: Int): Boolean = list.isNullAt(i) + + override def getSize(list: ArrayData): Int = list.numElements() + + override def getTiType: TiType = tiType + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = { + if (list.isNullAt(i)) { + ysonConsumer.onEntity() + } else if (inner.isInstanceOf[Optional]) { + ysonConsumer.onBeginList() + ysonConsumer.onListItem() + notEmptyGetter.getYson(list, i, ysonConsumer) + ysonConsumer.onEndList() + } else { + notEmptyGetter.getYson(list, i, ysonConsumer) + } + } + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToOptional { + private val notEmptyGetter = inner.ytGettersFromStruct(ytGetter, ordinal) + + override def getNotEmptyGetter: ytGetter.FromStruct = notEmptyGetter + + override def isEmpty(struct: InternalRow): Boolean = struct.isNullAt(ordinal) + + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = { + if (struct.isNullAt(ordinal)) { + ysonConsumer.onEntity() + } else if (inner.isInstanceOf[Optional]) { + ysonConsumer.onBeginList() + ysonConsumer.onListItem() + notEmptyGetter.getYson(struct, ysonConsumer) + ysonConsumer.onEndList() + } else { + notEmptyGetter.getYson(struct, ysonConsumer) + } + } + } } case object Optional extends CompositeYtLogicalTypeAlias(TypeName.Optional.getWireName) @@ -172,19 +786,129 @@ object YtLogicalType { MapType(dictKey.sparkType.innerLevel, dictValue.sparkType.innerLevel, dictValue.nullable) ) + private def newGetter(ytGetter: InternalRowYTGetters): ytGetter.FromDict = new ytGetter.FromDict { + private val keyGetter = dictKey.ytGettersFromList(ytGetter) + private val valueGetter = dictValue.ytGettersFromList(ytGetter) + + override def getKeyGetter: ytGetter.FromList = keyGetter + + override def getValueGetter: ytGetter.FromList = valueGetter + + override def getSize(dict: MapData): Int = dict.numElements() + + override def getKeys(dict: MapData): ArrayData = dict.keyArray() + + override def getValues(dict: MapData): ArrayData = dict.valueArray() + + override def getTiType: TiType = tiType + } + + def newYsonSerializer(getter: InternalRowYTGetters#FromDict): (MapData, YsonConsumer) => Unit = { + val keyGetter = getter.getKeyGetter + val valueGetter = getter.getValueGetter + (dict, ysonConsumer) => { + ysonConsumer.onBeginList() + val keys = dict.keyArray() + val values = dict.valueArray() + for (i <- 0 until dict.numElements()) { + ysonConsumer.onListItem() + ysonConsumer.onBeginList() + ysonConsumer.onListItem() + keyGetter.getYson(keys, i, ysonConsumer) + ysonConsumer.onListItem() + valueGetter.getYson(values, i, ysonConsumer) + ysonConsumer.onEndList() + } + ysonConsumer.onEndList() + } + } + override def tiType: TiType = TiType.dict(dictKey.tiType, dictValue.tiType) override def alias: CompositeYtLogicalTypeAlias = Dict + + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToDict { + private val getter = newGetter(ytGetter) + private val ysonSerializer = newYsonSerializer(getter) + + override def getGetter(): ytGetter.FromDict = getter + + override def getTiType: TiType = tiType + + override def getSize(list: ArrayData): Int = list.numElements() + + override def getDict(list: ArrayData, i: Int): MapData = list.getMap(i) + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonSerializer(list.getMap(i), ysonConsumer) + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToDict { + private val getter = newGetter(ytGetter) + private val ysonSerializer = newYsonSerializer(getter) + + override def getGetter(): ytGetter.FromDict = getter + + override def getDict(struct: InternalRow): MapData = struct.getMap(ordinal) + + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonSerializer(struct.getMap(ordinal), ysonConsumer) + } } case object Dict extends CompositeYtLogicalTypeAlias(TypeName.Dict.getWireName) case class Array(inner: YtLogicalType) extends CompositeYtLogicalType { - override def sparkType: SparkType = SingleSparkType( ArrayType(inner.sparkType.innerLevel, inner.nullable)) + override def sparkType: SparkType = SingleSparkType(ArrayType(inner.sparkType.innerLevel, inner.nullable)) override def tiType: TiType = TiType.list(inner.tiType) override def alias: CompositeYtLogicalTypeAlias = Array + + + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToList { + val elementGetter: ytGetter.FromList = inner.ytGettersFromList(ytGetter) + + override def getSize(list: ArrayData): Int = list.numElements() + + override def getTiType: TiType = tiType + + override def getElementGetter: ytGetter.FromList = elementGetter + + override def getList(list: ArrayData, i: Int): ArrayData = list.getArray(i) + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = { + val value = list.getArray(i) + ysonConsumer.onBeginList() + for (j <- 0 until value.numElements()) { + ysonConsumer.onListItem() + elementGetter.getYson(value, j, ysonConsumer) + } + ysonConsumer.onEndList() + } + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToList { + val elementGetter: ytGetter.FromList = inner.ytGettersFromList(ytGetter) + + override def getElementGetter: ytGetter.FromList = elementGetter + + override def getList(struct: InternalRow): ArrayData = struct.getArray(ordinal) + + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = { + val value = struct.getArray(ordinal) + ysonConsumer.onBeginList() + for (j <- 0 until value.numElements()) { + ysonConsumer.onListItem() + elementGetter.getYson(value, j, ysonConsumer) + } + ysonConsumer.onEndList() + } + } } case object Array extends CompositeYtLogicalTypeAlias(TypeName.List.getWireName) @@ -194,25 +918,121 @@ object YtLogicalType { .map { case (name, ytType, meta) => getStructField(name, ytType, meta, topLevel = false) })) import scala.collection.JavaConverters._ + override def tiType: TiType = TiType.struct( - fields.map{ case (name, ytType, _) => new Member(name, ytType.tiType)}.asJava + fields.map { case (name, ytType, _) => new Member(name, ytType.tiType) }.asJava ) override def alias: CompositeYtLogicalTypeAlias = Struct + + def newMembersGetters(ytGetter: InternalRowYTGetters): java.util.List[java.util.Map.Entry[String, ytGetter.FromStruct]] = + fields.zipWithIndex.map { case (field, i) => + java.util.Map.entry(field._1, field._2.ytGettersFromStruct(ytGetter, i)) + }.asJava + + def yson(ytGetter: InternalRowYTGetters)( + membersGetters: java.util.List[java.util.Map.Entry[String, ytGetter.FromStruct]], + internalRow: InternalRow, ysonConsumer: YsonConsumer, + ): Unit = { + ysonConsumer.onBeginList() + for (i <- 0 until membersGetters.size()) { + ysonConsumer.onListItem() + membersGetters.get(i).getValue.getYson(internalRow, ysonConsumer) + } + ysonConsumer.onEndList() + } + + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToStruct { + private val membersGetters = newMembersGetters(ytGetter) + + override def getMembersGetters(): java.util.List[java.util.Map.Entry[String, ytGetter.FromStruct]] = + membersGetters + + override def getStruct(list: ArrayData, i: Int): InternalRow = list.getStruct(i, fields.size) + + override def getSize(list: ArrayData): Int = list.numElements() + + override def getTiType: TiType = tiType + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + yson(ytGetter)(membersGetters, list.getStruct(i, membersGetters.size()), ysonConsumer) + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToStruct { + private val membersGetters = newMembersGetters(ytGetter) + + override def getMembersGetters(): java.util.List[java.util.Map.Entry[String, ytGetter.FromStruct]] = + membersGetters + + override def getStruct(struct: InternalRow): InternalRow = struct.getStruct(ordinal, fields.size) + + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + yson(ytGetter)(membersGetters, struct.getStruct(ordinal, membersGetters.size()), ysonConsumer) + } } case object Struct extends CompositeYtLogicalTypeAlias(TypeName.Struct.getWireName) case class Tuple(elements: Seq[(YtLogicalType, Metadata)]) extends CompositeYtLogicalType { + private val entries = elements.zipWithIndex.map { case ((ytType, _), index) => (s"_${1 + index}", ytType) } override def sparkType: SparkType = SingleSparkType(StructType(elements.zipWithIndex .map { case ((ytType, meta), index) => getStructField(s"_${1 + index}", ytType, meta, topLevel = false) })) import scala.collection.JavaConverters._ + override def tiType: TiType = TiType.tuple( - elements.map { case (e, _) => e.tiType } .asJava + elements.map { case (e, _) => e.tiType }.asJava ) override def alias: CompositeYtLogicalTypeAlias = Tuple + + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToStruct { + private val membersGetters = entries.zipWithIndex.map { case ((name, logicalType), i) => + java.util.Map.entry(name, logicalType.ytGettersFromStruct(ytGetter, i)) + }.asJava + + override def getMembersGetters(): java.util.List[java.util.Map.Entry[String, ytGetter.FromStruct]] = membersGetters + + override def getStruct(list: ArrayData, i: Int): InternalRow = list.getStruct(i, elements.size) + + override def getSize(list: ArrayData): Int = list.numElements() + + override def getTiType: TiType = tiType + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = { + val value = list.getStruct(i, membersGetters.size()) + ysonConsumer.onBeginList() + membersGetters.forEach { getter => + ysonConsumer.onListItem() + getter.getValue.getYson(value, ysonConsumer) + } + ysonConsumer.onEndList() + } + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToStruct { + private val membersGetters = entries.zipWithIndex.map { case ((name, logicalType), i) => + java.util.Map.entry(name, logicalType.ytGettersFromStruct(ytGetter, i)) + }.asJava + + override def getMembersGetters(): java.util.List[java.util.Map.Entry[String, ytGetter.FromStruct]] = membersGetters + + override def getStruct(struct: InternalRow): InternalRow = struct.getStruct(ordinal, elements.size) + + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = { + val value = struct.getStruct(ordinal, membersGetters.size()) + ysonConsumer.onBeginList() + membersGetters.forEach { getter => + ysonConsumer.onListItem() + getter.getValue.getYson(value, ysonConsumer) + } + ysonConsumer.onEndList() + } + } } case object Tuple extends CompositeYtLogicalTypeAlias(TypeName.Tuple.getWireName) @@ -223,34 +1043,103 @@ object YtLogicalType { override def tiType: TiType = TiType.tagged(inner.tiType, tag) override def alias: CompositeYtLogicalTypeAlias = Tagged + + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = inner.ytGettersFromList(ytGetter) + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = inner.ytGettersFromStruct(ytGetter, ordinal) } case object Tagged extends CompositeYtLogicalTypeAlias(TypeName.Tagged.getWireName) + private class VariantGetter(fields: Seq[YtLogicalType], ytGetter: InternalRowYTGetters) { + private val getters = fields.zipWithIndex.map { case (field, i) => field.ytGettersFromStruct(ytGetter, i) } + + def get(row: InternalRow, ysonConsumer: YsonConsumer): Unit = { + val notNulls = (0 until row.numFields).filter(!row.isNullAt(_)) + if (notNulls.isEmpty) { + throw new IllegalArgumentException("All elements in variant is null") + } else if (notNulls.size > 1) { + throw new IllegalArgumentException("Not null element must be single") + } else { + val index = notNulls.head + ysonConsumer.onBeginList() + ysonConsumer.onListItem() + ysonConsumer.onInteger(index) + ysonConsumer.onListItem() + getters(index).getYson(row, ysonConsumer) + ysonConsumer.onEndList() + } + } + } + case class VariantOverStruct(fields: Seq[(String, YtLogicalType, Metadata)]) extends CompositeYtLogicalType { override def sparkType: SparkType = SingleSparkType(StructType(fields.map { case (name, ytType, meta) => - getStructField(s"_v$name", ytType, meta, forcedNullability = Some(true), topLevel = false) })) + getStructField(s"_v$name", ytType, meta, forcedNullability = Some(true), topLevel = false) + })) import scala.collection.JavaConverters._ + override def tiType: TiType = TiType.variantOverStruct( - fields.map{ case (name, ytType, _) => new Member(name, ytType.tiType)}.asJava + fields.map { case (name, ytType, _) => new Member(name, ytType.tiType) }.asJava ) override def alias: CompositeYtLogicalTypeAlias = Variant + + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToYson { + val getter = new VariantGetter(fields.map(_._2), ytGetter) + + override def getSize(list: ArrayData): Int = list.numElements() + + override def getTiType: TiType = tiType + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + getter.get(list.getStruct(i, fields.size), ysonConsumer) + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToYson { + val getter = new VariantGetter(fields.map(_._2), ytGetter) + + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + getter.get(struct.getStruct(ordinal, fields.size), ysonConsumer) + } } case class VariantOverTuple(fields: Seq[(YtLogicalType, Metadata)]) extends CompositeYtLogicalType { override def sparkType: SparkType = SingleSparkType( StructType(fields.zipWithIndex.map { case ((ytType, meta), index) => - getStructField(s"_v_${1 + index}", ytType, meta, forcedNullability = Some(true), topLevel = false) }) + getStructField(s"_v_${1 + index}", ytType, meta, forcedNullability = Some(true), topLevel = false) + }) ) import scala.collection.JavaConverters._ + override def tiType: TiType = TiType.variantOverTuple( fields.map { case (e, _) => e.tiType }.asJava ) override def alias: CompositeYtLogicalTypeAlias = Variant + + override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToYson { + val getter = new VariantGetter(fields.map(_._1), ytGetter) + + override def getSize(list: ArrayData): Int = list.numElements() + + override def getTiType: TiType = tiType + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + getter.get(list.getStruct(i, fields.size), ysonConsumer) + } + + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToYson { + val getter = new VariantGetter(fields.map(_._1), ytGetter) + + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + getter.get(struct.getStruct(ordinal, fields.size), ysonConsumer) + } } case object Variant extends CompositeYtLogicalTypeAlias(TypeName.Variant.getWireName) diff --git a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalTypeSerializer.scala b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalTypeSerializer.scala index c2b3599f..6a184c9f 100644 --- a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalTypeSerializer.scala +++ b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalTypeSerializer.scala @@ -1,6 +1,6 @@ package tech.ytsaurus.spyt.serializers -import org.apache.spark.sql.types.Metadata +import org.apache.spark.sql.types.{DecimalType, Metadata} import tech.ytsaurus.ysontree.{YTree, YTreeBuilder, YTreeMapNode, YTreeNode, YTreeStringNode} import scala.collection.JavaConverters.asScalaBufferConverter @@ -105,10 +105,8 @@ object YtLogicalTypeSerializer { case YtLogicalType.Optional => YtLogicalType.Optional(deserializeTypeV3(m.getOrThrow("item"))) case YtLogicalType.Decimal => - YtLogicalType.Decimal( - m.getOrThrow("precision").intValue(), - m.getOrThrow("scale").intValue() - ) + val decimalType = DecimalType(m.getOrThrow("precision").intValue(), m.getOrThrow("scale").intValue()) + YtLogicalType.Decimal(decimalType.precision, decimalType.scale, decimalType) case YtLogicalType.Dict => YtLogicalType.Dict( deserializeTypeV3(m.getOrThrow("key")), From 161536efa868be3db0c1b397603ee943593e86fc Mon Sep 17 00:00:00 2001 From: Nikita Sokolov Date: Tue, 19 Nov 2024 15:48:25 +0100 Subject: [PATCH 02/12] PR comments: clean up imports arrow_write_enabled config --- .../ytsaurus/spyt/format/YtOutputWriter.scala | 19 ++- .../format/conf/YtTableSparkSettings.scala | 2 + .../serializers/InternalRowSerializer.scala | 144 ++++++++++++++++++ .../spyt/serializers/YtLogicalType.scala | 21 +-- 4 files changed, 159 insertions(+), 27 deletions(-) diff --git a/data-source/src/main/scala/tech/ytsaurus/spyt/format/YtOutputWriter.scala b/data-source/src/main/scala/tech/ytsaurus/spyt/format/YtOutputWriter.scala index 3ad532d0..00321da9 100644 --- a/data-source/src/main/scala/tech/ytsaurus/spyt/format/YtOutputWriter.scala +++ b/data-source/src/main/scala/tech/ytsaurus/spyt/format/YtOutputWriter.scala @@ -9,7 +9,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.execution.datasources.OutputWriter import org.apache.spark.sql.types.StructType import org.slf4j.LoggerFactory -import tech.ytsaurus.client.request.{TransactionalOptions, WriteTable} +import tech.ytsaurus.client.request.{TransactionalOptions, WriteSerializationContext, WriteTable} import tech.ytsaurus.client.{ArrowWriteSerializationContext, CompoundClient, InternalRowYTGetters, TableWriter} import tech.ytsaurus.core.GUID import tech.ytsaurus.spyt.format.conf.SparkYtWriteConfiguration @@ -156,12 +156,17 @@ class YtOutputWriter(richPath: YPathEnriched, val internalRowGetters = new InternalRowYTGetters() val request = WriteTable.builder[InternalRow]() .setPath(appendPath) - .setSerializationContext(new ArrowWriteSerializationContext[InternalRow, ArrayData, MapData, InternalRowYTGetters]( - WriteSchemaConverter(options).ytLogicalTypeStruct(schema).fields.zipWithIndex.map { - case ((name, ytLogicalType, _), i) => - util.Map.entry(name, ytLogicalType.ytGettersFromStruct(internalRowGetters, i)) - }.asJava - )) + .setSerializationContext( + if (options.ytConf(ArrowWriteEnabled)) + new ArrowWriteSerializationContext[InternalRow, ArrayData, MapData, InternalRowYTGetters]( + WriteSchemaConverter(options).ytLogicalTypeStruct(schema).fields.zipWithIndex.map { + case ((name, ytLogicalType, _), i) => + util.Map.entry(name, ytLogicalType.ytGettersFromStruct(internalRowGetters, i)) + }.asJava + ) + else + new WriteSerializationContext(new InternalRowSerializer(schema, WriteSchemaConverter(options))) + ) .setTransactionalOptions(new TransactionalOptions(GUID.valueOf(transactionGuid))) .setNeedRetries(false) .build() diff --git a/data-source/src/main/scala/tech/ytsaurus/spyt/format/conf/YtTableSparkSettings.scala b/data-source/src/main/scala/tech/ytsaurus/spyt/format/conf/YtTableSparkSettings.scala index 84eb69a0..25879ee1 100644 --- a/data-source/src/main/scala/tech/ytsaurus/spyt/format/conf/YtTableSparkSettings.scala +++ b/data-source/src/main/scala/tech/ytsaurus/spyt/format/conf/YtTableSparkSettings.scala @@ -65,6 +65,8 @@ object YtTableSparkSettings { case object ArrowEnabled extends ConfigEntry[Boolean]("arrow_enabled", Some(true)) + case object ArrowWriteEnabled extends ConfigEntry[Boolean]("arrow_write_enabled", Some(true)) + case object KeyPartitioned extends ConfigEntry[Boolean]("key_partitioned") case object Dynamic extends ConfigEntry[Boolean]("dynamic") diff --git a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/InternalRowSerializer.scala b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/InternalRowSerializer.scala index 1da39bb7..2e833dab 100644 --- a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/InternalRowSerializer.scala +++ b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/InternalRowSerializer.scala @@ -3,16 +3,143 @@ package tech.ytsaurus.spyt.serializers import org.apache.spark.metrics.yt.YtMetricsRegister import org.apache.spark.metrics.yt.YtMetricsRegister.ytMetricsSource._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.spyt.types._ +import org.apache.spark.sql.types._ +import org.slf4j.LoggerFactory import tech.ytsaurus.client.TableWriter +import tech.ytsaurus.client.rows.{WireProtocolWriteable, WireRowSerializer} +import tech.ytsaurus.core.tables.{ColumnValueType, TableSchema} +import tech.ytsaurus.spyt.serialization.YsonEncoder +import tech.ytsaurus.spyt.serializers.InternalRowSerializer._ +import tech.ytsaurus.spyt.serializers.SchemaConverter.{Unordered, decimalToBinary} +import tech.ytsaurus.spyt.types.YTsaurusTypes +import tech.ytsaurus.spyt.wrapper.LogLazy +import tech.ytsaurus.typeinfo.TiType import java.util.concurrent.{Executors, TimeUnit} import scala.annotation.tailrec +import scala.collection.mutable import scala.concurrent.duration.Duration import scala.concurrent.{ExecutionContext, Future} +class InternalRowSerializer(schema: StructType, writeSchemaConverter: WriteSchemaConverter) extends WireRowSerializer[InternalRow] with LogLazy { + + private val log = LoggerFactory.getLogger(getClass) + + private val tableSchema = writeSchemaConverter.tableSchema(schema, Unordered) + + override def getSchema: TableSchema = tableSchema + + private def getColumnType(i: Int): ColumnValueType = { + def isComposite(t: TiType): Boolean = t.isList || t.isDict || t.isStruct || t.isTuple || t.isVariant + + if (writeSchemaConverter.typeV3Format) { + val column = tableSchema.getColumnSchema(i) + val t = column.getTypeV3 + if (t.isOptional) { + val inner = t.asOptional().getItem + if (inner.isOptional || isComposite(inner)) { + ColumnValueType.COMPOSITE + } else { + column.getType + } + } else if (isComposite(t)) { + ColumnValueType.COMPOSITE + } else { + column.getType + } + } else { + tableSchema.getColumnType(i) + } + } + + override def serializeRow(row: InternalRow, + writeable: WireProtocolWriteable, + keyFieldsOnly: Boolean, + aggregate: Boolean, + idMapping: Array[Int]): Unit = { + writeable.writeValueCount(row.numFields) + for { + i <- 0 until row.numFields + } { + if (row.isNullAt(i)) { + writeable.writeValueHeader(valueId(i, idMapping), ColumnValueType.NULL, aggregate, 0) + } else { + val sparkField = schema(i) + val ytFieldHint = if (writeSchemaConverter.typeV3Format) Some(tableSchema.getColumnSchema(i).getTypeV3) else None + sparkField.dataType match { + case BinaryType => + writeBytes(writeable, idMapping, aggregate, i, row.getBinary(i), getColumnType) + case StringType => + writeBytes(writeable, idMapping, aggregate, i, row.getUTF8String(i).getBytes, getColumnType) + case d: DecimalType => + val value = row.getDecimal(i, d.precision, d.scale) + if (writeSchemaConverter.typeV3Format) { + val binary = decimalToBinary(ytFieldHint, d, value) + writeBytes(writeable, idMapping, aggregate, i, binary, getColumnType) + } else { + val targetColumnType = getColumnType(i) + targetColumnType match { + case ColumnValueType.INT64 | ColumnValueType.UINT64 | ColumnValueType.DOUBLE | ColumnValueType.STRING => + writeHeader(writeable, idMapping, aggregate, i, 0, _ => targetColumnType) + targetColumnType match { + case ColumnValueType.INT64 | ColumnValueType.UINT64 => + writeable.onInteger(value.toLong) + case ColumnValueType.DOUBLE => + writeable.onDouble(value.toDouble) + case ColumnValueType.STRING => + writeable.onBytes(value.toString().getBytes) + } + case _ => + throw new IllegalArgumentException("Writing decimal type without enabled type_v3 is not supported") + } + } + case t@(ArrayType(_, _) | StructType(_) | MapType(_, _, _)) => + val skipNulls = sparkField.metadata.contains("skipNulls") && sparkField.metadata.getBoolean("skipNulls") + writeBytes(writeable, idMapping, aggregate, i, + YsonEncoder.encode(row.get(i, sparkField.dataType), t, skipNulls, writeSchemaConverter.typeV3Format, ytFieldHint), + getColumnType) + case otherType => + val isExtendedType = YTsaurusTypes + .instance + .wireWriteRow(otherType, row, writeable, aggregate, idMapping, i, getColumnType) + if (!isExtendedType) { + writeHeader(writeable, idMapping, aggregate, i, 0, getColumnType) + otherType match { + case ByteType => writeable.onInteger(row.getByte(i)) + case ShortType => writeable.onInteger(row.getShort(i)) + case IntegerType => writeable.onInteger(row.getInt(i)) + case LongType => writeable.onInteger(row.getLong(i)) + case BooleanType => writeable.onBoolean(row.getBoolean(i)) + case FloatType => writeable.onDouble(row.getFloat(i)) + case DoubleType => writeable.onDouble(row.getDouble(i)) + case DateType => writeable.onInteger(row.getLong(i)) + case _: DatetimeType => writeable.onInteger(row.getLong(i)) + case TimestampType => writeable.onInteger(row.getLong(i)) + case _: Date32Type => writeable.onInteger(row.getInt(i)) + case _: Datetime64Type => writeable.onInteger(row.getLong(i)) + case _: Timestamp64Type => writeable.onInteger(row.getLong(i)) + case _: Interval64Type => writeable.onInteger(row.getLong(i)) + } + } + } + } + } + } +} + object InternalRowSerializer { + private val deserializers: ThreadLocal[mutable.Map[StructType, InternalRowSerializer]] = ThreadLocal.withInitial(() => mutable.ListMap.empty) private val context = ExecutionContext.fromExecutor(Executors.newFixedThreadPool(4)) + def getOrCreate(schema: StructType, + schemaHint: Map[String, YtLogicalType], + filters: Array[Filter] = Array.empty, + typeV3Format: Boolean = false): InternalRowSerializer = { + deserializers.get().getOrElseUpdate(schema, new InternalRowSerializer(schema, new WriteSchemaConverter(schemaHint, typeV3Format))) + } + final def writeRows(writer: TableWriter[InternalRow], rows: java.util.ArrayList[InternalRow], timeout: Duration): Future[Unit] = { @@ -32,6 +159,23 @@ object InternalRowSerializer { writeRowsRecursive(writer, rows, timeout) } } + + private def valueId(id: Int, idMapping: Array[Int]): Int = { + if (idMapping != null) { + idMapping(id) + } else id + } + + def writeHeader(writeable: WireProtocolWriteable, idMapping: Array[Int], aggregate: Boolean, + i: Int, length: Int, getColumnType: Int => ColumnValueType): Unit = { + writeable.writeValueHeader(valueId(i, idMapping), getColumnType(i), aggregate, length) + } + + def writeBytes(writeable: WireProtocolWriteable, idMapping: Array[Int], aggregate: Boolean, + i: Int, bytes: Array[Byte], getColumnType: Int => ColumnValueType): Unit = { + writeHeader(writeable, idMapping, aggregate, i, bytes.length, getColumnType) + writeable.onBytes(bytes) + } } diff --git a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalType.scala b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalType.scala index c2668dff..ca6ad416 100644 --- a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalType.scala +++ b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalType.scala @@ -1,31 +1,12 @@ package tech.ytsaurus.spyt.serializers -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.spyt.types._ import org.apache.spark.sql.types._ -import tech.ytsaurus.client.{InternalRowYTGetters, YTGetters} +import tech.ytsaurus.client.InternalRowYTGetters import tech.ytsaurus.core.tables.ColumnValueType import tech.ytsaurus.spyt.serializers.SchemaConverter.MetadataFields -import tech.ytsaurus.spyt.serializers.YsonRowConverter.{isNull, serializeValue} -import tech.ytsaurus.spyt.serializers.YtLogicalType.Binary.tiType -import tech.ytsaurus.spyt.serializers.YtLogicalType.Boolean.tiType -import tech.ytsaurus.spyt.serializers.YtLogicalType.Date.tiType -import tech.ytsaurus.spyt.serializers.YtLogicalType.Datetime.tiType -import tech.ytsaurus.spyt.serializers.YtLogicalType.Double.tiType -import tech.ytsaurus.spyt.serializers.YtLogicalType.Float.tiType -import tech.ytsaurus.spyt.serializers.YtLogicalType.Int16.tiType -import tech.ytsaurus.spyt.serializers.YtLogicalType.Int32.tiType -import tech.ytsaurus.spyt.serializers.YtLogicalType.Int64.tiType -import tech.ytsaurus.spyt.serializers.YtLogicalType.Int8.tiType -import tech.ytsaurus.spyt.serializers.YtLogicalType.Interval.tiType -import tech.ytsaurus.spyt.serializers.YtLogicalType.Null.tiType -import tech.ytsaurus.spyt.serializers.YtLogicalType.String.tiType -import tech.ytsaurus.spyt.serializers.YtLogicalType.Timestamp.tiType -import tech.ytsaurus.spyt.serializers.YtLogicalType.Uint32.tiType -import tech.ytsaurus.spyt.serializers.YtLogicalType.Uint64.tiType -import tech.ytsaurus.spyt.serializers.YtLogicalType.Utf8.tiType import tech.ytsaurus.typeinfo.StructType.Member import tech.ytsaurus.typeinfo.{TiType, TypeName} import tech.ytsaurus.yson.YsonConsumer From 898a56eb1e27c3d93e5a9670910f221492a632dc Mon Sep 17 00:00:00 2001 From: Nikita Sokolov Date: Tue, 19 Nov 2024 16:39:37 +0100 Subject: [PATCH 03/12] fix SchemaConverterTest --- .../scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java b/data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java index 61ac223b..bf3174a5 100644 --- a/data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java +++ b/data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java @@ -177,6 +177,7 @@ private Field field(String name, ArrowType arrowType) { private ArrowGetterFromList nonComplexArrowGetter(String name, Getters.FromList getter) { var tiType = getter.getTiType(); switch (tiType.getTypeName()) { + case Utf8: case String: { var stringGetter = (Getters.FromListToString) getter; return new ArrowGetterFromList( @@ -630,6 +631,7 @@ void setFromList(List list, int i) { private ArrowGetterFromStruct nonComplexArrowGetter(String name, Getters.FromStruct getter) { var tiType = getter.getTiType(); switch (tiType.getTypeName()) { + case Utf8: case String: { var stringGetter = (Getters.FromStructToString) getter; return new ArrowGetterFromStruct(field(name, new ArrowType.Binary())) { From 089017c7d8fd2f49a8160da7af03189580c80ffe Mon Sep 17 00:00:00 2001 From: Nikita Sokolov Date: Tue, 19 Nov 2024 17:14:49 +0100 Subject: [PATCH 04/12] fix YtInputSplitTest --- .../scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java | 1 + 1 file changed, 1 insertion(+) diff --git a/data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java b/data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java index bf3174a5..19cb226f 100644 --- a/data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java +++ b/data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java @@ -1166,6 +1166,7 @@ protected void writeRows(ByteBuf buf, TRowsetDescriptor descriptor, java.util.Li try (var batch = unloader.getRecordBatch()) { MessageSerializer.serialize(writeChannel, batch); } + writeChannel.writeZeros(4); } catch (IOException e) { throw new RuntimeException(e); } From 3552200f40bae8e38ad546c3897f2bdf7d41382e Mon Sep 17 00:00:00 2001 From: Nikita Sokolov Date: Wed, 20 Nov 2024 12:41:58 +0100 Subject: [PATCH 05/12] fix UInt64DecimalTest --- .../ytsaurus/spyt/format/YtOutputWriter.scala | 10 +- .../serializers/WriteSchemaConverter.scala | 5 +- .../spyt/serializers/YtLogicalType.scala | 227 ++++++++++-------- 3 files changed, 136 insertions(+), 106 deletions(-) diff --git a/data-source/src/main/scala/tech/ytsaurus/spyt/format/YtOutputWriter.scala b/data-source/src/main/scala/tech/ytsaurus/spyt/format/YtOutputWriter.scala index 00321da9..fb0f42f6 100644 --- a/data-source/src/main/scala/tech/ytsaurus/spyt/format/YtOutputWriter.scala +++ b/data-source/src/main/scala/tech/ytsaurus/spyt/format/YtOutputWriter.scala @@ -154,15 +154,17 @@ class YtOutputWriter(richPath: YPathEnriched, val appendPath = richPath.withAttr("append", "true").toYPath log.debugLazy(s"Initialize new write: $appendPath, transaction: $transactionGuid") val internalRowGetters = new InternalRowYTGetters() + val writeSchemaConverter = WriteSchemaConverter(options) val request = WriteTable.builder[InternalRow]() .setPath(appendPath) .setSerializationContext( if (options.ytConf(ArrowWriteEnabled)) new ArrowWriteSerializationContext[InternalRow, ArrayData, MapData, InternalRowYTGetters]( - WriteSchemaConverter(options).ytLogicalTypeStruct(schema).fields.zipWithIndex.map { - case ((name, ytLogicalType, _), i) => - util.Map.entry(name, ytLogicalType.ytGettersFromStruct(internalRowGetters, i)) - }.asJava + schema.fields.zipWithIndex.map { case (field, i) => + util.Map.entry(field.name, writeSchemaConverter.ytLogicalTypeV3(field).ytGettersFromStruct( + internalRowGetters, field.dataType, i + )) + }.toSeq.asJava ) else new WriteSerializationContext(new InternalRowSerializer(schema, WriteSchemaConverter(options))) diff --git a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/WriteSchemaConverter.scala b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/WriteSchemaConverter.scala index e8f992da..4c5813bd 100644 --- a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/WriteSchemaConverter.scala +++ b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/WriteSchemaConverter.scala @@ -56,9 +56,12 @@ class WriteSchemaConverter( case FloatType => YtLogicalType.Float case DoubleType => YtLogicalType.Double case BooleanType => YtLogicalType.Boolean - case d: DecimalType => + case d: DecimalType => if (hint != null) { + hint + } else { val dT = if (d.precision > 35) applyYtLimitToSparkDecimal(d) else d YtLogicalType.Decimal(dT.precision, dT.scale, d) + } case aT: ArrayType => YtLogicalType.Array(wrapSparkAttributes(ytLogicalTypeV3(aT.elementType), aT.containsNull)) case _: StructType if hint != null => hint diff --git a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalType.scala b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalType.scala index ca6ad416..98d2b751 100644 --- a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalType.scala +++ b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalType.scala @@ -3,6 +3,7 @@ package tech.ytsaurus.spyt.serializers import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.spyt.types._ +import org.apache.spark.sql.types import org.apache.spark.sql.types._ import tech.ytsaurus.client.InternalRowYTGetters import tech.ytsaurus.core.tables.ColumnValueType @@ -59,9 +60,9 @@ sealed trait YtLogicalType { def arrowSupported: Boolean = true - def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList + def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList - def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct + def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct } sealed trait YtLogicalTypeAlias { @@ -104,7 +105,7 @@ object YtLogicalType { import tech.ytsaurus.spyt.types.YTsaurusTypes.instance.sparkTypeFor case object Null extends AtomicYtLogicalType("null", 0x02, ColumnValueType.NULL, TiType.nullType(), NullType) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToNull { + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToNull { override def getTiType: TiType = tiType override def getSize(list: ArrayData): Int = list.numElements() @@ -112,7 +113,7 @@ object YtLogicalType { override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onEntity() } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToNull { override def getTiType: TiType = tiType @@ -121,7 +122,7 @@ object YtLogicalType { } case object Int64 extends AtomicYtLogicalType("int64", 0x03, ColumnValueType.INT64, TiType.int64(), LongType) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToLong { + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToLong { override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -131,7 +132,7 @@ object YtLogicalType { override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(list.getLong(i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToLong { + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToLong { override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) override def getTiType: TiType = tiType @@ -141,22 +142,42 @@ object YtLogicalType { } case object Uint64 extends AtomicYtLogicalType("uint64", 0x04, ColumnValueType.UINT64, TiType.uint64(), sparkTypeFor(TiType.uint64())) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToLong { - override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = dataType match { + case decimalType: DecimalType => new ytGetter.FromListToLong { + override def getLong(list: ArrayData, i: Int): Long = list.getDecimal(i, decimalType.precision, decimalType.scale).toLong - override def getSize(list: ArrayData): Int = list.numElements() + override def getSize(list: ArrayData): Int = list.numElements() - override def getTiType: TiType = tiType + override def getTiType: TiType = tiType + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getLong(list, i)) + } + case _ => new ytGetter.FromListToLong { + override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) + + override def getSize(list: ArrayData): Int = list.numElements() - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getLong(list, i)) + override def getTiType: TiType = tiType + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getLong(list, i)) + } } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToLong { - override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = dataType match { + case decimalType: DecimalType => new ytGetter.FromStructToLong { + override def getLong(struct: InternalRow): Long = struct.getDecimal(ordinal, decimalType.precision, decimalType.scale).toLong - override def getTiType: TiType = tiType + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getLong(struct)) + } + case _ => new ytGetter.FromStructToLong { + override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getLong(struct)) + override def getTiType: TiType = tiType + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getLong(struct)) + } } } @@ -164,7 +185,7 @@ object YtLogicalType { "float", 0x05, ColumnValueType.DOUBLE, TiType.floatType(), TopInnerSparkTypes(FloatType, DoubleType), Seq.empty, arrowSupported = false, ) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToFloat { + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToFloat { override def getFloat(list: ArrayData, i: Int): Float = list.getFloat(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -174,7 +195,7 @@ object YtLogicalType { override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onDouble(getFloat(list, i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToFloat { + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToFloat { override def getFloat(struct: InternalRow): Float = struct.getFloat(ordinal) override def getTiType: TiType = tiType @@ -184,7 +205,7 @@ object YtLogicalType { } case object Double extends AtomicYtLogicalType("double", 0x05, ColumnValueType.DOUBLE, TiType.doubleType(), DoubleType) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToDouble { + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToDouble { override def getDouble(list: ArrayData, i: Int): Double = list.getDouble(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -194,7 +215,7 @@ object YtLogicalType { override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onDouble(getDouble(list, i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToDouble { + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToDouble { override def getDouble(struct: InternalRow): Double = struct.getDouble(ordinal) override def getTiType: TiType = tiType @@ -204,7 +225,7 @@ object YtLogicalType { } case object Boolean extends AtomicYtLogicalType("boolean", 0x06, ColumnValueType.BOOLEAN, TiType.bool(), BooleanType, Seq("bool")) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToBoolean { + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToBoolean { override def getBoolean(list: ArrayData, i: Int): Boolean = list.getBoolean(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -215,7 +236,7 @@ object YtLogicalType { ysonConsumer.onBoolean(list.getBoolean(i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToBoolean { + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToBoolean { override def getBoolean(struct: InternalRow): Boolean = struct.getBoolean(ordinal) override def getTiType: TiType = tiType @@ -232,7 +253,7 @@ object YtLogicalType { } case object String extends AtomicYtLogicalType("string", 0x10, ColumnValueType.STRING, TiType.string(), StringType) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToString { + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToString { override def getString(list: ArrayData, i: Int): ByteBuffer = list.getUTF8String(i).getByteBuffer override def getSize(list: ArrayData): Int = list.numElements() @@ -245,7 +266,7 @@ object YtLogicalType { } } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToString { + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToString { override def getString(struct: InternalRow): ByteBuffer = struct.getUTF8String(ordinal).getByteBuffer override def getTiType: TiType = tiType @@ -264,7 +285,7 @@ object YtLogicalType { if (inner) alias.name else "string" } - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToString { + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToString { override def getString(list: ArrayData, i: Int): ByteBuffer = ByteBuffer.wrap(list.getBinary(i)) override def getSize(list: ArrayData): Int = list.numElements() @@ -279,7 +300,7 @@ object YtLogicalType { } } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToString { + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToString { override def getString(struct: InternalRow): ByteBuffer = ByteBuffer.wrap(struct.getBinary(ordinal)) override def getTiType: TiType = tiType @@ -296,7 +317,7 @@ object YtLogicalType { case object Any extends AtomicYtLogicalType("any", 0x11, ColumnValueType.ANY, TiType.yson(), sparkTypeFor(TiType.yson()), Seq("yson")) { override def nullable: Boolean = true - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToYson { + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToYson { override def getSize(list: ArrayData): Int = list.numElements() override def getTiType: TiType = tiType @@ -305,7 +326,7 @@ object YtLogicalType { YTreeBinarySerializer.deserialize(new ByteArrayInputStream(list.getBinary(i)), ysonConsumer) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToYson { + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToYson { override def getTiType: TiType = tiType override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = @@ -314,7 +335,7 @@ object YtLogicalType { } case object Int8 extends AtomicYtLogicalType("int8", 0x1000, ColumnValueType.INT64, TiType.int8(), ByteType) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToByte { + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToByte { override def getByte(list: ArrayData, i: Int): Byte = list.getByte(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -324,7 +345,7 @@ object YtLogicalType { override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(list.getByte(i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToByte { + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToByte { override def getByte(struct: InternalRow): Byte = struct.getByte(ordinal) override def getTiType: TiType = tiType @@ -334,7 +355,7 @@ object YtLogicalType { } case object Uint8 extends AtomicYtLogicalType("uint8", 0x1001, ColumnValueType.INT64, TiType.uint8(), ShortType) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToByte { + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToByte { override def getByte(list: ArrayData, i: Int): Byte = list.getShort(i).toByte override def getSize(list: ArrayData): Int = list.numElements() @@ -344,7 +365,7 @@ object YtLogicalType { override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getByte(list, i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToByte { + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToByte { override def getByte(struct: InternalRow): Byte = struct.getShort(ordinal).toByte override def getTiType: TiType = tiType @@ -354,7 +375,7 @@ object YtLogicalType { } case object Int16 extends AtomicYtLogicalType("int16", 0x1003, ColumnValueType.INT64, TiType.int16(), ShortType) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToShort { + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToShort { override def getShort(list: ArrayData, i: Int): Short = list.getShort(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -364,7 +385,7 @@ object YtLogicalType { override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(list.getShort(i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToShort { + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToShort { override def getShort(struct: InternalRow): Short = struct.getShort(ordinal) override def getTiType: TiType = tiType @@ -374,7 +395,7 @@ object YtLogicalType { } case object Uint16 extends AtomicYtLogicalType("uint16", 0x1004, ColumnValueType.INT64, TiType.uint16(), IntegerType) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToShort { + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToShort { override def getShort(list: ArrayData, i: Int): Short = list.getInt(i).toShort override def getSize(list: ArrayData): Int = list.numElements() @@ -384,7 +405,7 @@ object YtLogicalType { override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getShort(list, i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToShort { + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToShort { override def getShort(struct: InternalRow): Short = struct.getInt(ordinal).toShort override def getTiType: TiType = tiType @@ -394,7 +415,7 @@ object YtLogicalType { } case object Int32 extends AtomicYtLogicalType("int32", 0x1005, ColumnValueType.INT64, TiType.int32(), IntegerType) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToInt { + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToInt { override def getInt(list: ArrayData, i: Int): Int = list.getInt(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -404,7 +425,7 @@ object YtLogicalType { override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(list.getInt(i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToInt { + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToInt { override def getInt(struct: InternalRow): Int = struct.getInt(ordinal) override def getTiType: TiType = tiType @@ -414,7 +435,7 @@ object YtLogicalType { } case object Uint32 extends AtomicYtLogicalType("uint32", 0x1006, ColumnValueType.INT64, TiType.uint32(), LongType) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToInt { + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToInt { override def getInt(list: ArrayData, i: Int): Int = list.getLong(i).toInt override def getSize(list: ArrayData): Int = list.numElements() @@ -424,7 +445,7 @@ object YtLogicalType { override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getInt(list, i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToInt { + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToInt { override def getInt(struct: InternalRow): Int = struct.getLong(ordinal).toInt override def getTiType: TiType = tiType @@ -434,7 +455,7 @@ object YtLogicalType { } case object Utf8 extends AtomicYtLogicalType("utf8", 0x1007, ColumnValueType.STRING, TiType.utf8(), StringType) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToString { + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToString { override def getString(list: ArrayData, i: Int): ByteBuffer = list.getUTF8String(i).getByteBuffer override def getSize(list: ArrayData): Int = list.numElements() @@ -447,7 +468,7 @@ object YtLogicalType { } } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToString { + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToString { override def getString(struct: InternalRow): ByteBuffer = struct.getUTF8String(ordinal).getByteBuffer override def getTiType: TiType = tiType @@ -461,7 +482,7 @@ object YtLogicalType { // Unsupported types are listed here: yt/yt/client/arrow/arrow_row_stream_encoder.cpp case object Date extends AtomicYtLogicalType("date", 0x1008, ColumnValueType.UINT64, TiType.date(), DateType, arrowSupported = false) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToInt { + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToInt { override def getInt(list: ArrayData, i: Int): Int = list.getInt(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -472,7 +493,7 @@ object YtLogicalType { ysonConsumer.onUnsignedInteger(getInt(list, i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToInt { + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToInt { override def getInt(struct: InternalRow): Int = struct.getInt(ordinal) override def getTiType: TiType = tiType @@ -483,7 +504,7 @@ object YtLogicalType { } case object Datetime extends AtomicYtLogicalType("datetime", 0x1009, ColumnValueType.UINT64, TiType.datetime(), new DatetimeType(), arrowSupported = false) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToLong { + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToLong { override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -494,7 +515,7 @@ object YtLogicalType { ysonConsumer.onUnsignedInteger(getLong(list, i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToLong { + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToLong { override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) override def getTiType: TiType = tiType @@ -505,7 +526,7 @@ object YtLogicalType { } case object Timestamp extends AtomicYtLogicalType("timestamp", 0x100a, ColumnValueType.UINT64, TiType.timestamp(), TimestampType, arrowSupported = false) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToLong { + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToLong { override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -516,7 +537,7 @@ object YtLogicalType { ysonConsumer.onUnsignedInteger(getLong(list, i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToLong { + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToLong { override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) override def getTiType: TiType = tiType @@ -527,7 +548,7 @@ object YtLogicalType { } case object Interval extends AtomicYtLogicalType("interval", 0x100b, ColumnValueType.INT64, TiType.interval(), LongType, arrowSupported = false) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToLong { + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToLong { override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -538,7 +559,7 @@ object YtLogicalType { ysonConsumer.onInteger(getLong(list, i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToLong { + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToLong { override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) override def getTiType: TiType = tiType @@ -549,7 +570,7 @@ object YtLogicalType { } case object Void extends AtomicYtLogicalType("void", 0x100c, ColumnValueType.NULL, TiType.voidType(), NullType) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToNull { + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToNull { override def getTiType: TiType = tiType override def getSize(list: ArrayData): Int = list.numElements() @@ -557,7 +578,7 @@ object YtLogicalType { override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onEntity() } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToNull { override def getTiType: TiType = tiType @@ -566,7 +587,7 @@ object YtLogicalType { } case object Date32 extends AtomicYtLogicalType("date32", 0x1018, ColumnValueType.INT64, TiType.date32(), new Date32Type(), arrowSupported = false) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToInt { + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToInt { override def getInt(list: ArrayData, i: Int): Int = list.getInt(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -577,7 +598,7 @@ object YtLogicalType { ysonConsumer.onUnsignedInteger(getInt(list, i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToInt { + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToInt { override def getInt(struct: InternalRow): Int = struct.getInt(ordinal) override def getTiType: TiType = tiType @@ -588,7 +609,7 @@ object YtLogicalType { } case object Datetime64 extends AtomicYtLogicalType("datetime64", 0x1019, ColumnValueType.INT64, TiType.datetime64(), new Datetime64Type(), arrowSupported = false) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToLong { + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToLong { override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -599,7 +620,7 @@ object YtLogicalType { ysonConsumer.onUnsignedInteger(getLong(list, i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToLong { + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToLong { override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) override def getTiType: TiType = tiType @@ -610,7 +631,7 @@ object YtLogicalType { } case object Timestamp64 extends AtomicYtLogicalType("timestamp64", 0x101a, ColumnValueType.INT64, TiType.timestamp64(), new Timestamp64Type(), arrowSupported = false) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToLong { + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToLong { override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -621,7 +642,7 @@ object YtLogicalType { ysonConsumer.onUnsignedInteger(getLong(list, i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToLong { + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToLong { override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) override def getTiType: TiType = tiType @@ -632,7 +653,7 @@ object YtLogicalType { } case object Interval64 extends AtomicYtLogicalType("interval64", 0x101b, ColumnValueType.INT64, TiType.interval64(), new Interval64Type(), arrowSupported = false) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToLong { + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToLong { override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -643,7 +664,7 @@ object YtLogicalType { ysonConsumer.onInteger(getLong(list, i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToLong { + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToLong { override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) override def getTiType: TiType = tiType @@ -660,7 +681,7 @@ object YtLogicalType { override def tiType: TiType = TiType.decimal(precision, scale) - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToBigDecimal { + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToBigDecimal { override def getBigDecimal(list: ArrayData, i: Int): java.math.BigDecimal = list.getDecimal(i, decimalType.precision, decimalType.scale).toJavaBigDecimal.setScale(scale) @@ -674,7 +695,7 @@ object YtLogicalType { } } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToBigDecimal { + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToBigDecimal { override def getBigDecimal(struct: InternalRow): java.math.BigDecimal = struct.getDecimal(ordinal, decimalType.precision, decimalType.scale).toJavaBigDecimal.setScale(scale) @@ -706,8 +727,8 @@ object YtLogicalType { override def arrowSupported: Boolean = inner.arrowSupported - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToOptional { - private val notEmptyGetter = inner.ytGettersFromList(ytGetter) + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToOptional { + private val notEmptyGetter = inner.ytGettersFromList(ytGetter, dataType) override def getNotEmptyGetter: ytGetter.FromList = notEmptyGetter @@ -731,8 +752,8 @@ object YtLogicalType { } } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToOptional { - private val notEmptyGetter = inner.ytGettersFromStruct(ytGetter, ordinal) + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToOptional { + private val notEmptyGetter = inner.ytGettersFromStruct(ytGetter, dataType, ordinal) override def getNotEmptyGetter: ytGetter.FromStruct = notEmptyGetter @@ -767,9 +788,9 @@ object YtLogicalType { MapType(dictKey.sparkType.innerLevel, dictValue.sparkType.innerLevel, dictValue.nullable) ) - private def newGetter(ytGetter: InternalRowYTGetters): ytGetter.FromDict = new ytGetter.FromDict { - private val keyGetter = dictKey.ytGettersFromList(ytGetter) - private val valueGetter = dictValue.ytGettersFromList(ytGetter) + private def newGetter(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromDict = new ytGetter.FromDict { + private val keyGetter = dictKey.ytGettersFromList(ytGetter, dataType.asInstanceOf[MapType].keyType) + private val valueGetter = dictValue.ytGettersFromList(ytGetter, dataType.asInstanceOf[MapType].valueType) override def getKeyGetter: ytGetter.FromList = keyGetter @@ -808,8 +829,8 @@ object YtLogicalType { override def alias: CompositeYtLogicalTypeAlias = Dict - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToDict { - private val getter = newGetter(ytGetter) + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToDict { + private val getter = newGetter(ytGetter, dataType) private val ysonSerializer = newYsonSerializer(getter) override def getGetter(): ytGetter.FromDict = getter @@ -824,8 +845,8 @@ object YtLogicalType { ysonSerializer(list.getMap(i), ysonConsumer) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToDict { - private val getter = newGetter(ytGetter) + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToDict { + private val getter = newGetter(ytGetter, dataType) private val ysonSerializer = newYsonSerializer(getter) override def getGetter(): ytGetter.FromDict = getter @@ -849,8 +870,8 @@ object YtLogicalType { override def alias: CompositeYtLogicalTypeAlias = Array - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToList { - val elementGetter: ytGetter.FromList = inner.ytGettersFromList(ytGetter) + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToList { + val elementGetter: ytGetter.FromList = inner.ytGettersFromList(ytGetter, dataType.asInstanceOf[ArrayType].elementType) override def getSize(list: ArrayData): Int = list.numElements() @@ -871,8 +892,8 @@ object YtLogicalType { } } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToList { - val elementGetter: ytGetter.FromList = inner.ytGettersFromList(ytGetter) + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToList { + val elementGetter: ytGetter.FromList = inner.ytGettersFromList(ytGetter, dataType.asInstanceOf[ArrayType].elementType) override def getElementGetter: ytGetter.FromList = elementGetter @@ -906,9 +927,9 @@ object YtLogicalType { override def alias: CompositeYtLogicalTypeAlias = Struct - def newMembersGetters(ytGetter: InternalRowYTGetters): java.util.List[java.util.Map.Entry[String, ytGetter.FromStruct]] = - fields.zipWithIndex.map { case (field, i) => - java.util.Map.entry(field._1, field._2.ytGettersFromStruct(ytGetter, i)) + def newMembersGetters(ytGetter: InternalRowYTGetters, dataType: DataType): java.util.List[java.util.Map.Entry[String, ytGetter.FromStruct]] = + fields.zip(dataType.asInstanceOf[StructType].fields).zipWithIndex.map { case ((field, structField), i) => + java.util.Map.entry(field._1, field._2.ytGettersFromStruct(ytGetter, structField.dataType, i)) }.asJava def yson(ytGetter: InternalRowYTGetters)( @@ -923,8 +944,8 @@ object YtLogicalType { ysonConsumer.onEndList() } - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToStruct { - private val membersGetters = newMembersGetters(ytGetter) + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToStruct { + private val membersGetters = newMembersGetters(ytGetter, dataType) override def getMembersGetters(): java.util.List[java.util.Map.Entry[String, ytGetter.FromStruct]] = membersGetters @@ -939,8 +960,8 @@ object YtLogicalType { yson(ytGetter)(membersGetters, list.getStruct(i, membersGetters.size()), ysonConsumer) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToStruct { - private val membersGetters = newMembersGetters(ytGetter) + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToStruct { + private val membersGetters = newMembersGetters(ytGetter, dataType) override def getMembersGetters(): java.util.List[java.util.Map.Entry[String, ytGetter.FromStruct]] = membersGetters @@ -969,9 +990,10 @@ object YtLogicalType { override def alias: CompositeYtLogicalTypeAlias = Tuple - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToStruct { - private val membersGetters = entries.zipWithIndex.map { case ((name, logicalType), i) => - java.util.Map.entry(name, logicalType.ytGettersFromStruct(ytGetter, i)) + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToStruct { + private val membersGetters = entries.zip(dataType.asInstanceOf[types.StructType]).zipWithIndex.map { + case (((name, logicalType), structField), i) => + java.util.Map.entry(name, logicalType.ytGettersFromStruct(ytGetter, structField.dataType, i)) }.asJava override def getMembersGetters(): java.util.List[java.util.Map.Entry[String, ytGetter.FromStruct]] = membersGetters @@ -993,9 +1015,10 @@ object YtLogicalType { } } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToStruct { - private val membersGetters = entries.zipWithIndex.map { case ((name, logicalType), i) => - java.util.Map.entry(name, logicalType.ytGettersFromStruct(ytGetter, i)) + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToStruct { + private val membersGetters = entries.zip(dataType.asInstanceOf[types.StructType]).zipWithIndex.map { + case (((name, logicalType), structField), i) => + java.util.Map.entry(name, logicalType.ytGettersFromStruct(ytGetter, structField.dataType, i)) }.asJava override def getMembersGetters(): java.util.List[java.util.Map.Entry[String, ytGetter.FromStruct]] = membersGetters @@ -1025,15 +1048,17 @@ object YtLogicalType { override def alias: CompositeYtLogicalTypeAlias = Tagged - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = inner.ytGettersFromList(ytGetter) + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = inner.ytGettersFromList(ytGetter, dataType) - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = inner.ytGettersFromStruct(ytGetter, ordinal) + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = inner.ytGettersFromStruct(ytGetter, dataType, ordinal) } case object Tagged extends CompositeYtLogicalTypeAlias(TypeName.Tagged.getWireName) - private class VariantGetter(fields: Seq[YtLogicalType], ytGetter: InternalRowYTGetters) { - private val getters = fields.zipWithIndex.map { case (field, i) => field.ytGettersFromStruct(ytGetter, i) } + private class VariantGetter(fields: Seq[YtLogicalType], ytGetter: InternalRowYTGetters, dataType: DataType) { + private val getters = fields.zip(dataType.asInstanceOf[StructType].fields).zipWithIndex.map { + case ((field, structField), i) => field.ytGettersFromStruct(ytGetter, structField.dataType, i) + } def get(row: InternalRow, ysonConsumer: YsonConsumer): Unit = { val notNulls = (0 until row.numFields).filter(!row.isNullAt(_)) @@ -1066,8 +1091,8 @@ object YtLogicalType { override def alias: CompositeYtLogicalTypeAlias = Variant - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToYson { - val getter = new VariantGetter(fields.map(_._2), ytGetter) + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToYson { + val getter = new VariantGetter(fields.map(_._2), ytGetter, dataType) override def getSize(list: ArrayData): Int = list.numElements() @@ -1077,8 +1102,8 @@ object YtLogicalType { getter.get(list.getStruct(i, fields.size), ysonConsumer) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToYson { - val getter = new VariantGetter(fields.map(_._2), ytGetter) + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToYson { + val getter = new VariantGetter(fields.map(_._2), ytGetter, dataType) override def getTiType: TiType = tiType @@ -1102,8 +1127,8 @@ object YtLogicalType { override def alias: CompositeYtLogicalTypeAlias = Variant - override def ytGettersFromList(ytGetter: InternalRowYTGetters): ytGetter.FromList = new ytGetter.FromListToYson { - val getter = new VariantGetter(fields.map(_._1), ytGetter) + override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToYson { + val getter = new VariantGetter(fields.map(_._1), ytGetter, dataType) override def getSize(list: ArrayData): Int = list.numElements() @@ -1113,8 +1138,8 @@ object YtLogicalType { getter.get(list.getStruct(i, fields.size), ysonConsumer) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToYson { - val getter = new VariantGetter(fields.map(_._1), ytGetter) + override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToYson { + val getter = new VariantGetter(fields.map(_._1), ytGetter, dataType) override def getTiType: TiType = tiType From 23a19a60379bbd4ceb92eaa2b48294170841b96d Mon Sep 17 00:00:00 2001 From: Nikita Sokolov Date: Wed, 20 Nov 2024 16:49:51 +0100 Subject: [PATCH 06/12] clean-up imports --- .../tech/ytsaurus/client/ArrowTableRowsSerializer.java | 2 -- .../scala/tech/ytsaurus/spyt/format/YtOutputWriter.scala | 7 +++++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java b/data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java index 19cb226f..38c31065 100644 --- a/data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java +++ b/data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java @@ -20,10 +20,8 @@ import tech.ytsaurus.rpcproxy.ERowsetFormat; import tech.ytsaurus.rpcproxy.TRowsetDescriptor; import tech.ytsaurus.spyt.format.batch.ArrowUtils; -import tech.ytsaurus.spyt.serialization.YsonDecoder; import tech.ytsaurus.typeinfo.DecimalType; import tech.ytsaurus.yson.YsonBinaryWriter; -import tech.ytsaurus.ysontree.YTreeBuilder; import java.io.ByteArrayOutputStream; import java.io.IOException; diff --git a/data-source/src/main/scala/tech/ytsaurus/spyt/format/YtOutputWriter.scala b/data-source/src/main/scala/tech/ytsaurus/spyt/format/YtOutputWriter.scala index fb0f42f6..731f18dc 100644 --- a/data-source/src/main/scala/tech/ytsaurus/spyt/format/YtOutputWriter.scala +++ b/data-source/src/main/scala/tech/ytsaurus/spyt/format/YtOutputWriter.scala @@ -158,7 +158,10 @@ class YtOutputWriter(richPath: YPathEnriched, val request = WriteTable.builder[InternalRow]() .setPath(appendPath) .setSerializationContext( - if (options.ytConf(ArrowWriteEnabled)) + if (options.ytConf(ArrowWriteEnabled)) { + if (!writeSchemaConverter.typeV3Format) { + throw new RuntimeException("arrow writer is only supported with typeV3") + } new ArrowWriteSerializationContext[InternalRow, ArrayData, MapData, InternalRowYTGetters]( schema.fields.zipWithIndex.map { case (field, i) => util.Map.entry(field.name, writeSchemaConverter.ytLogicalTypeV3(field).ytGettersFromStruct( @@ -166,7 +169,7 @@ class YtOutputWriter(richPath: YPathEnriched, )) }.toSeq.asJava ) - else + } else new WriteSerializationContext(new InternalRowSerializer(schema, WriteSchemaConverter(options))) ) .setTransactionalOptions(new TransactionalOptions(GUID.valueOf(transactionGuid))) From 56d9770358eee65e928d517cd214f9d6f32d26a5 Mon Sep 17 00:00:00 2001 From: Nikita Sokolov Date: Wed, 20 Nov 2024 16:50:13 +0100 Subject: [PATCH 07/12] disable arrow writing by default --- .../tech/ytsaurus/spyt/format/conf/YtTableSparkSettings.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data-source/src/main/scala/tech/ytsaurus/spyt/format/conf/YtTableSparkSettings.scala b/data-source/src/main/scala/tech/ytsaurus/spyt/format/conf/YtTableSparkSettings.scala index 25879ee1..7866e0b5 100644 --- a/data-source/src/main/scala/tech/ytsaurus/spyt/format/conf/YtTableSparkSettings.scala +++ b/data-source/src/main/scala/tech/ytsaurus/spyt/format/conf/YtTableSparkSettings.scala @@ -65,7 +65,7 @@ object YtTableSparkSettings { case object ArrowEnabled extends ConfigEntry[Boolean]("arrow_enabled", Some(true)) - case object ArrowWriteEnabled extends ConfigEntry[Boolean]("arrow_write_enabled", Some(true)) + case object ArrowWriteEnabled extends ConfigEntry[Boolean]("arrow_write_enabled", Some(false)) case object KeyPartitioned extends ConfigEntry[Boolean]("key_partitioned") From 1e6d0212b5494fc2a4444a47cb491542823ebb6e Mon Sep 17 00:00:00 2001 From: Nikita Sokolov Date: Wed, 20 Nov 2024 17:16:27 +0100 Subject: [PATCH 08/12] fix YtFileFormatTest --- .../client/ArrowTableRowsSerializer.java | 34 +++++++++++++++++++ .../serializers/WriteSchemaConverter.scala | 22 +++++------- 2 files changed, 42 insertions(+), 14 deletions(-) diff --git a/data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java b/data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java index 38c31065..14a47fc6 100644 --- a/data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java +++ b/data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java @@ -175,6 +175,23 @@ private Field field(String name, ArrowType arrowType) { private ArrowGetterFromList nonComplexArrowGetter(String name, Getters.FromList getter) { var tiType = getter.getTiType(); switch (tiType.getTypeName()) { + case Null: + case Void: { + return new ArrowGetterFromList( + new Field(name, new FieldType(false, new ArrowType.Null(), null), new ArrayList<>()) + ) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var nullVector = (NullVector) valueVector; + return new ArrowWriterFromList() { + @Override + void setFromList(List list, int i) { + nullVector.setValueCount(nullVector.getValueCount() + 1); + } + }; + } + }; + } case Utf8: case String: { var stringGetter = (Getters.FromListToString) getter; @@ -629,6 +646,23 @@ void setFromList(List list, int i) { private ArrowGetterFromStruct nonComplexArrowGetter(String name, Getters.FromStruct getter) { var tiType = getter.getTiType(); switch (tiType.getTypeName()) { + case Null: + case Void: { + return new ArrowGetterFromStruct( + new Field(name, new FieldType(false, new ArrowType.Null(), null), new ArrayList<>()) + ) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var nullVector = (NullVector) valueVector; + return new ArrowWriterFromStruct() { + @Override + void setFromStruct(Struct struct) { + nullVector.setValueCount(nullVector.getValueCount() + 1); + } + }; + } + }; + } case Utf8: case String: { var stringGetter = (Getters.FromStructToString) getter; diff --git a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/WriteSchemaConverter.scala b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/WriteSchemaConverter.scala index 4c5813bd..76d4e865 100644 --- a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/WriteSchemaConverter.scala +++ b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/WriteSchemaConverter.scala @@ -19,25 +19,20 @@ class WriteSchemaConverter( ) { private def ytLogicalTypeV3Variant(struct: StructType): YtLogicalType = { if (isVariantOverTuple(struct)) { - YtLogicalType.VariantOverTuple { - struct.fields.map(tF => - (wrapSparkAttributes(ytLogicalTypeV3(tF), tF.nullable, Some(tF.metadata)), tF.metadata)) - } + YtLogicalType.VariantOverTuple { struct.fields.map(tF => (ytLogicalTypeV3(tF), tF.metadata)) } } else { - YtLogicalType.VariantOverStruct { - struct.fields.map(sf => (sf.name.drop(2), - wrapSparkAttributes(ytLogicalTypeV3(sf), sf.nullable, Some(sf.metadata)), sf.metadata)) - } + YtLogicalType.VariantOverStruct { struct.fields.map(sf => (sf.name.drop(2), ytLogicalTypeV3(sf), sf.metadata)) } } } def ytLogicalTypeStruct(structType: StructType): YtLogicalType.Struct = YtLogicalType.Struct { - structType.fields.map(sf => (sf.name, - wrapSparkAttributes(ytLogicalTypeV3(sf), sf.nullable, Some(sf.metadata)), sf.metadata)) + structType.fields.map(sf => (sf.name, ytLogicalTypeV3(sf), sf.metadata)) } - def ytLogicalTypeV3(structField: StructField): YtLogicalType = - ytLogicalTypeV3(structField.dataType, hint.getOrElse(structField.name, null)) + def ytLogicalTypeV3(structField: StructField): YtLogicalType = wrapSparkAttributes( + ytLogicalTypeV3(structField.dataType, hint.getOrElse(structField.name, null)), + structField.nullable, Some(structField.metadata), + ) def ytLogicalTypeV3(sparkType: DataType, hint: YtLogicalType = null): YtLogicalType = sparkType match { case NullType => YtLogicalType.Null @@ -96,8 +91,7 @@ class WriteSchemaConverter( val builder = YTree.builder .beginMap .key("name").value(field.name) - val fieldType = hint.getOrElse(field.name, - wrapSparkAttributes(ytLogicalTypeV3(field), field.nullable, Some(field.metadata))) + val fieldType = hint.getOrElse(field.name, ytLogicalTypeV3(field)) if (typeV3Format) { builder .key("type_v3").value(serializeTypeV3(fieldType)) From dbd6cd96a0d8367c772b20eef8b8d4ced924053b Mon Sep 17 00:00:00 2001 From: Nikita Sokolov Date: Thu, 21 Nov 2024 11:51:58 +0100 Subject: [PATCH 09/12] YTGetters interfaces should be static --- .../client/ArrowTableRowsSerializer.java | 798 +++++++++--------- .../ArrowWriteSerializationContext.java | 8 +- .../ytsaurus/client/InternalRowYTGetters.java | 8 - .../ytsaurus/client/TableWriterBaseImpl.java | 2 +- .../scala/tech/ytsaurus/client/YTGetters.java | 165 ++-- .../ytsaurus/spyt/format/YtOutputWriter.scala | 8 +- .../spyt/serializers/YtLogicalType.scala | 314 +++---- 7 files changed, 660 insertions(+), 643 deletions(-) delete mode 100644 data-source/src/main/scala/tech/ytsaurus/client/InternalRowYTGetters.java diff --git a/data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java b/data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java index 14a47fc6..1f4d8d74 100644 --- a/data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java +++ b/data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java @@ -31,8 +31,8 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; -public class ArrowTableRowsSerializer> extends TableRowsSerializer implements AutoCloseable { - private abstract class ArrowGetterFromStruct { +public class ArrowTableRowsSerializer extends TableRowsSerializer implements AutoCloseable { + private static abstract class ArrowGetterFromStruct { public final Field field; public final ArrowType arrowType; @@ -46,14 +46,14 @@ public final ArrowType getArrowType() { return arrowType; } - public abstract ArrowWriterFromStruct writer(ValueVector valueVector); + public abstract ArrowWriterFromStruct writer(ValueVector valueVector); } - private abstract class ArrowWriterFromStruct { - abstract void setFromStruct(Struct struct); + private static abstract class ArrowWriterFromStruct { + abstract void setFromStruct(Row struct); } - private abstract class ArrowGetterFromList { + private static abstract class ArrowGetterFromList { public final Field field; public final ArrowType arrowType; @@ -66,50 +66,50 @@ public final ArrowType getArrowType() { return arrowType; } - public abstract ArrowWriterFromList writer(ValueVector valueVector); + public abstract ArrowWriterFromList writer(ValueVector valueVector); } - private abstract class ArrowWriterFromList { + private static abstract class ArrowWriterFromList { abstract void setFromList(List list, int i); } - private ArrowGetterFromList arrowGetter(String name, Getters.FromList getter) { + private ArrowGetterFromList arrowGetter(String name, YTGetters.FromList getter) { var optionalGetter = getter instanceof YTGetters.FromListToOptional - ? (Getters.FromListToOptional) getter + ? (YTGetters.FromListToOptional) getter : null; - var nonEmptyGetter = optionalGetter != null ? (Getters.FromList) optionalGetter.getNotEmptyGetter() : getter; + var nonEmptyGetter = optionalGetter != null ? optionalGetter.getNotEmptyGetter() : getter; var arrowGetter = nonComplexArrowGetter(name, nonEmptyGetter); if (arrowGetter != null) { - return optionalGetter == null ? arrowGetter : new ArrowGetterFromList(new Field(name, new FieldType( + return optionalGetter == null ? arrowGetter : new ArrowGetterFromList<>(new Field(name, new FieldType( true, arrowGetter.field.getType(), null ), arrowGetter.field.getChildren())) { @Override - public ArrowWriterFromList writer(ValueVector valueVector) { + public ArrowWriterFromList writer(ValueVector valueVector) { var nonOptionalWriter = arrowGetter.writer(valueVector); - return new ArrowWriterFromList() { + return new ArrowWriterFromList<>() { @Override - public void setFromList(List list, int i) { - nonOptionalWriter.setFromList(optionalGetter.isEmpty(list, i) ? null : list, i); + public void setFromList(Array array, int i) { + nonOptionalWriter.setFromList(optionalGetter.isEmpty(array, i) ? null : array, i); } }; } }; } - return new ArrowGetterFromList(new Field(name, new FieldType( + return new ArrowGetterFromList<>(new Field(name, new FieldType( optionalGetter != null, new ArrowType.Binary(), null ), new ArrayList<>())) { @Override - public ArrowWriterFromList writer(ValueVector valueVector) { + public ArrowWriterFromList writer(ValueVector valueVector) { var varBinaryVector = (VarBinaryVector) valueVector; - return new ArrowWriterFromList() { + return new ArrowWriterFromList<>() { @Override - public void setFromList(List list, int i) { - if (optionalGetter != null && optionalGetter.isEmpty(list, i)) { + public void setFromList(Array array, int i) { + if (optionalGetter != null && optionalGetter.isEmpty(array, i)) { varBinaryVector.setNull(varBinaryVector.getValueCount()); } else { var byteArrayOutputStream = new ByteArrayOutputStream(); try (var ysonBinaryWriter = new YsonBinaryWriter(byteArrayOutputStream)) { - nonEmptyGetter.getYson(list, i, ysonBinaryWriter); + nonEmptyGetter.getYson(array, i, ysonBinaryWriter); } varBinaryVector.set(varBinaryVector.getValueCount(), byteArrayOutputStream.toByteArray()); } @@ -120,20 +120,20 @@ public void setFromList(List list, int i) { }; } - private ArrowGetterFromStruct arrowGetter(String name, Getters.FromStruct getter) { + private ArrowGetterFromStruct arrowGetter(String name, YTGetters.FromStruct getter) { var optionalGetter = getter instanceof YTGetters.FromStructToOptional - ? (Getters.FromStructToOptional) getter + ? (YTGetters.FromStructToOptional) getter : null; - var nonEmptyGetter = optionalGetter != null ? (Getters.FromStruct) optionalGetter.getNotEmptyGetter() : getter; + var nonEmptyGetter = optionalGetter != null ? optionalGetter.getNotEmptyGetter() : getter; var arrowGetter = nonComplexArrowGetter(name, nonEmptyGetter); if (arrowGetter != null) { - return optionalGetter == null ? arrowGetter : new ArrowGetterFromStruct(new Field(name, new FieldType( + return optionalGetter == null ? arrowGetter : new ArrowGetterFromStruct<>(new Field(name, new FieldType( true, arrowGetter.field.getType(), null ), arrowGetter.field.getChildren())) { @Override - public ArrowWriterFromStruct writer(ValueVector valueVector) { + public ArrowWriterFromStruct writer(ValueVector valueVector) { var nonOptionalWriter = arrowGetter.writer(valueVector); - return new ArrowWriterFromStruct() { + return new ArrowWriterFromStruct<>() { @Override public void setFromStruct(Struct struct) { nonOptionalWriter.setFromStruct(optionalGetter.isEmpty(struct) ? null : struct); @@ -142,13 +142,13 @@ public void setFromStruct(Struct struct) { } }; } else { - return new ArrowGetterFromStruct(new Field(name, new FieldType( + return new ArrowGetterFromStruct<>(new Field(name, new FieldType( optionalGetter != null, new ArrowType.Binary(), null ), new ArrayList<>())) { @Override - public ArrowWriterFromStruct writer(ValueVector valueVector) { + public ArrowWriterFromStruct writer(ValueVector valueVector) { var varBinaryVector = (VarBinaryVector) valueVector; - return new ArrowWriterFromStruct() { + return new ArrowWriterFromStruct<>() { @Override void setFromStruct(Struct struct) { if (optionalGetter != null && optionalGetter.isEmpty(struct)) { @@ -172,20 +172,20 @@ private Field field(String name, ArrowType arrowType) { return new Field(name, new FieldType(false, arrowType, null), Collections.emptyList()); } - private ArrowGetterFromList nonComplexArrowGetter(String name, Getters.FromList getter) { + private ArrowGetterFromList nonComplexArrowGetter(String name, YTGetters.FromList getter) { var tiType = getter.getTiType(); switch (tiType.getTypeName()) { case Null: case Void: { - return new ArrowGetterFromList( + return new ArrowGetterFromList<>( new Field(name, new FieldType(false, new ArrowType.Null(), null), new ArrayList<>()) ) { @Override - public ArrowWriterFromList writer(ValueVector valueVector) { + public ArrowWriterFromList writer(ValueVector valueVector) { var nullVector = (NullVector) valueVector; - return new ArrowWriterFromList() { + return new ArrowWriterFromList<>() { @Override - void setFromList(List list, int i) { + void setFromList(Array list, int i) { nullVector.setValueCount(nullVector.getValueCount() + 1); } }; @@ -194,16 +194,16 @@ void setFromList(List list, int i) { } case Utf8: case String: { - var stringGetter = (Getters.FromListToString) getter; - return new ArrowGetterFromList( + var stringGetter = (YTGetters.FromListToString) getter; + return new ArrowGetterFromList<>( new Field(name, new FieldType(false, new ArrowType.Binary(), null), new ArrayList<>()) ) { @Override - public ArrowWriterFromList writer(ValueVector valueVector) { + public ArrowWriterFromList writer(ValueVector valueVector) { var varBinaryVector = (VarBinaryVector) valueVector; - return new ArrowWriterFromList() { + return new ArrowWriterFromList<>() { @Override - void setFromList(List list, int i) { + void setFromList(Array list, int i) { if (list == null) { varBinaryVector.setNull(varBinaryVector.getValueCount()); } else { @@ -220,14 +220,14 @@ void setFromList(List list, int i) { }; } case Int8: { - var byteGetter = (Getters.FromListToByte) getter; - return new ArrowGetterFromList(field(name, new ArrowType.Int(8, true))) { + var byteGetter = (YTGetters.FromListToByte) getter; + return new ArrowGetterFromList<>(field(name, new ArrowType.Int(8, true))) { @Override - public ArrowWriterFromList writer(ValueVector valueVector) { + public ArrowWriterFromList writer(ValueVector valueVector) { var tinyIntVector = (TinyIntVector) valueVector; - return new ArrowWriterFromList() { + return new ArrowWriterFromList<>() { @Override - void setFromList(List list, int i) { + void setFromList(Array list, int i) { if (list == null) { tinyIntVector.setNull(tinyIntVector.getValueCount()); } else { @@ -240,14 +240,14 @@ void setFromList(List list, int i) { }; } case Uint8: { - var byteGetter = (Getters.FromListToByte) getter; - return new ArrowGetterFromList(field(name, new ArrowType.Int(8, false))) { + var byteGetter = (YTGetters.FromListToByte) getter; + return new ArrowGetterFromList<>(field(name, new ArrowType.Int(8, false))) { @Override - public ArrowWriterFromList writer(ValueVector valueVector) { + public ArrowWriterFromList writer(ValueVector valueVector) { var uInt1Vector = (UInt1Vector) valueVector; - return new ArrowWriterFromList() { + return new ArrowWriterFromList<>() { @Override - void setFromList(List list, int i) { + void setFromList(Array list, int i) { if (list == null) { uInt1Vector.setNull(uInt1Vector.getValueCount()); } else { @@ -260,14 +260,14 @@ void setFromList(List list, int i) { }; } case Int16: { - var shortGetter = (Getters.FromListToShort) getter; - return new ArrowGetterFromList(field(name, new ArrowType.Int(16, true))) { + var shortGetter = (YTGetters.FromListToShort) getter; + return new ArrowGetterFromList<>(field(name, new ArrowType.Int(16, true))) { @Override - public ArrowWriterFromList writer(ValueVector valueVector) { + public ArrowWriterFromList writer(ValueVector valueVector) { var smallIntVector = (SmallIntVector) valueVector; - return new ArrowWriterFromList() { + return new ArrowWriterFromList<>() { @Override - void setFromList(List list, int i) { + void setFromList(Array list, int i) { if (list == null) { smallIntVector.setNull(smallIntVector.getValueCount()); } else { @@ -280,14 +280,14 @@ void setFromList(List list, int i) { }; } case Uint16: { - var shortGetter = (Getters.FromListToShort) getter; - return new ArrowGetterFromList(field(name, new ArrowType.Int(16, false))) { + var shortGetter = (YTGetters.FromListToShort) getter; + return new ArrowGetterFromList<>(field(name, new ArrowType.Int(16, false))) { @Override - public ArrowWriterFromList writer(ValueVector valueVector) { + public ArrowWriterFromList writer(ValueVector valueVector) { var uInt2Vector = (UInt2Vector) valueVector; - return new ArrowWriterFromList() { + return new ArrowWriterFromList<>() { @Override - void setFromList(List list, int i) { + void setFromList(Array list, int i) { if (list == null) { uInt2Vector.setNull(uInt2Vector.getValueCount()); } else { @@ -300,14 +300,14 @@ void setFromList(List list, int i) { }; } case Int32: { - var intGetter = (Getters.FromListToInt) getter; - return new ArrowGetterFromList(field(name, new ArrowType.Int(32, true))) { + var intGetter = (YTGetters.FromListToInt) getter; + return new ArrowGetterFromList<>(field(name, new ArrowType.Int(32, true))) { @Override - public ArrowWriterFromList writer(ValueVector valueVector) { + public ArrowWriterFromList writer(ValueVector valueVector) { var intVector = (IntVector) valueVector; - return new ArrowWriterFromList() { + return new ArrowWriterFromList<>() { @Override - void setFromList(List list, int i) { + void setFromList(Array list, int i) { if (list == null) { intVector.setNull(intVector.getValueCount()); } else { @@ -320,14 +320,14 @@ void setFromList(List list, int i) { }; } case Uint32: { - var intGetter = (Getters.FromListToInt) getter; - return new ArrowGetterFromList(field(name, new ArrowType.Int(32, false))) { + var intGetter = (YTGetters.FromListToInt) getter; + return new ArrowGetterFromList<>(field(name, new ArrowType.Int(32, false))) { @Override - public ArrowWriterFromList writer(ValueVector valueVector) { + public ArrowWriterFromList writer(ValueVector valueVector) { var uInt4Vector = (UInt4Vector) valueVector; - return new ArrowWriterFromList() { + return new ArrowWriterFromList<>() { @Override - void setFromList(List list, int i) { + void setFromList(Array list, int i) { if (list == null) { uInt4Vector.setNull(uInt4Vector.getValueCount()); } else { @@ -342,14 +342,14 @@ void setFromList(List list, int i) { case Interval: case Interval64: case Int64: { - var longGetter = (Getters.FromListToLong) getter; - return new ArrowGetterFromList(field(name, new ArrowType.Int(64, true))) { + var longGetter = (YTGetters.FromListToLong) getter; + return new ArrowGetterFromList<>(field(name, new ArrowType.Int(64, true))) { @Override - public ArrowWriterFromList writer(ValueVector valueVector) { + public ArrowWriterFromList writer(ValueVector valueVector) { var bigIntVector = (BigIntVector) valueVector; - return new ArrowWriterFromList() { + return new ArrowWriterFromList<>() { @Override - void setFromList(List list, int i) { + void setFromList(Array list, int i) { if (list == null) { bigIntVector.setNull(bigIntVector.getValueCount()); } else { @@ -362,14 +362,14 @@ void setFromList(List list, int i) { }; } case Uint64: { - var longGetter = (Getters.FromListToLong) getter; - return new ArrowGetterFromList(field(name, new ArrowType.Int(64, false))) { + var longGetter = (YTGetters.FromListToLong) getter; + return new ArrowGetterFromList<>(field(name, new ArrowType.Int(64, false))) { @Override - public ArrowWriterFromList writer(ValueVector valueVector) { + public ArrowWriterFromList writer(ValueVector valueVector) { var uInt8Vector = (UInt8Vector) valueVector; - return new ArrowWriterFromList() { + return new ArrowWriterFromList<>() { @Override - void setFromList(List list, int i) { + void setFromList(Array list, int i) { if (list == null) { uInt8Vector.setNull(uInt8Vector.getValueCount()); } else { @@ -382,14 +382,14 @@ void setFromList(List list, int i) { }; } case Bool: { - var booleanGetter = (Getters.FromListToBoolean) getter; - return new ArrowGetterFromList(field(name, new ArrowType.Bool())) { + var booleanGetter = (YTGetters.FromListToBoolean) getter; + return new ArrowGetterFromList<>(field(name, new ArrowType.Bool())) { @Override - public ArrowWriterFromList writer(ValueVector valueVector) { + public ArrowWriterFromList writer(ValueVector valueVector) { var bitVector = (BitVector) valueVector; - return new ArrowWriterFromList() { + return new ArrowWriterFromList<>() { @Override - void setFromList(List list, int i) { + void setFromList(Array list, int i) { if (list == null) { bitVector.setNull(bitVector.getValueCount()); } else { @@ -402,14 +402,14 @@ void setFromList(List list, int i) { }; } case Float: { - var floatGetter = (Getters.FromListToFloat) getter; - return new ArrowGetterFromList(field(name, new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE))) { + var floatGetter = (YTGetters.FromListToFloat) getter; + return new ArrowGetterFromList<>(field(name, new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE))) { @Override - public ArrowWriterFromList writer(ValueVector valueVector) { + public ArrowWriterFromList writer(ValueVector valueVector) { var float4Vector = (Float4Vector) valueVector; - return new ArrowWriterFromList() { + return new ArrowWriterFromList<>() { @Override - void setFromList(List list, int i) { + void setFromList(Array list, int i) { if (list == null) { float4Vector.setNull(float4Vector.getValueCount()); } else { @@ -422,14 +422,14 @@ void setFromList(List list, int i) { }; } case Double: { - var doubleGetter = (Getters.FromListToDouble) getter; - return new ArrowGetterFromList(field(name, new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE))) { + var doubleGetter = (YTGetters.FromListToDouble) getter; + return new ArrowGetterFromList<>(field(name, new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE))) { @Override - public ArrowWriterFromList writer(ValueVector valueVector) { + public ArrowWriterFromList writer(ValueVector valueVector) { var float8Vector = (Float8Vector) valueVector; - return new ArrowWriterFromList() { + return new ArrowWriterFromList<>() { @Override - void setFromList(List list, int i) { + void setFromList(Array list, int i) { if (list == null) { float8Vector.setNull(float8Vector.getValueCount()); } else { @@ -442,17 +442,17 @@ void setFromList(List list, int i) { }; } case Decimal: { - var decimalGetter = (Getters.FromListToBigDecimal) getter; + var decimalGetter = (YTGetters.FromListToBigDecimal) getter; var decimalType = (DecimalType) decimalGetter.getTiType(); - return new ArrowGetterFromList(field(name, new ArrowType.Decimal( + return new ArrowGetterFromList<>(field(name, new ArrowType.Decimal( decimalType.getPrecision(), decimalType.getScale() ))) { @Override - public ArrowWriterFromList writer(ValueVector valueVector) { + public ArrowWriterFromList writer(ValueVector valueVector) { var decimalVector = (DecimalVector) valueVector; - return new ArrowWriterFromList() { + return new ArrowWriterFromList<>() { @Override - void setFromList(List list, int i) { + void setFromList(Array list, int i) { if (list == null) { decimalVector.setNull(decimalVector.getValueCount()); } else { @@ -466,14 +466,14 @@ void setFromList(List list, int i) { } case Date: case Date32: { - var intGetter = (Getters.FromListToInt) getter; - return new ArrowGetterFromList(field(name, new ArrowType.Date(DateUnit.DAY))) { + var intGetter = (YTGetters.FromListToInt) getter; + return new ArrowGetterFromList<>(field(name, new ArrowType.Date(DateUnit.DAY))) { @Override - public ArrowWriterFromList writer(ValueVector valueVector) { + public ArrowWriterFromList writer(ValueVector valueVector) { var dateDayVector = (DateDayVector) valueVector; - return new ArrowWriterFromList() { + return new ArrowWriterFromList<>() { @Override - void setFromList(List list, int i) { + void setFromList(Array list, int i) { if (list == null) { dateDayVector.setNull(dateDayVector.getValueCount()); } else { @@ -487,14 +487,14 @@ void setFromList(List list, int i) { } case Datetime: case Datetime64: { - var longGetter = (Getters.FromListToLong) getter; - return new ArrowGetterFromList(field(name, new ArrowType.Date(DateUnit.MILLISECOND))) { + var longGetter = (YTGetters.FromListToLong) getter; + return new ArrowGetterFromList<>(field(name, new ArrowType.Date(DateUnit.MILLISECOND))) { @Override - public ArrowWriterFromList writer(ValueVector valueVector) { + public ArrowWriterFromList writer(ValueVector valueVector) { var dateMilliVector = (DateMilliVector) valueVector; - return new ArrowWriterFromList() { + return new ArrowWriterFromList<>() { @Override - void setFromList(List list, int i) { + void setFromList(Array list, int i) { if (list == null) { dateMilliVector.setNull(dateMilliVector.getValueCount()); } else { @@ -508,14 +508,14 @@ void setFromList(List list, int i) { } case Timestamp: case Timestamp64: { - var longGetter = (Getters.FromListToLong) getter; - return new ArrowGetterFromList(field(name, new ArrowType.Timestamp(TimeUnit.MICROSECOND, null))) { + var longGetter = (YTGetters.FromListToLong) getter; + return new ArrowGetterFromList<>(field(name, new ArrowType.Timestamp(TimeUnit.MICROSECOND, null))) { @Override - public ArrowWriterFromList writer(ValueVector valueVector) { + public ArrowWriterFromList writer(ValueVector valueVector) { var timeStampMicroVector = (TimeStampMicroVector) valueVector; - return new ArrowWriterFromList() { + return new ArrowWriterFromList<>() { @Override - void setFromList(List list, int i) { + void setFromList(Array list, int i) { if (list == null) { timeStampMicroVector.setNull(timeStampMicroVector.getValueCount()); } else { @@ -528,133 +528,148 @@ void setFromList(List list, int i) { }; } case List: { - var listGetter = (Getters.FromListToList) getter; - var elementGetter = listGetter.getElementGetter(); - var itemGetter = arrowGetter("item", (Getters.FromList) elementGetter); - return new ArrowGetterFromList(new Field(name, new FieldType( - false, new ArrowType.List(), null - ), Collections.singletonList(itemGetter.field))) { + return getArrowGetterFromList(name, (YTGetters.FromListToList) getter); + } + case Dict: { + return getArrowGetterFromList(name, (YTGetters.FromListToDict) getter); + } + case Struct: { + return getArrowGetterFromList(name, (YTGetters.FromListToStruct) getter); + } + default: + return null; + } + } + + private ArrowGetterFromList getArrowGetterFromList( + String name, YTGetters.FromListToStruct structGetter + ) { + var members = structGetter.getMembersGetters(); + var membersGetters = new ArrayList>(members.size()); + for (var member : members) { + membersGetters.add(arrowGetter(member.getKey(), member.getValue())); + } + return new ArrowGetterFromList<>(new Field( + name, new FieldType(false, new ArrowType.Struct(), null), + membersGetters.stream().map(member -> member.field).collect(Collectors.toList()) + )) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var structVector = (StructVector) valueVector; + var membersWriters = new ArrayList>(members.size()); + for (int i = 0; i < members.size(); i++) { + membersWriters.add(membersGetters.get(i).writer(structVector.getChildByOrdinal(i))); + } + return new ArrowWriterFromList<>() { @Override - public ArrowWriterFromList writer(ValueVector valueVector) { - var listVector = (ListVector) valueVector; - var dataWriter = itemGetter.writer(listVector.getDataVector()); - return new ArrowWriterFromList() { - @Override - void setFromList(List list, int i) { - var value = list == null ? null : (List) listGetter.getList(list, i); - if (value != null) { - int size = elementGetter.getSize(value); - listVector.startNewValue(listVector.getValueCount()); - for (int j = 0; j < size; j++) { - dataWriter.setFromList(value, j); - } - listVector.endValue(listVector.getValueCount(), size); - } - listVector.setValueCount(listVector.getValueCount() + 1); + void setFromList(Array list, int i) { + if (list == null) { + for (int j = 0; j < members.size(); j++) { + membersWriters.get(j).setFromStruct(null); } - }; + } else { + var struct = structGetter.getStruct(list, i); + structVector.setIndexDefined(structVector.getValueCount()); + for (int j = 0; j < members.size(); j++) { + membersWriters.get(j).setFromStruct(struct); + } + } + structVector.setValueCount(structVector.getValueCount() + 1); } }; } - case Dict: { - var dictGetter = (Getters.FromListToDict) getter; - var fromDictGetter = dictGetter.getGetter(); - var keyGetter = nonComplexArrowGetter("key", (Getters.FromList) fromDictGetter.getKeyGetter()); - var valueGetter = arrowGetter("value", (Getters.FromList) fromDictGetter.getValueGetter()); - if (keyGetter == null || valueGetter == null) { - return null; - } - return new ArrowGetterFromList(new Field( - name, new FieldType(false, new ArrowType.Map(false), null), - Collections.singletonList(new Field( - "entries", new FieldType(false, new ArrowType.Struct(), null), - Arrays.asList(keyGetter.field, valueGetter.field) - )) - )) { + }; + } + + private ArrowGetterFromList getArrowGetterFromList( + String name, YTGetters.FromListToDict dictGetter + ) { + var fromDictGetter = dictGetter.getGetter(); + var keyGetter = nonComplexArrowGetter("key", fromDictGetter.getKeyGetter()); + var valueGetter = arrowGetter("value", fromDictGetter.getValueGetter()); + if (keyGetter == null) { + return null; + } + return new ArrowGetterFromList<>(new Field( + name, new FieldType(false, new ArrowType.Map(false), null), + Collections.singletonList(new Field( + "entries", new FieldType(false, new ArrowType.Struct(), null), + Arrays.asList(keyGetter.field, valueGetter.field) + )) + )) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var mapVector = (MapVector) valueVector; + var structVector = (StructVector) mapVector.getDataVector(); + var keyWriter = keyGetter.writer(structVector.getChildByOrdinal(0)); + var valueWriter = valueGetter.writer(structVector.getChildByOrdinal(1)); + return new ArrowWriterFromList<>() { @Override - public ArrowWriterFromList writer(ValueVector valueVector) { - var mapVector = (MapVector) valueVector; - var structVector = (StructVector) mapVector.getDataVector(); - var keyWriter = keyGetter.writer(structVector.getChildByOrdinal(0)); - var valueWriter = valueGetter.writer(structVector.getChildByOrdinal(1)); - return new ArrowWriterFromList() { - @Override - void setFromList(List list, int i) { - var dict = list == null ? null : dictGetter.getDict(list, i); - if (dict != null) { - int size = fromDictGetter.getSize(dict); - var keys = fromDictGetter.getKeys(dict); - var values = fromDictGetter.getValues(dict); - mapVector.startNewValue(mapVector.getValueCount()); - for (int j = 0; j < size; j++) { - structVector.setIndexDefined(structVector.getValueCount()); - keyWriter.setFromList((List) keys, j); - valueWriter.setFromList((List) values, j); - structVector.setValueCount(structVector.getValueCount() + 1); - } - mapVector.endValue(mapVector.getValueCount(), size); - } - mapVector.setValueCount(mapVector.getValueCount() + 1); + void setFromList(Array list, int i) { + var dict = list == null ? null : dictGetter.getDict(list, i); + if (dict != null) { + int size = fromDictGetter.getSize(dict); + var keys = fromDictGetter.getKeys(dict); + var values = fromDictGetter.getValues(dict); + mapVector.startNewValue(mapVector.getValueCount()); + for (int j = 0; j < size; j++) { + structVector.setIndexDefined(structVector.getValueCount()); + keyWriter.setFromList(keys, j); + valueWriter.setFromList(values, j); + structVector.setValueCount(structVector.getValueCount() + 1); } - }; + mapVector.endValue(mapVector.getValueCount(), size); + } + mapVector.setValueCount(mapVector.getValueCount() + 1); } }; } - case Struct: { - var structGetter = (Getters.FromListToStruct) getter; - var members = (java.util.List>) structGetter.getMembersGetters(); - var membersGetters = new ArrayList(members.size()); - for (Map.Entry member : members) { - membersGetters.add(arrowGetter(member.getKey(), member.getValue())); - } - return new ArrowGetterFromList(new Field( - name, new FieldType(false, new ArrowType.Struct(), null), - membersGetters.stream().map(member -> member.field).collect(Collectors.toList()) - )) { + }; + } + + private ArrowGetterFromList getArrowGetterFromList( + String name, YTGetters.FromListToList listGetter + ) { + var elementGetter = listGetter.getElementGetter(); + var itemGetter = arrowGetter("item", elementGetter); + return new ArrowGetterFromList<>(new Field(name, new FieldType( + false, new ArrowType.List(), null + ), Collections.singletonList(itemGetter.field))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var listVector = (ListVector) valueVector; + var dataWriter = itemGetter.writer(listVector.getDataVector()); + return new ArrowWriterFromList<>() { @Override - public ArrowWriterFromList writer(ValueVector valueVector) { - var structVector = (StructVector) valueVector; - var membersWriters = new ArrayList(members.size()); - for (int i = 0; i < members.size(); i++) { - membersWriters.add(membersGetters.get(i).writer(structVector.getChildByOrdinal(i))); - } - return new ArrowWriterFromList() { - @Override - void setFromList(List list, int i) { - if (list == null) { - for (int j = 0; j < members.size(); j++) { - membersWriters.get(j).setFromStruct(null); - } - } else { - var struct = (Struct) structGetter.getStruct(list, i); - structVector.setIndexDefined(structVector.getValueCount()); - for (int j = 0; j < members.size(); j++) { - membersWriters.get(j).setFromStruct(struct); - } - } - structVector.setValueCount(structVector.getValueCount() + 1); + void setFromList(Array list, int i) { + var value = list == null ? null : listGetter.getList(list, i); + if (value != null) { + int size = elementGetter.getSize(value); + listVector.startNewValue(listVector.getValueCount()); + for (int j = 0; j < size; j++) { + dataWriter.setFromList(value, j); } - }; + listVector.endValue(listVector.getValueCount(), size); + } + listVector.setValueCount(listVector.getValueCount() + 1); } }; } - default: - return null; - } + }; } - private ArrowGetterFromStruct nonComplexArrowGetter(String name, Getters.FromStruct getter) { + private ArrowGetterFromStruct nonComplexArrowGetter(String name, YTGetters.FromStruct getter) { var tiType = getter.getTiType(); switch (tiType.getTypeName()) { case Null: case Void: { - return new ArrowGetterFromStruct( + return new ArrowGetterFromStruct<>( new Field(name, new FieldType(false, new ArrowType.Null(), null), new ArrayList<>()) ) { @Override - public ArrowWriterFromStruct writer(ValueVector valueVector) { + public ArrowWriterFromStruct writer(ValueVector valueVector) { var nullVector = (NullVector) valueVector; - return new ArrowWriterFromStruct() { + return new ArrowWriterFromStruct<>() { @Override void setFromStruct(Struct struct) { nullVector.setValueCount(nullVector.getValueCount() + 1); @@ -665,12 +680,12 @@ void setFromStruct(Struct struct) { } case Utf8: case String: { - var stringGetter = (Getters.FromStructToString) getter; - return new ArrowGetterFromStruct(field(name, new ArrowType.Binary())) { + var stringGetter = (YTGetters.FromStructToString) getter; + return new ArrowGetterFromStruct<>(field(name, new ArrowType.Binary())) { @Override - public ArrowWriterFromStruct writer(ValueVector valueVector) { + public ArrowWriterFromStruct writer(ValueVector valueVector) { var varBinaryVector = (VarBinaryVector) valueVector; - return new ArrowWriterFromStruct() { + return new ArrowWriterFromStruct<>() { @Override void setFromStruct(Struct struct) { if (struct == null) { @@ -689,12 +704,12 @@ void setFromStruct(Struct struct) { }; } case Int8: { - var byteGetter = (Getters.FromStructToByte) getter; - return new ArrowGetterFromStruct(field(name, new ArrowType.Int(8, true))) { + var byteGetter = (YTGetters.FromStructToByte) getter; + return new ArrowGetterFromStruct<>(field(name, new ArrowType.Int(8, true))) { @Override - public ArrowWriterFromStruct writer(ValueVector valueVector) { + public ArrowWriterFromStruct writer(ValueVector valueVector) { var tinyIntVector = (TinyIntVector) valueVector; - return new ArrowWriterFromStruct() { + return new ArrowWriterFromStruct<>() { @Override void setFromStruct(Struct struct) { if (struct == null) { @@ -709,12 +724,12 @@ void setFromStruct(Struct struct) { }; } case Uint8: { - var byteGetter = (Getters.FromStructToByte) getter; - return new ArrowGetterFromStruct(field(name, new ArrowType.Int(8, false))) { + var byteGetter = (YTGetters.FromStructToByte) getter; + return new ArrowGetterFromStruct<>(field(name, new ArrowType.Int(8, false))) { @Override - public ArrowWriterFromStruct writer(ValueVector valueVector) { + public ArrowWriterFromStruct writer(ValueVector valueVector) { var uInt1Vector = (UInt1Vector) valueVector; - return new ArrowWriterFromStruct() { + return new ArrowWriterFromStruct<>() { @Override void setFromStruct(Struct struct) { if (struct == null) { @@ -729,12 +744,12 @@ void setFromStruct(Struct struct) { }; } case Int16: { - var shortGetter = (Getters.FromStructToShort) getter; - return new ArrowGetterFromStruct(field(name, new ArrowType.Int(16, true))) { + var shortGetter = (YTGetters.FromStructToShort) getter; + return new ArrowGetterFromStruct<>(field(name, new ArrowType.Int(16, true))) { @Override - public ArrowWriterFromStruct writer(ValueVector valueVector) { + public ArrowWriterFromStruct writer(ValueVector valueVector) { var smallIntVector = (SmallIntVector) valueVector; - return new ArrowWriterFromStruct() { + return new ArrowWriterFromStruct<>() { @Override void setFromStruct(Struct struct) { if (struct == null) { @@ -749,12 +764,12 @@ void setFromStruct(Struct struct) { }; } case Uint16: { - var shortGetter = (Getters.FromStructToShort) getter; - return new ArrowGetterFromStruct(field(name, new ArrowType.Int(16, false))) { + var shortGetter = (YTGetters.FromStructToShort) getter; + return new ArrowGetterFromStruct<>(field(name, new ArrowType.Int(16, false))) { @Override - public ArrowWriterFromStruct writer(ValueVector valueVector) { + public ArrowWriterFromStruct writer(ValueVector valueVector) { var uInt2Vector = (UInt2Vector) valueVector; - return new ArrowWriterFromStruct() { + return new ArrowWriterFromStruct<>() { @Override void setFromStruct(Struct struct) { if (struct == null) { @@ -769,12 +784,12 @@ void setFromStruct(Struct struct) { }; } case Int32: { - var intGetter = (Getters.FromStructToInt) getter; - return new ArrowGetterFromStruct(field(name, new ArrowType.Int(32, true))) { + var intGetter = (YTGetters.FromStructToInt) getter; + return new ArrowGetterFromStruct<>(field(name, new ArrowType.Int(32, true))) { @Override - public ArrowWriterFromStruct writer(ValueVector valueVector) { + public ArrowWriterFromStruct writer(ValueVector valueVector) { var intVector = (IntVector) valueVector; - return new ArrowWriterFromStruct() { + return new ArrowWriterFromStruct<>() { @Override void setFromStruct(Struct struct) { if (struct == null) { @@ -789,12 +804,12 @@ void setFromStruct(Struct struct) { }; } case Uint32: { - var intGetter = (Getters.FromStructToInt) getter; - return new ArrowGetterFromStruct(field(name, new ArrowType.Int(32, false))) { + var intGetter = (YTGetters.FromStructToInt) getter; + return new ArrowGetterFromStruct<>(field(name, new ArrowType.Int(32, false))) { @Override - public ArrowWriterFromStruct writer(ValueVector valueVector) { + public ArrowWriterFromStruct writer(ValueVector valueVector) { var uInt4Vector = (UInt4Vector) valueVector; - return new ArrowWriterFromStruct() { + return new ArrowWriterFromStruct<>() { @Override void setFromStruct(Struct struct) { if (struct == null) { @@ -811,12 +826,12 @@ void setFromStruct(Struct struct) { case Interval: case Interval64: case Int64: { - var longGetter = (Getters.FromStructToLong) getter; - return new ArrowGetterFromStruct(field(name, new ArrowType.Int(64, true))) { + var longGetter = (YTGetters.FromStructToLong) getter; + return new ArrowGetterFromStruct<>(field(name, new ArrowType.Int(64, true))) { @Override - public ArrowWriterFromStruct writer(ValueVector valueVector) { + public ArrowWriterFromStruct writer(ValueVector valueVector) { var bigIntVector = (BigIntVector) valueVector; - return new ArrowWriterFromStruct() { + return new ArrowWriterFromStruct<>() { @Override void setFromStruct(Struct struct) { if (struct == null) { @@ -831,12 +846,12 @@ void setFromStruct(Struct struct) { }; } case Uint64: { - var longGetter = (Getters.FromStructToLong) getter; - return new ArrowGetterFromStruct(field(name, new ArrowType.Int(64, false))) { + var longGetter = (YTGetters.FromStructToLong) getter; + return new ArrowGetterFromStruct<>(field(name, new ArrowType.Int(64, false))) { @Override - public ArrowWriterFromStruct writer(ValueVector valueVector) { + public ArrowWriterFromStruct writer(ValueVector valueVector) { var uInt8Vector = (UInt8Vector) valueVector; - return new ArrowWriterFromStruct() { + return new ArrowWriterFromStruct<>() { @Override void setFromStruct(Struct struct) { if (struct == null) { @@ -851,12 +866,12 @@ void setFromStruct(Struct struct) { }; } case Bool: { - var booleanGetter = (Getters.FromStructToBoolean) getter; - return new ArrowGetterFromStruct(field(name, new ArrowType.Bool())) { + var booleanGetter = (YTGetters.FromStructToBoolean) getter; + return new ArrowGetterFromStruct<>(field(name, new ArrowType.Bool())) { @Override - public ArrowWriterFromStruct writer(ValueVector valueVector) { + public ArrowWriterFromStruct writer(ValueVector valueVector) { var bitVector = (BitVector) valueVector; - return new ArrowWriterFromStruct() { + return new ArrowWriterFromStruct<>() { @Override void setFromStruct(Struct struct) { if (struct == null) { @@ -871,12 +886,12 @@ void setFromStruct(Struct struct) { }; } case Float: { - var floatGetter = (Getters.FromStructToFloat) getter; - return new ArrowGetterFromStruct(field(name, new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE))) { + var floatGetter = (YTGetters.FromStructToFloat) getter; + return new ArrowGetterFromStruct<>(field(name, new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE))) { @Override - public ArrowWriterFromStruct writer(ValueVector valueVector) { + public ArrowWriterFromStruct writer(ValueVector valueVector) { var float4Vector = (Float4Vector) valueVector; - return new ArrowWriterFromStruct() { + return new ArrowWriterFromStruct<>() { @Override void setFromStruct(Struct struct) { if (struct == null) { @@ -891,12 +906,12 @@ void setFromStruct(Struct struct) { }; } case Double: { - var doubleGetter = (Getters.FromStructToDouble) getter; - return new ArrowGetterFromStruct(field(name, new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE))) { + var doubleGetter = (YTGetters.FromStructToDouble) getter; + return new ArrowGetterFromStruct<>(field(name, new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE))) { @Override - public ArrowWriterFromStruct writer(ValueVector valueVector) { + public ArrowWriterFromStruct writer(ValueVector valueVector) { var float8Vector = (Float8Vector) valueVector; - return new ArrowWriterFromStruct() { + return new ArrowWriterFromStruct<>() { @Override void setFromStruct(Struct struct) { if (struct == null) { @@ -911,15 +926,15 @@ void setFromStruct(Struct struct) { }; } case Decimal: { - var decimalGetter = (Getters.FromStructToBigDecimal) getter; + var decimalGetter = (YTGetters.FromStructToBigDecimal) getter; var decimalType = (DecimalType) decimalGetter.getTiType(); - return new ArrowGetterFromStruct(field(name, new ArrowType.Decimal( + return new ArrowGetterFromStruct<>(field(name, new ArrowType.Decimal( decimalType.getPrecision(), decimalType.getScale() ))) { @Override - public ArrowWriterFromStruct writer(ValueVector valueVector) { + public ArrowWriterFromStruct writer(ValueVector valueVector) { var decimalVector = (DecimalVector) valueVector; - return new ArrowWriterFromStruct() { + return new ArrowWriterFromStruct<>() { @Override void setFromStruct(Struct struct) { if (struct == null) { @@ -935,12 +950,12 @@ void setFromStruct(Struct struct) { } case Date: case Date32: { - var intGetter = (Getters.FromStructToInt) getter; - return new ArrowGetterFromStruct(field(name, new ArrowType.Date(DateUnit.DAY))) { + var intGetter = (YTGetters.FromStructToInt) getter; + return new ArrowGetterFromStruct<>(field(name, new ArrowType.Date(DateUnit.DAY))) { @Override - public ArrowWriterFromStruct writer(ValueVector valueVector) { + public ArrowWriterFromStruct writer(ValueVector valueVector) { var dateDayVector = (DateDayVector) valueVector; - return new ArrowWriterFromStruct() { + return new ArrowWriterFromStruct<>() { @Override void setFromStruct(Struct struct) { if (struct == null) { @@ -956,12 +971,12 @@ void setFromStruct(Struct struct) { } case Datetime: case Datetime64: { - var longGetter = (Getters.FromStructToLong) getter; - return new ArrowGetterFromStruct(field(name, new ArrowType.Date(DateUnit.MILLISECOND))) { + var longGetter = (YTGetters.FromStructToLong) getter; + return new ArrowGetterFromStruct<>(field(name, new ArrowType.Date(DateUnit.MILLISECOND))) { @Override - public ArrowWriterFromStruct writer(ValueVector valueVector) { + public ArrowWriterFromStruct writer(ValueVector valueVector) { var dateMilliVector = (DateMilliVector) valueVector; - return new ArrowWriterFromStruct() { + return new ArrowWriterFromStruct<>() { @Override void setFromStruct(Struct struct) { if (struct == null) { @@ -977,12 +992,12 @@ void setFromStruct(Struct struct) { } case Timestamp: case Timestamp64: { - var longGetter = (Getters.FromStructToLong) getter; - return new ArrowGetterFromStruct(field(name, new ArrowType.Timestamp(TimeUnit.MICROSECOND, null))) { + var longGetter = (YTGetters.FromStructToLong) getter; + return new ArrowGetterFromStruct<>(field(name, new ArrowType.Timestamp(TimeUnit.MICROSECOND, null))) { @Override - public ArrowWriterFromStruct writer(ValueVector valueVector) { + public ArrowWriterFromStruct writer(ValueVector valueVector) { var timeStampMicroVector = (TimeStampMicroVector) valueVector; - return new ArrowWriterFromStruct() { + return new ArrowWriterFromStruct<>() { @Override void setFromStruct(Struct struct) { if (struct == null) { @@ -997,127 +1012,142 @@ void setFromStruct(Struct struct) { }; } case List: { - var listGetter = (Getters.FromStructToList) getter; - var elementGetter = (Getters.FromList) listGetter.getElementGetter(); - var itemGetter = arrowGetter("item", elementGetter); - return new ArrowGetterFromStruct(new Field(name, new FieldType( - false, new ArrowType.List(), null - ), Collections.singletonList(itemGetter.field))) { + return getArrowGetterFromStruct(name, (YTGetters.FromStructToList) getter); + } + case Dict: { + return getArrowGetterFromStruct(name, (YTGetters.FromStructToDict) getter); + } + case Struct: { + return getArrowGetterFromStruct(name, (YTGetters.FromStructToStruct) getter); + } + default: + return null; + } + } + + private ArrowGetterFromStruct getArrowGetterFromStruct( + String name, YTGetters.FromStructToStruct structGetter + ) { + var members = structGetter.getMembersGetters(); + var membersGetters = new ArrayList>(members.size()); + for (var member : members) { + membersGetters.add(arrowGetter(member.getKey(), member.getValue())); + } + return new ArrowGetterFromStruct<>(new Field( + name, new FieldType(false, new ArrowType.Struct(), null), + membersGetters.stream().map(member -> member.field).collect(Collectors.toList()) + )) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var structVector = (StructVector) valueVector; + var membersWriters = new ArrayList>(members.size()); + for (int i = 0; i < members.size(); i++) { + membersWriters.add(membersGetters.get(i).writer(structVector.getChildByOrdinal(i))); + } + return new ArrowWriterFromStruct<>() { @Override - public ArrowWriterFromStruct writer(ValueVector valueVector) { - var listVector = (ListVector) valueVector; - var dataWriter = itemGetter.writer(listVector.getDataVector()); - return new ArrowWriterFromStruct() { - @Override - public void setFromStruct(Struct struct) { - var list = struct == null ? null : (List) listGetter.getList(struct); - if (list != null) { - int size = elementGetter.getSize(list); - listVector.startNewValue(listVector.getValueCount()); - for (int i = 0; i < size; i++) { - dataWriter.setFromList(list, i); - } - listVector.endValue(listVector.getValueCount(), size); - } - listVector.setValueCount(listVector.getValueCount() + 1); + void setFromStruct(Struct row) { + if (row == null) { + for (int i = 0; i < members.size(); i++) { + membersWriters.get(i).setFromStruct(null); } - }; + } else { + var value = structGetter.getStruct(row); + structVector.setIndexDefined(structVector.getValueCount()); + for (int i = 0; i < members.size(); i++) { + membersWriters.get(i).setFromStruct(value); + } + } + structVector.setValueCount(structVector.getValueCount() + 1); } }; } - case Dict: { - var dictGetter = (Getters.FromStructToDict) getter; - var fromDictGetter = (Getters.FromDict) dictGetter.getGetter(); - var keyGetter = nonComplexArrowGetter("key", (Getters.FromList) fromDictGetter.getKeyGetter()); - var valueGetter = arrowGetter("value", (Getters.FromList) fromDictGetter.getValueGetter()); - if (keyGetter == null || valueGetter == null) { - return null; - } - return new ArrowGetterFromStruct(new Field( - name, new FieldType(false, new ArrowType.Map(false), null), - Collections.singletonList(new Field( - "entries", new FieldType(false, new ArrowType.Struct(), null), - Arrays.asList(keyGetter.field, valueGetter.field) - )) - )) { + }; + } + + private ArrowGetterFromStruct getArrowGetterFromStruct( + String name, YTGetters.FromStructToDict dictGetter + ) { + var fromDictGetter = dictGetter.getGetter(); + var keyGetter = nonComplexArrowGetter("key", fromDictGetter.getKeyGetter()); + var valueGetter = arrowGetter("value", fromDictGetter.getValueGetter()); + if (keyGetter == null) { + return null; + } + return new ArrowGetterFromStruct<>(new Field( + name, new FieldType(false, new ArrowType.Map(false), null), + Collections.singletonList(new Field( + "entries", new FieldType(false, new ArrowType.Struct(), null), + Arrays.asList(keyGetter.field, valueGetter.field) + )) + )) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var mapVector = (MapVector) valueVector; + var structVector = (StructVector) mapVector.getDataVector(); + var keyWriter = keyGetter.writer(structVector.getChildByOrdinal(0)); + var valueWriter = valueGetter.writer(structVector.getChildByOrdinal(1)); + return new ArrowWriterFromStruct<>() { @Override - public ArrowWriterFromStruct writer(ValueVector valueVector) { - var mapVector = (MapVector) valueVector; - var structVector = (StructVector) mapVector.getDataVector(); - var keyWriter = keyGetter.writer(structVector.getChildByOrdinal(0)); - var valueWriter = valueGetter.writer(structVector.getChildByOrdinal(1)); - return new ArrowWriterFromStruct() { - @Override - public void setFromStruct(Struct struct) { - var dict = struct == null ? null : dictGetter.getDict(struct); - if (dict != null) { - int size = fromDictGetter.getSize(dict); - var keys = (List) fromDictGetter.getKeys(dict); - var values = (List) fromDictGetter.getValues(dict); - mapVector.startNewValue(mapVector.getValueCount()); - for (int i = 0; i < size; i++) { - structVector.setIndexDefined(structVector.getValueCount()); - keyWriter.setFromList(keys, i); - valueWriter.setFromList(values, i); - structVector.setValueCount(structVector.getValueCount() + 1); - } - mapVector.endValue(mapVector.getValueCount(), size); - } - mapVector.setValueCount(mapVector.getValueCount() + 1); + public void setFromStruct(Struct struct) { + var dict = struct == null ? null : dictGetter.getDict(struct); + if (dict != null) { + int size = fromDictGetter.getSize(dict); + var keys = fromDictGetter.getKeys(dict); + var values = fromDictGetter.getValues(dict); + mapVector.startNewValue(mapVector.getValueCount()); + for (int i = 0; i < size; i++) { + structVector.setIndexDefined(structVector.getValueCount()); + keyWriter.setFromList(keys, i); + valueWriter.setFromList(values, i); + structVector.setValueCount(structVector.getValueCount() + 1); } - }; + mapVector.endValue(mapVector.getValueCount(), size); + } + mapVector.setValueCount(mapVector.getValueCount() + 1); } }; } - case Struct: { - var structGetter = (Getters.FromStructToStruct) getter; - var members = (java.util.List>) structGetter.getMembersGetters(); - var membersGetters = new ArrayList(members.size()); - for (Map.Entry member : members) { - membersGetters.add(arrowGetter(member.getKey(), member.getValue())); - } - return new ArrowGetterFromStruct(new Field( - name, new FieldType(false, new ArrowType.Struct(), null), - membersGetters.stream().map(member -> member.field).collect(Collectors.toList()) - )) { + }; + } + + private ArrowGetterFromStruct getArrowGetterFromStruct( + String name, YTGetters.FromStructToList listGetter + ) { + var elementGetter = listGetter.getElementGetter(); + var itemGetter = arrowGetter("item", elementGetter); + return new ArrowGetterFromStruct<>(new Field(name, new FieldType( + false, new ArrowType.List(), null + ), Collections.singletonList(itemGetter.field))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var listVector = (ListVector) valueVector; + var dataWriter = itemGetter.writer(listVector.getDataVector()); + return new ArrowWriterFromStruct<>() { @Override - public ArrowWriterFromStruct writer(ValueVector valueVector) { - var structVector = (StructVector) valueVector; - var membersWriters = new ArrayList(members.size()); - for (int i = 0; i < members.size(); i++) { - membersWriters.add(membersGetters.get(i).writer(structVector.getChildByOrdinal(i))); - } - return new ArrowWriterFromStruct() { - @Override - void setFromStruct(Struct row) { - if (row == null) { - for (int i = 0; i < members.size(); i++) { - membersWriters.get(i).setFromStruct(null); - } - } else { - var struct = (Struct) structGetter.getStruct(row); - structVector.setIndexDefined(structVector.getValueCount()); - for (int i = 0; i < members.size(); i++) { - membersWriters.get(i).setFromStruct(struct); - } - } - structVector.setValueCount(structVector.getValueCount() + 1); + public void setFromStruct(Struct struct) { + var list = struct == null ? null : listGetter.getList(struct); + if (list != null) { + int size = elementGetter.getSize(list); + listVector.startNewValue(listVector.getValueCount()); + for (int i = 0; i < size; i++) { + dataWriter.setFromList(list, i); } - }; + listVector.endValue(listVector.getValueCount(), size); + } + listVector.setValueCount(listVector.getValueCount() + 1); } }; } - default: - return null; - } + }; } - private final java.util.List fieldGetters; + private final java.util.List> fieldGetters; private final Schema schema; private final BufferAllocator allocator = ArrowUtils.rootAllocator().newChildAllocator("toBatchIterator", 0, Long.MAX_VALUE); - public ArrowTableRowsSerializer(java.util.List> structsGetter) { + public ArrowTableRowsSerializer(java.util.List>> structsGetter) { super(ERowsetFormat.RF_FORMAT); fieldGetters = structsGetter.stream().map(memberGetter -> arrowGetter( memberGetter.getKey(), memberGetter.getValue() @@ -1168,13 +1198,13 @@ protected void writeMeta(ByteBuf buf, ByteBuf serializedRows, int rowsCount) { @Override protected void writeRowsWithoutCount( - ByteBuf buf, TRowsetDescriptor descriptor, java.util.List rows, int[] idMapping + ByteBuf buf, TRowsetDescriptor descriptor, java.util.List rows, int[] idMapping ) { writeRows(buf, descriptor, rows, idMapping); } @Override - protected void writeRows(ByteBuf buf, TRowsetDescriptor descriptor, java.util.List rows, int[] idMapping) { + protected void writeRows(ByteBuf buf, TRowsetDescriptor descriptor, java.util.List rows, int[] idMapping) { try { var writeChannel = new WriteChannel(new ByteBufWritableByteChannel(buf)); MessageSerializer.serialize(writeChannel, schema); diff --git a/data-source/src/main/scala/tech/ytsaurus/client/ArrowWriteSerializationContext.java b/data-source/src/main/scala/tech/ytsaurus/client/ArrowWriteSerializationContext.java index 6779e849..bc3eab48 100644 --- a/data-source/src/main/scala/tech/ytsaurus/client/ArrowWriteSerializationContext.java +++ b/data-source/src/main/scala/tech/ytsaurus/client/ArrowWriteSerializationContext.java @@ -7,18 +7,18 @@ import java.util.HashMap; import java.util.Map; -public class ArrowWriteSerializationContext> extends SerializationContext { - private final java.util.List> rowGetters; +public class ArrowWriteSerializationContext extends SerializationContext { + private final java.util.List>> rowGetters; public ArrowWriteSerializationContext( - java.util.List> rowGetters + java.util.List>> rowGetters ) { this.rowsetFormat = ERowsetFormat.RF_FORMAT; this.format = new Format("arrow", new HashMap<>()); this.rowGetters = rowGetters; } - public java.util.List> getRowGetters() { + public java.util.List>> getRowGetters() { return rowGetters; } } diff --git a/data-source/src/main/scala/tech/ytsaurus/client/InternalRowYTGetters.java b/data-source/src/main/scala/tech/ytsaurus/client/InternalRowYTGetters.java deleted file mode 100644 index b823af88..00000000 --- a/data-source/src/main/scala/tech/ytsaurus/client/InternalRowYTGetters.java +++ /dev/null @@ -1,8 +0,0 @@ -package tech.ytsaurus.client; - -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.util.ArrayData; -import org.apache.spark.sql.catalyst.util.MapData; - -public class InternalRowYTGetters extends YTGetters { -} diff --git a/data-source/src/main/scala/tech/ytsaurus/client/TableWriterBaseImpl.java b/data-source/src/main/scala/tech/ytsaurus/client/TableWriterBaseImpl.java index 21bac7a9..fc8ad5b4 100644 --- a/data-source/src/main/scala/tech/ytsaurus/client/TableWriterBaseImpl.java +++ b/data-source/src/main/scala/tech/ytsaurus/client/TableWriterBaseImpl.java @@ -70,7 +70,7 @@ public CompletableFuture> startUploadImpl() { var format = this.req.getSerializationContext().getFormat(); if (format.isPresent() && "arrow".equals(format.get().getType())) { tableRowsSerializer = new ArrowTableRowsSerializer<>( - ((ArrowWriteSerializationContext) this.req.getSerializationContext()).getRowGetters() + ((ArrowWriteSerializationContext) this.req.getSerializationContext()).getRowGetters() ); } } diff --git a/data-source/src/main/scala/tech/ytsaurus/client/YTGetters.java b/data-source/src/main/scala/tech/ytsaurus/client/YTGetters.java index 6c75768f..173c98f4 100644 --- a/data-source/src/main/scala/tech/ytsaurus/client/YTGetters.java +++ b/data-source/src/main/scala/tech/ytsaurus/client/YTGetters.java @@ -7,171 +7,162 @@ import java.nio.ByteBuffer; import java.util.Map; -public class YTGetters { - public abstract class Getter { - private Getter() { - } - - public abstract TiType getTiType(); +public class YTGetters { + public interface GetTiType { + TiType getTiType(); } - public abstract class FromStruct extends Getter { - private FromStruct() { - } - - public abstract void getYson(Struct struct, YsonConsumer ysonConsumer); + public interface FromStruct extends GetTiType { + void getYson(Struct struct, YsonConsumer ysonConsumer); } - public abstract class FromList extends Getter { - private FromList() { - } - - public abstract int getSize(List list); + public interface FromList extends GetTiType { + int getSize(List list); - public abstract void getYson(List list, int i, YsonConsumer ysonConsumer); + void getYson(List list, int i, YsonConsumer ysonConsumer); } - public abstract class FromStructToYson extends FromStruct { + public interface FromStructToYson extends FromStruct { } - public abstract class FromListToYson extends FromList { + public interface FromListToYson extends FromList { } - public abstract class FromDict extends Getter { - public abstract FromList getKeyGetter(); + public interface FromDict extends GetTiType { + FromList getKeyGetter(); - public abstract FromList getValueGetter(); + FromList getValueGetter(); - public abstract int getSize(Dict dict); + int getSize(Dict dict); - public abstract List getKeys(Dict dict); + Keys getKeys(Dict dict); - public abstract List getValues(Dict dict); + Values getValues(Dict dict); } - public abstract class FromStructToNull extends FromStruct { + public interface FromStructToNull extends FromStruct { } - public abstract class FromListToNull extends FromList { + public interface FromListToNull extends FromList { } - public abstract class FromStructToOptional extends FromStruct { - public abstract FromStruct getNotEmptyGetter(); + public interface FromStructToOptional extends FromStruct { + FromStruct getNotEmptyGetter(); - public abstract boolean isEmpty(Struct struct); + boolean isEmpty(Struct struct); } - public abstract class FromListToOptional extends FromList { - public abstract FromList getNotEmptyGetter(); + public interface FromListToOptional extends FromList { + FromList getNotEmptyGetter(); - public abstract boolean isEmpty(List list, int i); + boolean isEmpty(List list, int i); } - public abstract class FromStructToString extends FromStruct { - public abstract ByteBuffer getString(Struct struct); + public interface FromStructToString extends FromStruct { + ByteBuffer getString(Struct struct); } - public abstract class FromListToString extends FromList { - public abstract ByteBuffer getString(List struct, int i); + public interface FromListToString extends FromList { + ByteBuffer getString(List struct, int i); } - public abstract class FromStructToByte extends FromStruct { - public abstract byte getByte(Struct struct); + public interface FromStructToByte extends FromStruct { + byte getByte(Struct struct); } - public abstract class FromListToByte extends FromList { - public abstract byte getByte(List list, int i); + public interface FromListToByte extends FromList { + byte getByte(List list, int i); } - public abstract class FromStructToShort extends FromStruct { - public abstract short getShort(Struct struct); + public interface FromStructToShort extends FromStruct { + short getShort(Struct struct); } - public abstract class FromListToShort extends FromList { - public abstract short getShort(List list, int i); + public interface FromListToShort extends FromList { + short getShort(List list, int i); } - public abstract class FromStructToInt extends FromStruct { - public abstract int getInt(Struct struct); + public interface FromStructToInt extends FromStruct { + int getInt(Struct struct); } - public abstract class FromListToInt extends FromList { - public abstract int getInt(List list, int i); + public interface FromListToInt extends FromList { + int getInt(List list, int i); } - public abstract class FromStructToLong extends FromStruct { - public abstract long getLong(Struct struct); + public interface FromStructToLong extends FromStruct { + long getLong(Struct struct); } - public abstract class FromListToLong extends FromList { - public abstract long getLong(List list, int i); + public interface FromListToLong extends FromList { + long getLong(List list, int i); } - public abstract class FromStructToBoolean extends FromStruct { - public abstract boolean getBoolean(Struct struct); + public interface FromStructToBoolean extends FromStruct { + boolean getBoolean(Struct struct); } - public abstract class FromListToBoolean extends FromList { - public abstract boolean getBoolean(List list, int i); + public interface FromListToBoolean extends FromList { + boolean getBoolean(List list, int i); } - public abstract class FromStructToFloat extends FromStruct { - public abstract float getFloat(Struct struct); + public interface FromStructToFloat extends FromStruct { + float getFloat(Struct struct); } - public abstract class FromListToFloat extends FromList { - public abstract float getFloat(List list, int i); + public interface FromListToFloat extends FromList { + float getFloat(List list, int i); } - public abstract class FromStructToDouble extends FromStruct { - public abstract double getDouble(Struct struct); + public interface FromStructToDouble extends FromStruct { + double getDouble(Struct struct); } - public abstract class FromListToDouble extends FromList { - public abstract double getDouble(List list, int i); + public interface FromListToDouble extends FromList { + double getDouble(List list, int i); } - public abstract class FromStructToStruct extends FromStruct { - public abstract java.util.List> getMembersGetters(); + public interface FromStructToStruct extends FromStruct { + java.util.List>> getMembersGetters(); - public abstract Struct getStruct(Struct struct); + Value getStruct(Struct struct); } - public abstract class FromListToStruct extends FromList { - public abstract java.util.List> getMembersGetters(); + public interface FromListToStruct extends FromList { + java.util.List>> getMembersGetters(); - public abstract Struct getStruct(List list, int i); + Value getStruct(List list, int i); } - public abstract class FromStructToList extends FromStruct { - public abstract FromList getElementGetter(); + public interface FromStructToList extends FromStruct { + FromList getElementGetter(); - public abstract List getList(Struct struct); + List getList(Struct struct); } - public abstract class FromListToList extends FromList { - public abstract FromList getElementGetter(); + public interface FromListToList extends FromList { + FromList getElementGetter(); - public abstract List getList(List list, int i); + Value getList(List list, int i); } - public abstract class FromStructToDict extends FromStruct { - public abstract FromDict getGetter(); + public interface FromStructToDict extends FromStruct { + FromDict getGetter(); - public abstract Dict getDict(Struct struct); + Dict getDict(Struct struct); } - public abstract class FromListToDict extends FromList { - public abstract FromDict getGetter(); + public interface FromListToDict extends FromList { + FromDict getGetter(); - public abstract Dict getDict(List list, int i); + Dict getDict(List list, int i); } - public abstract class FromStructToBigDecimal extends FromStruct { - public abstract BigDecimal getBigDecimal(Struct struct); + public interface FromStructToBigDecimal extends FromStruct { + BigDecimal getBigDecimal(Struct struct); } - public abstract class FromListToBigDecimal extends FromList { - public abstract BigDecimal getBigDecimal(List list, int i); + public interface FromListToBigDecimal extends FromList { + BigDecimal getBigDecimal(List list, int i); } } diff --git a/data-source/src/main/scala/tech/ytsaurus/spyt/format/YtOutputWriter.scala b/data-source/src/main/scala/tech/ytsaurus/spyt/format/YtOutputWriter.scala index 731f18dc..cb35ab7f 100644 --- a/data-source/src/main/scala/tech/ytsaurus/spyt/format/YtOutputWriter.scala +++ b/data-source/src/main/scala/tech/ytsaurus/spyt/format/YtOutputWriter.scala @@ -5,12 +5,11 @@ import org.apache.spark.executor.TaskMetricUpdater import org.apache.spark.metrics.yt.YtMetricsRegister import org.apache.spark.metrics.yt.YtMetricsRegister.ytMetricsSource._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.execution.datasources.OutputWriter import org.apache.spark.sql.types.StructType import org.slf4j.LoggerFactory import tech.ytsaurus.client.request.{TransactionalOptions, WriteSerializationContext, WriteTable} -import tech.ytsaurus.client.{ArrowWriteSerializationContext, CompoundClient, InternalRowYTGetters, TableWriter} +import tech.ytsaurus.client.{ArrowWriteSerializationContext, CompoundClient, TableWriter} import tech.ytsaurus.core.GUID import tech.ytsaurus.spyt.format.conf.SparkYtWriteConfiguration import tech.ytsaurus.spyt.format.conf.YtTableSparkSettings._ @@ -153,7 +152,6 @@ class YtOutputWriter(richPath: YPathEnriched, protected def initializeWriter(): TableWriter[InternalRow] = { val appendPath = richPath.withAttr("append", "true").toYPath log.debugLazy(s"Initialize new write: $appendPath, transaction: $transactionGuid") - val internalRowGetters = new InternalRowYTGetters() val writeSchemaConverter = WriteSchemaConverter(options) val request = WriteTable.builder[InternalRow]() .setPath(appendPath) @@ -162,10 +160,10 @@ class YtOutputWriter(richPath: YPathEnriched, if (!writeSchemaConverter.typeV3Format) { throw new RuntimeException("arrow writer is only supported with typeV3") } - new ArrowWriteSerializationContext[InternalRow, ArrayData, MapData, InternalRowYTGetters]( + new ArrowWriteSerializationContext[InternalRow]( schema.fields.zipWithIndex.map { case (field, i) => util.Map.entry(field.name, writeSchemaConverter.ytLogicalTypeV3(field).ytGettersFromStruct( - internalRowGetters, field.dataType, i + field.dataType, i )) }.toSeq.asJava ) diff --git a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalType.scala b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalType.scala index 98d2b751..9ad5a207 100644 --- a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalType.scala +++ b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalType.scala @@ -5,7 +5,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.spyt.types._ import org.apache.spark.sql.types import org.apache.spark.sql.types._ -import tech.ytsaurus.client.InternalRowYTGetters +import tech.ytsaurus.client.YTGetters import tech.ytsaurus.core.tables.ColumnValueType import tech.ytsaurus.spyt.serializers.SchemaConverter.MetadataFields import tech.ytsaurus.typeinfo.StructType.Member @@ -60,9 +60,9 @@ sealed trait YtLogicalType { def arrowSupported: Boolean = true - def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList + def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] - def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct + def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] } sealed trait YtLogicalTypeAlias { @@ -105,7 +105,7 @@ object YtLogicalType { import tech.ytsaurus.spyt.types.YTsaurusTypes.instance.sparkTypeFor case object Null extends AtomicYtLogicalType("null", 0x02, ColumnValueType.NULL, TiType.nullType(), NullType) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToNull { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToNull[ArrayData] { override def getTiType: TiType = tiType override def getSize(list: ArrayData): Int = list.numElements() @@ -113,8 +113,8 @@ object YtLogicalType { override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onEntity() } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = - new ytGetter.FromStructToNull { + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToNull[InternalRow] { override def getTiType: TiType = tiType override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onEntity() @@ -122,7 +122,7 @@ object YtLogicalType { } case object Int64 extends AtomicYtLogicalType("int64", 0x03, ColumnValueType.INT64, TiType.int64(), LongType) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToLong { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToLong[ArrayData] { override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -132,7 +132,7 @@ object YtLogicalType { override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(list.getLong(i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToLong { + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToLong[InternalRow] { override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) override def getTiType: TiType = tiType @@ -142,8 +142,8 @@ object YtLogicalType { } case object Uint64 extends AtomicYtLogicalType("uint64", 0x04, ColumnValueType.UINT64, TiType.uint64(), sparkTypeFor(TiType.uint64())) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = dataType match { - case decimalType: DecimalType => new ytGetter.FromListToLong { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = dataType match { + case decimalType: DecimalType => new YTGetters.FromListToLong[ArrayData] { override def getLong(list: ArrayData, i: Int): Long = list.getDecimal(i, decimalType.precision, decimalType.scale).toLong override def getSize(list: ArrayData): Int = list.numElements() @@ -152,7 +152,7 @@ object YtLogicalType { override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getLong(list, i)) } - case _ => new ytGetter.FromListToLong { + case _ => new YTGetters.FromListToLong[ArrayData] { override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -163,15 +163,15 @@ object YtLogicalType { } } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = dataType match { - case decimalType: DecimalType => new ytGetter.FromStructToLong { + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = dataType match { + case decimalType: DecimalType => new YTGetters.FromStructToLong[InternalRow] { override def getLong(struct: InternalRow): Long = struct.getDecimal(ordinal, decimalType.precision, decimalType.scale).toLong override def getTiType: TiType = tiType override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getLong(struct)) } - case _ => new ytGetter.FromStructToLong { + case _ => new YTGetters.FromStructToLong[InternalRow] { override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) override def getTiType: TiType = tiType @@ -185,7 +185,7 @@ object YtLogicalType { "float", 0x05, ColumnValueType.DOUBLE, TiType.floatType(), TopInnerSparkTypes(FloatType, DoubleType), Seq.empty, arrowSupported = false, ) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToFloat { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToFloat[ArrayData] { override def getFloat(list: ArrayData, i: Int): Float = list.getFloat(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -195,7 +195,7 @@ object YtLogicalType { override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onDouble(getFloat(list, i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToFloat { + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToFloat[InternalRow] { override def getFloat(struct: InternalRow): Float = struct.getFloat(ordinal) override def getTiType: TiType = tiType @@ -205,7 +205,7 @@ object YtLogicalType { } case object Double extends AtomicYtLogicalType("double", 0x05, ColumnValueType.DOUBLE, TiType.doubleType(), DoubleType) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToDouble { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToDouble[ArrayData] { override def getDouble(list: ArrayData, i: Int): Double = list.getDouble(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -215,7 +215,7 @@ object YtLogicalType { override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onDouble(getDouble(list, i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToDouble { + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToDouble[InternalRow] { override def getDouble(struct: InternalRow): Double = struct.getDouble(ordinal) override def getTiType: TiType = tiType @@ -225,7 +225,7 @@ object YtLogicalType { } case object Boolean extends AtomicYtLogicalType("boolean", 0x06, ColumnValueType.BOOLEAN, TiType.bool(), BooleanType, Seq("bool")) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToBoolean { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToBoolean[ArrayData] { override def getBoolean(list: ArrayData, i: Int): Boolean = list.getBoolean(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -236,7 +236,7 @@ object YtLogicalType { ysonConsumer.onBoolean(list.getBoolean(i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToBoolean { + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToBoolean[InternalRow] { override def getBoolean(struct: InternalRow): Boolean = struct.getBoolean(ordinal) override def getTiType: TiType = tiType @@ -253,7 +253,7 @@ object YtLogicalType { } case object String extends AtomicYtLogicalType("string", 0x10, ColumnValueType.STRING, TiType.string(), StringType) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToString { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToString[ArrayData] { override def getString(list: ArrayData, i: Int): ByteBuffer = list.getUTF8String(i).getByteBuffer override def getSize(list: ArrayData): Int = list.numElements() @@ -266,7 +266,7 @@ object YtLogicalType { } } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToString { + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToString[InternalRow] { override def getString(struct: InternalRow): ByteBuffer = struct.getUTF8String(ordinal).getByteBuffer override def getTiType: TiType = tiType @@ -285,7 +285,7 @@ object YtLogicalType { if (inner) alias.name else "string" } - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToString { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToString[ArrayData] { override def getString(list: ArrayData, i: Int): ByteBuffer = ByteBuffer.wrap(list.getBinary(i)) override def getSize(list: ArrayData): Int = list.numElements() @@ -300,7 +300,7 @@ object YtLogicalType { } } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToString { + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToString[InternalRow] { override def getString(struct: InternalRow): ByteBuffer = ByteBuffer.wrap(struct.getBinary(ordinal)) override def getTiType: TiType = tiType @@ -317,7 +317,7 @@ object YtLogicalType { case object Any extends AtomicYtLogicalType("any", 0x11, ColumnValueType.ANY, TiType.yson(), sparkTypeFor(TiType.yson()), Seq("yson")) { override def nullable: Boolean = true - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToYson { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToYson[ArrayData] { override def getSize(list: ArrayData): Int = list.numElements() override def getTiType: TiType = tiType @@ -326,7 +326,7 @@ object YtLogicalType { YTreeBinarySerializer.deserialize(new ByteArrayInputStream(list.getBinary(i)), ysonConsumer) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToYson { + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToYson[InternalRow] { override def getTiType: TiType = tiType override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = @@ -335,7 +335,7 @@ object YtLogicalType { } case object Int8 extends AtomicYtLogicalType("int8", 0x1000, ColumnValueType.INT64, TiType.int8(), ByteType) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToByte { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToByte[ArrayData] { override def getByte(list: ArrayData, i: Int): Byte = list.getByte(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -345,7 +345,7 @@ object YtLogicalType { override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(list.getByte(i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToByte { + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToByte[InternalRow] { override def getByte(struct: InternalRow): Byte = struct.getByte(ordinal) override def getTiType: TiType = tiType @@ -355,7 +355,7 @@ object YtLogicalType { } case object Uint8 extends AtomicYtLogicalType("uint8", 0x1001, ColumnValueType.INT64, TiType.uint8(), ShortType) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToByte { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToByte[ArrayData] { override def getByte(list: ArrayData, i: Int): Byte = list.getShort(i).toByte override def getSize(list: ArrayData): Int = list.numElements() @@ -365,7 +365,7 @@ object YtLogicalType { override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getByte(list, i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToByte { + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToByte[InternalRow] { override def getByte(struct: InternalRow): Byte = struct.getShort(ordinal).toByte override def getTiType: TiType = tiType @@ -375,7 +375,7 @@ object YtLogicalType { } case object Int16 extends AtomicYtLogicalType("int16", 0x1003, ColumnValueType.INT64, TiType.int16(), ShortType) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToShort { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToShort[ArrayData] { override def getShort(list: ArrayData, i: Int): Short = list.getShort(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -385,7 +385,7 @@ object YtLogicalType { override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(list.getShort(i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToShort { + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToShort[InternalRow] { override def getShort(struct: InternalRow): Short = struct.getShort(ordinal) override def getTiType: TiType = tiType @@ -395,7 +395,7 @@ object YtLogicalType { } case object Uint16 extends AtomicYtLogicalType("uint16", 0x1004, ColumnValueType.INT64, TiType.uint16(), IntegerType) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToShort { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToShort[ArrayData] { override def getShort(list: ArrayData, i: Int): Short = list.getInt(i).toShort override def getSize(list: ArrayData): Int = list.numElements() @@ -405,7 +405,7 @@ object YtLogicalType { override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getShort(list, i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToShort { + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToShort[InternalRow] { override def getShort(struct: InternalRow): Short = struct.getInt(ordinal).toShort override def getTiType: TiType = tiType @@ -415,7 +415,7 @@ object YtLogicalType { } case object Int32 extends AtomicYtLogicalType("int32", 0x1005, ColumnValueType.INT64, TiType.int32(), IntegerType) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToInt { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToInt[ArrayData] { override def getInt(list: ArrayData, i: Int): Int = list.getInt(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -425,7 +425,7 @@ object YtLogicalType { override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(list.getInt(i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToInt { + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToInt[InternalRow] { override def getInt(struct: InternalRow): Int = struct.getInt(ordinal) override def getTiType: TiType = tiType @@ -435,7 +435,7 @@ object YtLogicalType { } case object Uint32 extends AtomicYtLogicalType("uint32", 0x1006, ColumnValueType.INT64, TiType.uint32(), LongType) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToInt { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToInt[ArrayData] { override def getInt(list: ArrayData, i: Int): Int = list.getLong(i).toInt override def getSize(list: ArrayData): Int = list.numElements() @@ -445,7 +445,7 @@ object YtLogicalType { override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getInt(list, i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToInt { + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToInt[InternalRow] { override def getInt(struct: InternalRow): Int = struct.getLong(ordinal).toInt override def getTiType: TiType = tiType @@ -455,7 +455,7 @@ object YtLogicalType { } case object Utf8 extends AtomicYtLogicalType("utf8", 0x1007, ColumnValueType.STRING, TiType.utf8(), StringType) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToString { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToString[ArrayData] { override def getString(list: ArrayData, i: Int): ByteBuffer = list.getUTF8String(i).getByteBuffer override def getSize(list: ArrayData): Int = list.numElements() @@ -468,7 +468,7 @@ object YtLogicalType { } } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToString { + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToString[InternalRow] { override def getString(struct: InternalRow): ByteBuffer = struct.getUTF8String(ordinal).getByteBuffer override def getTiType: TiType = tiType @@ -482,7 +482,7 @@ object YtLogicalType { // Unsupported types are listed here: yt/yt/client/arrow/arrow_row_stream_encoder.cpp case object Date extends AtomicYtLogicalType("date", 0x1008, ColumnValueType.UINT64, TiType.date(), DateType, arrowSupported = false) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToInt { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToInt[ArrayData] { override def getInt(list: ArrayData, i: Int): Int = list.getInt(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -493,7 +493,7 @@ object YtLogicalType { ysonConsumer.onUnsignedInteger(getInt(list, i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToInt { + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToInt[InternalRow] { override def getInt(struct: InternalRow): Int = struct.getInt(ordinal) override def getTiType: TiType = tiType @@ -504,7 +504,7 @@ object YtLogicalType { } case object Datetime extends AtomicYtLogicalType("datetime", 0x1009, ColumnValueType.UINT64, TiType.datetime(), new DatetimeType(), arrowSupported = false) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToLong { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToLong[ArrayData] { override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -515,7 +515,7 @@ object YtLogicalType { ysonConsumer.onUnsignedInteger(getLong(list, i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToLong { + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToLong[InternalRow] { override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) override def getTiType: TiType = tiType @@ -526,7 +526,7 @@ object YtLogicalType { } case object Timestamp extends AtomicYtLogicalType("timestamp", 0x100a, ColumnValueType.UINT64, TiType.timestamp(), TimestampType, arrowSupported = false) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToLong { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToLong[ArrayData] { override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -537,7 +537,7 @@ object YtLogicalType { ysonConsumer.onUnsignedInteger(getLong(list, i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToLong { + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToLong[InternalRow] { override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) override def getTiType: TiType = tiType @@ -548,7 +548,7 @@ object YtLogicalType { } case object Interval extends AtomicYtLogicalType("interval", 0x100b, ColumnValueType.INT64, TiType.interval(), LongType, arrowSupported = false) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToLong { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToLong[ArrayData] { override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -559,7 +559,7 @@ object YtLogicalType { ysonConsumer.onInteger(getLong(list, i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToLong { + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToLong[InternalRow] { override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) override def getTiType: TiType = tiType @@ -570,7 +570,7 @@ object YtLogicalType { } case object Void extends AtomicYtLogicalType("void", 0x100c, ColumnValueType.NULL, TiType.voidType(), NullType) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToNull { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToNull[ArrayData] { override def getTiType: TiType = tiType override def getSize(list: ArrayData): Int = list.numElements() @@ -578,8 +578,8 @@ object YtLogicalType { override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onEntity() } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = - new ytGetter.FromStructToNull { + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToNull[InternalRow] { override def getTiType: TiType = tiType override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onEntity() @@ -587,7 +587,7 @@ object YtLogicalType { } case object Date32 extends AtomicYtLogicalType("date32", 0x1018, ColumnValueType.INT64, TiType.date32(), new Date32Type(), arrowSupported = false) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToInt { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToInt[ArrayData] { override def getInt(list: ArrayData, i: Int): Int = list.getInt(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -598,7 +598,7 @@ object YtLogicalType { ysonConsumer.onUnsignedInteger(getInt(list, i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToInt { + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToInt[InternalRow] { override def getInt(struct: InternalRow): Int = struct.getInt(ordinal) override def getTiType: TiType = tiType @@ -609,7 +609,7 @@ object YtLogicalType { } case object Datetime64 extends AtomicYtLogicalType("datetime64", 0x1019, ColumnValueType.INT64, TiType.datetime64(), new Datetime64Type(), arrowSupported = false) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToLong { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToLong[ArrayData] { override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -620,7 +620,7 @@ object YtLogicalType { ysonConsumer.onUnsignedInteger(getLong(list, i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToLong { + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToLong[InternalRow] { override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) override def getTiType: TiType = tiType @@ -631,7 +631,7 @@ object YtLogicalType { } case object Timestamp64 extends AtomicYtLogicalType("timestamp64", 0x101a, ColumnValueType.INT64, TiType.timestamp64(), new Timestamp64Type(), arrowSupported = false) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToLong { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToLong[ArrayData] { override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -642,7 +642,7 @@ object YtLogicalType { ysonConsumer.onUnsignedInteger(getLong(list, i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToLong { + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToLong[InternalRow] { override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) override def getTiType: TiType = tiType @@ -653,7 +653,7 @@ object YtLogicalType { } case object Interval64 extends AtomicYtLogicalType("interval64", 0x101b, ColumnValueType.INT64, TiType.interval64(), new Interval64Type(), arrowSupported = false) { - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToLong { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToLong[ArrayData] { override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) override def getSize(list: ArrayData): Int = list.numElements() @@ -664,7 +664,7 @@ object YtLogicalType { ysonConsumer.onInteger(getLong(list, i)) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToLong { + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToLong[InternalRow] { override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) override def getTiType: TiType = tiType @@ -681,7 +681,7 @@ object YtLogicalType { override def tiType: TiType = TiType.decimal(precision, scale) - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToBigDecimal { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToBigDecimal[ArrayData] { override def getBigDecimal(list: ArrayData, i: Int): java.math.BigDecimal = list.getDecimal(i, decimalType.precision, decimalType.scale).toJavaBigDecimal.setScale(scale) @@ -695,7 +695,7 @@ object YtLogicalType { } } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToBigDecimal { + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToBigDecimal[InternalRow] { override def getBigDecimal(struct: InternalRow): java.math.BigDecimal = struct.getDecimal(ordinal, decimalType.precision, decimalType.scale).toJavaBigDecimal.setScale(scale) @@ -727,10 +727,10 @@ object YtLogicalType { override def arrowSupported: Boolean = inner.arrowSupported - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToOptional { - private val notEmptyGetter = inner.ytGettersFromList(ytGetter, dataType) + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToOptional[ArrayData] { + private val notEmptyGetter = inner.ytGettersFromList(dataType) - override def getNotEmptyGetter: ytGetter.FromList = notEmptyGetter + override def getNotEmptyGetter: YTGetters.FromList[ArrayData] = notEmptyGetter override def isEmpty(list: ArrayData, i: Int): Boolean = list.isNullAt(i) @@ -752,10 +752,10 @@ object YtLogicalType { } } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToOptional { - private val notEmptyGetter = inner.ytGettersFromStruct(ytGetter, dataType, ordinal) + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToOptional[InternalRow] { + private val notEmptyGetter = inner.ytGettersFromStruct(dataType, ordinal) - override def getNotEmptyGetter: ytGetter.FromStruct = notEmptyGetter + override def getNotEmptyGetter: YTGetters.FromStruct[InternalRow] = notEmptyGetter override def isEmpty(struct: InternalRow): Boolean = struct.isNullAt(ordinal) @@ -788,24 +788,25 @@ object YtLogicalType { MapType(dictKey.sparkType.innerLevel, dictValue.sparkType.innerLevel, dictValue.nullable) ) - private def newGetter(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromDict = new ytGetter.FromDict { - private val keyGetter = dictKey.ytGettersFromList(ytGetter, dataType.asInstanceOf[MapType].keyType) - private val valueGetter = dictValue.ytGettersFromList(ytGetter, dataType.asInstanceOf[MapType].valueType) + private def newGetter(dataType: DataType): YTGetters.FromDict[MapData, ArrayData, ArrayData] = + new YTGetters.FromDict[MapData, ArrayData, ArrayData] { + private val keyGetter = dictKey.ytGettersFromList(dataType.asInstanceOf[MapType].keyType) + private val valueGetter = dictValue.ytGettersFromList(dataType.asInstanceOf[MapType].valueType) - override def getKeyGetter: ytGetter.FromList = keyGetter + override def getKeyGetter: YTGetters.FromList[ArrayData] = keyGetter - override def getValueGetter: ytGetter.FromList = valueGetter + override def getValueGetter: YTGetters.FromList[ArrayData] = valueGetter - override def getSize(dict: MapData): Int = dict.numElements() + override def getSize(dict: MapData): Int = dict.numElements() - override def getKeys(dict: MapData): ArrayData = dict.keyArray() + override def getKeys(dict: MapData): ArrayData = dict.keyArray() - override def getValues(dict: MapData): ArrayData = dict.valueArray() + override def getValues(dict: MapData): ArrayData = dict.valueArray() - override def getTiType: TiType = tiType - } + override def getTiType: TiType = tiType + } - def newYsonSerializer(getter: InternalRowYTGetters#FromDict): (MapData, YsonConsumer) => Unit = { + def newYsonSerializer(getter: YTGetters.FromDict[MapData, ArrayData, ArrayData]): (MapData, YsonConsumer) => Unit = { val keyGetter = getter.getKeyGetter val valueGetter = getter.getValueGetter (dict, ysonConsumer) => { @@ -829,35 +830,37 @@ object YtLogicalType { override def alias: CompositeYtLogicalTypeAlias = Dict - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToDict { - private val getter = newGetter(ytGetter, dataType) - private val ysonSerializer = newYsonSerializer(getter) + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToDict[ArrayData, MapData, ArrayData, ArrayData] { + private val getter = newGetter(dataType) + private val ysonSerializer = newYsonSerializer(getter) - override def getGetter(): ytGetter.FromDict = getter + override def getGetter(): YTGetters.FromDict[MapData, ArrayData, ArrayData] = getter - override def getTiType: TiType = tiType + override def getTiType: TiType = tiType - override def getSize(list: ArrayData): Int = list.numElements() + override def getSize(list: ArrayData): Int = list.numElements() - override def getDict(list: ArrayData, i: Int): MapData = list.getMap(i) + override def getDict(list: ArrayData, i: Int): MapData = list.getMap(i) - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = - ysonSerializer(list.getMap(i), ysonConsumer) - } + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonSerializer(list.getMap(i), ysonConsumer) + } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToDict { - private val getter = newGetter(ytGetter, dataType) - private val ysonSerializer = newYsonSerializer(getter) + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToDict[InternalRow, MapData, ArrayData, ArrayData] { + private val getter = newGetter(dataType) + private val ysonSerializer = newYsonSerializer(getter) - override def getGetter(): ytGetter.FromDict = getter + override def getGetter(): YTGetters.FromDict[MapData, ArrayData, ArrayData] = getter - override def getDict(struct: InternalRow): MapData = struct.getMap(ordinal) + override def getDict(struct: InternalRow): MapData = struct.getMap(ordinal) - override def getTiType: TiType = tiType + override def getTiType: TiType = tiType - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = - ysonSerializer(struct.getMap(ordinal), ysonConsumer) - } + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonSerializer(struct.getMap(ordinal), ysonConsumer) + } } case object Dict extends CompositeYtLogicalTypeAlias(TypeName.Dict.getWireName) @@ -870,47 +873,49 @@ object YtLogicalType { override def alias: CompositeYtLogicalTypeAlias = Array - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToList { - val elementGetter: ytGetter.FromList = inner.ytGettersFromList(ytGetter, dataType.asInstanceOf[ArrayType].elementType) + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToList[ArrayData, ArrayData] { + val elementGetter: YTGetters.FromList[ArrayData] = inner.ytGettersFromList(dataType.asInstanceOf[ArrayType].elementType) - override def getSize(list: ArrayData): Int = list.numElements() + override def getSize(list: ArrayData): Int = list.numElements() - override def getTiType: TiType = tiType + override def getTiType: TiType = tiType - override def getElementGetter: ytGetter.FromList = elementGetter + override def getElementGetter: YTGetters.FromList[ArrayData] = elementGetter - override def getList(list: ArrayData, i: Int): ArrayData = list.getArray(i) + override def getList(list: ArrayData, i: Int): ArrayData = list.getArray(i) - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = { - val value = list.getArray(i) - ysonConsumer.onBeginList() - for (j <- 0 until value.numElements()) { - ysonConsumer.onListItem() - elementGetter.getYson(value, j, ysonConsumer) + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = { + val value = list.getArray(i) + ysonConsumer.onBeginList() + for (j <- 0 until value.numElements()) { + ysonConsumer.onListItem() + elementGetter.getYson(value, j, ysonConsumer) + } + ysonConsumer.onEndList() } - ysonConsumer.onEndList() } - } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToList { - val elementGetter: ytGetter.FromList = inner.ytGettersFromList(ytGetter, dataType.asInstanceOf[ArrayType].elementType) + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToList[InternalRow, ArrayData] { + val elementGetter: YTGetters.FromList[ArrayData] = inner.ytGettersFromList(dataType.asInstanceOf[ArrayType].elementType) - override def getElementGetter: ytGetter.FromList = elementGetter + override def getElementGetter: YTGetters.FromList[ArrayData] = elementGetter - override def getList(struct: InternalRow): ArrayData = struct.getArray(ordinal) + override def getList(struct: InternalRow): ArrayData = struct.getArray(ordinal) - override def getTiType: TiType = tiType + override def getTiType: TiType = tiType - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = { - val value = struct.getArray(ordinal) - ysonConsumer.onBeginList() - for (j <- 0 until value.numElements()) { - ysonConsumer.onListItem() - elementGetter.getYson(value, j, ysonConsumer) + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = { + val value = struct.getArray(ordinal) + ysonConsumer.onBeginList() + for (j <- 0 until value.numElements()) { + ysonConsumer.onListItem() + elementGetter.getYson(value, j, ysonConsumer) + } + ysonConsumer.onEndList() } - ysonConsumer.onEndList() } - } } case object Array extends CompositeYtLogicalTypeAlias(TypeName.List.getWireName) @@ -927,13 +932,13 @@ object YtLogicalType { override def alias: CompositeYtLogicalTypeAlias = Struct - def newMembersGetters(ytGetter: InternalRowYTGetters, dataType: DataType): java.util.List[java.util.Map.Entry[String, ytGetter.FromStruct]] = + def newMembersGetters(dataType: DataType): java.util.List[java.util.Map.Entry[String, YTGetters.FromStruct[InternalRow]]] = fields.zip(dataType.asInstanceOf[StructType].fields).zipWithIndex.map { case ((field, structField), i) => - java.util.Map.entry(field._1, field._2.ytGettersFromStruct(ytGetter, structField.dataType, i)) + java.util.Map.entry(field._1, field._2.ytGettersFromStruct(structField.dataType, i)) }.asJava - def yson(ytGetter: InternalRowYTGetters)( - membersGetters: java.util.List[java.util.Map.Entry[String, ytGetter.FromStruct]], + def yson( + membersGetters: java.util.List[java.util.Map.Entry[String, YTGetters.FromStruct[InternalRow]]], internalRow: InternalRow, ysonConsumer: YsonConsumer, ): Unit = { ysonConsumer.onBeginList() @@ -944,10 +949,10 @@ object YtLogicalType { ysonConsumer.onEndList() } - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToStruct { - private val membersGetters = newMembersGetters(ytGetter, dataType) + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToStruct[ArrayData, InternalRow] { + private val membersGetters = newMembersGetters(dataType) - override def getMembersGetters(): java.util.List[java.util.Map.Entry[String, ytGetter.FromStruct]] = + override def getMembersGetters(): java.util.List[java.util.Map.Entry[String, YTGetters.FromStruct[InternalRow]]] = membersGetters override def getStruct(list: ArrayData, i: Int): InternalRow = list.getStruct(i, fields.size) @@ -957,13 +962,13 @@ object YtLogicalType { override def getTiType: TiType = tiType override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = - yson(ytGetter)(membersGetters, list.getStruct(i, membersGetters.size()), ysonConsumer) + yson(membersGetters, list.getStruct(i, membersGetters.size()), ysonConsumer) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToStruct { - private val membersGetters = newMembersGetters(ytGetter, dataType) + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToStruct[InternalRow, InternalRow] { + private val membersGetters = newMembersGetters(dataType) - override def getMembersGetters(): java.util.List[java.util.Map.Entry[String, ytGetter.FromStruct]] = + override def getMembersGetters(): java.util.List[java.util.Map.Entry[String, YTGetters.FromStruct[InternalRow]]] = membersGetters override def getStruct(struct: InternalRow): InternalRow = struct.getStruct(ordinal, fields.size) @@ -971,7 +976,7 @@ object YtLogicalType { override def getTiType: TiType = tiType override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = - yson(ytGetter)(membersGetters, struct.getStruct(ordinal, membersGetters.size()), ysonConsumer) + yson(membersGetters, struct.getStruct(ordinal, membersGetters.size()), ysonConsumer) } } @@ -979,6 +984,7 @@ object YtLogicalType { case class Tuple(elements: Seq[(YtLogicalType, Metadata)]) extends CompositeYtLogicalType { private val entries = elements.zipWithIndex.map { case ((ytType, _), index) => (s"_${1 + index}", ytType) } + override def sparkType: SparkType = SingleSparkType(StructType(elements.zipWithIndex .map { case ((ytType, meta), index) => getStructField(s"_${1 + index}", ytType, meta, topLevel = false) })) @@ -990,13 +996,13 @@ object YtLogicalType { override def alias: CompositeYtLogicalTypeAlias = Tuple - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToStruct { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToStruct[ArrayData, InternalRow] { private val membersGetters = entries.zip(dataType.asInstanceOf[types.StructType]).zipWithIndex.map { case (((name, logicalType), structField), i) => - java.util.Map.entry(name, logicalType.ytGettersFromStruct(ytGetter, structField.dataType, i)) + java.util.Map.entry(name, logicalType.ytGettersFromStruct(structField.dataType, i)) }.asJava - override def getMembersGetters(): java.util.List[java.util.Map.Entry[String, ytGetter.FromStruct]] = membersGetters + override def getMembersGetters(): java.util.List[java.util.Map.Entry[String, YTGetters.FromStruct[InternalRow]]] = membersGetters override def getStruct(list: ArrayData, i: Int): InternalRow = list.getStruct(i, elements.size) @@ -1015,13 +1021,13 @@ object YtLogicalType { } } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToStruct { + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToStruct[InternalRow, InternalRow] { private val membersGetters = entries.zip(dataType.asInstanceOf[types.StructType]).zipWithIndex.map { case (((name, logicalType), structField), i) => - java.util.Map.entry(name, logicalType.ytGettersFromStruct(ytGetter, structField.dataType, i)) + java.util.Map.entry(name, logicalType.ytGettersFromStruct(structField.dataType, i)) }.asJava - override def getMembersGetters(): java.util.List[java.util.Map.Entry[String, ytGetter.FromStruct]] = membersGetters + override def getMembersGetters(): java.util.List[java.util.Map.Entry[String, YTGetters.FromStruct[InternalRow]]] = membersGetters override def getStruct(struct: InternalRow): InternalRow = struct.getStruct(ordinal, elements.size) @@ -1048,16 +1054,16 @@ object YtLogicalType { override def alias: CompositeYtLogicalTypeAlias = Tagged - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = inner.ytGettersFromList(ytGetter, dataType) + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = inner.ytGettersFromList(dataType) - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = inner.ytGettersFromStruct(ytGetter, dataType, ordinal) + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = inner.ytGettersFromStruct(dataType, ordinal) } case object Tagged extends CompositeYtLogicalTypeAlias(TypeName.Tagged.getWireName) - private class VariantGetter(fields: Seq[YtLogicalType], ytGetter: InternalRowYTGetters, dataType: DataType) { + private class VariantGetter(fields: Seq[YtLogicalType], dataType: DataType) { private val getters = fields.zip(dataType.asInstanceOf[StructType].fields).zipWithIndex.map { - case ((field, structField), i) => field.ytGettersFromStruct(ytGetter, structField.dataType, i) + case ((field, structField), i) => field.ytGettersFromStruct(structField.dataType, i) } def get(row: InternalRow, ysonConsumer: YsonConsumer): Unit = { @@ -1091,8 +1097,8 @@ object YtLogicalType { override def alias: CompositeYtLogicalTypeAlias = Variant - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToYson { - val getter = new VariantGetter(fields.map(_._2), ytGetter, dataType) + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToYson[ArrayData] { + val getter = new VariantGetter(fields.map(_._2), dataType) override def getSize(list: ArrayData): Int = list.numElements() @@ -1102,8 +1108,8 @@ object YtLogicalType { getter.get(list.getStruct(i, fields.size), ysonConsumer) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToYson { - val getter = new VariantGetter(fields.map(_._2), ytGetter, dataType) + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToYson[InternalRow] { + val getter = new VariantGetter(fields.map(_._2), dataType) override def getTiType: TiType = tiType @@ -1127,8 +1133,8 @@ object YtLogicalType { override def alias: CompositeYtLogicalTypeAlias = Variant - override def ytGettersFromList(ytGetter: InternalRowYTGetters, dataType: DataType): ytGetter.FromList = new ytGetter.FromListToYson { - val getter = new VariantGetter(fields.map(_._1), ytGetter, dataType) + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToYson[ArrayData] { + val getter = new VariantGetter(fields.map(_._1), dataType) override def getSize(list: ArrayData): Int = list.numElements() @@ -1138,8 +1144,8 @@ object YtLogicalType { getter.get(list.getStruct(i, fields.size), ysonConsumer) } - override def ytGettersFromStruct(ytGetter: InternalRowYTGetters, dataType: DataType, ordinal: Int): ytGetter.FromStruct = new ytGetter.FromStructToYson { - val getter = new VariantGetter(fields.map(_._1), ytGetter, dataType) + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToYson[InternalRow] { + val getter = new VariantGetter(fields.map(_._1), dataType) override def getTiType: TiType = tiType From d230beb5620208207e7d9da6cf211c0c2971d1d4 Mon Sep 17 00:00:00 2001 From: Nikita Sokolov Date: Thu, 21 Nov 2024 12:10:58 +0100 Subject: [PATCH 10/12] YtLogicalType FromList and FromStruct --- .../spyt/serializers/YtLogicalType.scala | 959 ++++++++---------- 1 file changed, 415 insertions(+), 544 deletions(-) diff --git a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalType.scala b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalType.scala index 9ad5a207..85fa6ac8 100644 --- a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalType.scala +++ b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalType.scala @@ -60,8 +60,18 @@ sealed trait YtLogicalType { def arrowSupported: Boolean = true + trait FromList extends YTGetters.FromList[ArrayData] { + override def getTiType: TiType = tiType + + override def getSize(list: ArrayData): Int = list.numElements() + } + def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] + trait FromStruct extends YTGetters.FromStruct[InternalRow] { + override def getTiType: TiType = tiType + } + def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] } @@ -105,77 +115,56 @@ object YtLogicalType { import tech.ytsaurus.spyt.types.YTsaurusTypes.instance.sparkTypeFor case object Null extends AtomicYtLogicalType("null", 0x02, ColumnValueType.NULL, TiType.nullType(), NullType) { - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToNull[ArrayData] { - override def getTiType: TiType = tiType - - override def getSize(list: ArrayData): Int = list.numElements() - - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onEntity() - } + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToNull[ArrayData] with FromList { + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onEntity() + } override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = - new YTGetters.FromStructToNull[InternalRow] { - override def getTiType: TiType = tiType - + new YTGetters.FromStructToNull[InternalRow] with FromStruct { override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onEntity() } } case object Int64 extends AtomicYtLogicalType("int64", 0x03, ColumnValueType.INT64, TiType.int64(), LongType) { - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToLong[ArrayData] { - override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) - - override def getSize(list: ArrayData): Int = list.numElements() - - override def getTiType: TiType = tiType - - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(list.getLong(i)) - } + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToLong[ArrayData] with FromList { + override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) - override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToLong[InternalRow] { - override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(list.getLong(i)) + } - override def getTiType: TiType = tiType + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToLong[InternalRow] with FromStruct { + override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(struct.getLong(ordinal)) - } + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(struct.getLong(ordinal)) + } } case object Uint64 extends AtomicYtLogicalType("uint64", 0x04, ColumnValueType.UINT64, TiType.uint64(), sparkTypeFor(TiType.uint64())) { override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = dataType match { - case decimalType: DecimalType => new YTGetters.FromListToLong[ArrayData] { + case decimalType: DecimalType => new YTGetters.FromListToLong[ArrayData] with FromList { override def getLong(list: ArrayData, i: Int): Long = list.getDecimal(i, decimalType.precision, decimalType.scale).toLong - override def getSize(list: ArrayData): Int = list.numElements() - - override def getTiType: TiType = tiType - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getLong(list, i)) } - case _ => new YTGetters.FromListToLong[ArrayData] { + case _ => new YTGetters.FromListToLong[ArrayData] with FromList { override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) - override def getSize(list: ArrayData): Int = list.numElements() - - override def getTiType: TiType = tiType - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getLong(list, i)) } } override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = dataType match { - case decimalType: DecimalType => new YTGetters.FromStructToLong[InternalRow] { + case decimalType: DecimalType => new YTGetters.FromStructToLong[InternalRow] with FromStruct { override def getLong(struct: InternalRow): Long = struct.getDecimal(ordinal, decimalType.precision, decimalType.scale).toLong - override def getTiType: TiType = tiType - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getLong(struct)) } - case _ => new YTGetters.FromStructToLong[InternalRow] { + case _ => new YTGetters.FromStructToLong[InternalRow] with FromStruct { override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) - override def getTiType: TiType = tiType - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getLong(struct)) } } @@ -185,65 +174,53 @@ object YtLogicalType { "float", 0x05, ColumnValueType.DOUBLE, TiType.floatType(), TopInnerSparkTypes(FloatType, DoubleType), Seq.empty, arrowSupported = false, ) { - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToFloat[ArrayData] { - override def getFloat(list: ArrayData, i: Int): Float = list.getFloat(i) - - override def getSize(list: ArrayData): Int = list.numElements() - - override def getTiType: TiType = tiType - - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onDouble(getFloat(list, i)) - } + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToFloat[ArrayData] with FromList { + override def getFloat(list: ArrayData, i: Int): Float = list.getFloat(i) - override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToFloat[InternalRow] { - override def getFloat(struct: InternalRow): Float = struct.getFloat(ordinal) + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onDouble(getFloat(list, i)) + } - override def getTiType: TiType = tiType + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToFloat[InternalRow] with FromStruct { + override def getFloat(struct: InternalRow): Float = struct.getFloat(ordinal) - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onDouble(getFloat(struct)) - } + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onDouble(getFloat(struct)) + } } case object Double extends AtomicYtLogicalType("double", 0x05, ColumnValueType.DOUBLE, TiType.doubleType(), DoubleType) { - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToDouble[ArrayData] { - override def getDouble(list: ArrayData, i: Int): Double = list.getDouble(i) - - override def getSize(list: ArrayData): Int = list.numElements() - - override def getTiType: TiType = tiType - - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onDouble(getDouble(list, i)) - } + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToDouble[ArrayData] with FromList { + override def getDouble(list: ArrayData, i: Int): Double = list.getDouble(i) - override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToDouble[InternalRow] { - override def getDouble(struct: InternalRow): Double = struct.getDouble(ordinal) + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onDouble(getDouble(list, i)) + } - override def getTiType: TiType = tiType + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToDouble[InternalRow] with FromStruct { + override def getDouble(struct: InternalRow): Double = struct.getDouble(ordinal) - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onDouble(getDouble(struct)) - } + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onDouble(getDouble(struct)) + } } case object Boolean extends AtomicYtLogicalType("boolean", 0x06, ColumnValueType.BOOLEAN, TiType.bool(), BooleanType, Seq("bool")) { - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToBoolean[ArrayData] { - override def getBoolean(list: ArrayData, i: Int): Boolean = list.getBoolean(i) - - override def getSize(list: ArrayData): Int = list.numElements() - - override def getTiType: TiType = tiType - - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = - ysonConsumer.onBoolean(list.getBoolean(i)) - } + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToBoolean[ArrayData] with FromList { + override def getBoolean(list: ArrayData, i: Int): Boolean = list.getBoolean(i) - override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToBoolean[InternalRow] { - override def getBoolean(struct: InternalRow): Boolean = struct.getBoolean(ordinal) + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onBoolean(list.getBoolean(i)) + } - override def getTiType: TiType = tiType + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToBoolean[InternalRow] with FromStruct { + override def getBoolean(struct: InternalRow): Boolean = struct.getBoolean(ordinal) - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = - ysonConsumer.onBoolean(struct.getBoolean(ordinal)) - } + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onBoolean(struct.getBoolean(ordinal)) + } } private def getBytes(byteBuffer: ByteBuffer): scala.Array[Byte] = { @@ -253,29 +230,25 @@ object YtLogicalType { } case object String extends AtomicYtLogicalType("string", 0x10, ColumnValueType.STRING, TiType.string(), StringType) { - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToString[ArrayData] { - override def getString(list: ArrayData, i: Int): ByteBuffer = list.getUTF8String(i).getByteBuffer - - override def getSize(list: ArrayData): Int = list.numElements() - - override def getTiType: TiType = tiType + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToString[ArrayData] with FromList { + override def getString(list: ArrayData, i: Int): ByteBuffer = list.getUTF8String(i).getByteBuffer - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = { - val bytes = getBytes(getString(list, i)) - ysonConsumer.onString(bytes, 0, bytes.length) + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = { + val bytes = getBytes(getString(list, i)) + ysonConsumer.onString(bytes, 0, bytes.length) + } } - } - override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToString[InternalRow] { - override def getString(struct: InternalRow): ByteBuffer = struct.getUTF8String(ordinal).getByteBuffer - - override def getTiType: TiType = tiType + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToString[InternalRow] with FromStruct { + override def getString(struct: InternalRow): ByteBuffer = struct.getUTF8String(ordinal).getByteBuffer - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = { - val bytes = getBytes(getString(struct)) - ysonConsumer.onString(bytes, 0, bytes.length) + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = { + val bytes = getBytes(getString(struct)) + ysonConsumer.onString(bytes, 0, bytes.length) + } } - } } case object Binary extends AtomicYtLogicalType("binary", 0x10, ColumnValueType.STRING, TiType.string(), BinaryType) { @@ -285,393 +258,328 @@ object YtLogicalType { if (inner) alias.name else "string" } - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToString[ArrayData] { - override def getString(list: ArrayData, i: Int): ByteBuffer = ByteBuffer.wrap(list.getBinary(i)) - - override def getSize(list: ArrayData): Int = list.numElements() - - override def getTiType: TiType = tiType + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToString[ArrayData] with FromList { + override def getString(list: ArrayData, i: Int): ByteBuffer = ByteBuffer.wrap(list.getBinary(i)) - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = { - val byteBuffer = getString(list, i) - val bytes = new scala.Array[Byte](byteBuffer.remaining()) - byteBuffer.get(bytes) - ysonConsumer.onString(bytes, 0, bytes.length) + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = { + val byteBuffer = getString(list, i) + val bytes = new scala.Array[Byte](byteBuffer.remaining()) + byteBuffer.get(bytes) + ysonConsumer.onString(bytes, 0, bytes.length) + } } - } - - override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToString[InternalRow] { - override def getString(struct: InternalRow): ByteBuffer = ByteBuffer.wrap(struct.getBinary(ordinal)) - override def getTiType: TiType = tiType + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToString[InternalRow] with FromStruct { + override def getString(struct: InternalRow): ByteBuffer = ByteBuffer.wrap(struct.getBinary(ordinal)) - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = { - val byteBuffer = getString(struct) - val bytes = new scala.Array[Byte](byteBuffer.remaining()) - byteBuffer.get(bytes) - ysonConsumer.onString(bytes, 0, bytes.length) + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = { + val byteBuffer = getString(struct) + val bytes = new scala.Array[Byte](byteBuffer.remaining()) + byteBuffer.get(bytes) + ysonConsumer.onString(bytes, 0, bytes.length) + } } - } } case object Any extends AtomicYtLogicalType("any", 0x11, ColumnValueType.ANY, TiType.yson(), sparkTypeFor(TiType.yson()), Seq("yson")) { override def nullable: Boolean = true - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToYson[ArrayData] { - override def getSize(list: ArrayData): Int = list.numElements() - - override def getTiType: TiType = tiType - - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = - YTreeBinarySerializer.deserialize(new ByteArrayInputStream(list.getBinary(i)), ysonConsumer) - } - - override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToYson[InternalRow] { - override def getTiType: TiType = tiType + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToYson[ArrayData] with FromList { + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + YTreeBinarySerializer.deserialize(new ByteArrayInputStream(list.getBinary(i)), ysonConsumer) + } - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = - YTreeBinarySerializer.deserialize(new ByteArrayInputStream(struct.getBinary(ordinal)), ysonConsumer) - } + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToYson[InternalRow] with FromStruct { + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + YTreeBinarySerializer.deserialize(new ByteArrayInputStream(struct.getBinary(ordinal)), ysonConsumer) + } } case object Int8 extends AtomicYtLogicalType("int8", 0x1000, ColumnValueType.INT64, TiType.int8(), ByteType) { - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToByte[ArrayData] { - override def getByte(list: ArrayData, i: Int): Byte = list.getByte(i) - - override def getSize(list: ArrayData): Int = list.numElements() - - override def getTiType: TiType = tiType - - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(list.getByte(i)) - } + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToByte[ArrayData] with FromList { + override def getByte(list: ArrayData, i: Int): Byte = list.getByte(i) - override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToByte[InternalRow] { - override def getByte(struct: InternalRow): Byte = struct.getByte(ordinal) + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(list.getByte(i)) + } - override def getTiType: TiType = tiType + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToByte[InternalRow] with FromStruct { + override def getByte(struct: InternalRow): Byte = struct.getByte(ordinal) - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(struct.getByte(ordinal)) - } + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(struct.getByte(ordinal)) + } } case object Uint8 extends AtomicYtLogicalType("uint8", 0x1001, ColumnValueType.INT64, TiType.uint8(), ShortType) { - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToByte[ArrayData] { - override def getByte(list: ArrayData, i: Int): Byte = list.getShort(i).toByte - - override def getSize(list: ArrayData): Int = list.numElements() - - override def getTiType: TiType = tiType - - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getByte(list, i)) - } + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToByte[ArrayData] with FromList { + override def getByte(list: ArrayData, i: Int): Byte = list.getShort(i).toByte - override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToByte[InternalRow] { - override def getByte(struct: InternalRow): Byte = struct.getShort(ordinal).toByte + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getByte(list, i)) + } - override def getTiType: TiType = tiType + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToByte[InternalRow] with FromStruct { + override def getByte(struct: InternalRow): Byte = struct.getShort(ordinal).toByte - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getByte(struct)) - } + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getByte(struct)) + } } case object Int16 extends AtomicYtLogicalType("int16", 0x1003, ColumnValueType.INT64, TiType.int16(), ShortType) { - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToShort[ArrayData] { - override def getShort(list: ArrayData, i: Int): Short = list.getShort(i) - - override def getSize(list: ArrayData): Int = list.numElements() - - override def getTiType: TiType = tiType - - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(list.getShort(i)) - } + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToShort[ArrayData] with FromList { + override def getShort(list: ArrayData, i: Int): Short = list.getShort(i) - override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToShort[InternalRow] { - override def getShort(struct: InternalRow): Short = struct.getShort(ordinal) + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(list.getShort(i)) + } - override def getTiType: TiType = tiType + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToShort[InternalRow] with FromStruct { + override def getShort(struct: InternalRow): Short = struct.getShort(ordinal) - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(struct.getShort(ordinal)) - } + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onInteger(struct.getShort(ordinal)) + } } case object Uint16 extends AtomicYtLogicalType("uint16", 0x1004, ColumnValueType.INT64, TiType.uint16(), IntegerType) { - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToShort[ArrayData] { - override def getShort(list: ArrayData, i: Int): Short = list.getInt(i).toShort - - override def getSize(list: ArrayData): Int = list.numElements() - - override def getTiType: TiType = tiType - - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getShort(list, i)) - } + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToShort[ArrayData] with FromList { + override def getShort(list: ArrayData, i: Int): Short = list.getInt(i).toShort - override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToShort[InternalRow] { - override def getShort(struct: InternalRow): Short = struct.getInt(ordinal).toShort + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getShort(list, i)) + } - override def getTiType: TiType = tiType + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToShort[InternalRow] with FromStruct { + override def getShort(struct: InternalRow): Short = struct.getInt(ordinal).toShort - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getShort(struct)) - } + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getShort(struct)) + } } case object Int32 extends AtomicYtLogicalType("int32", 0x1005, ColumnValueType.INT64, TiType.int32(), IntegerType) { - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToInt[ArrayData] { - override def getInt(list: ArrayData, i: Int): Int = list.getInt(i) - - override def getSize(list: ArrayData): Int = list.numElements() - - override def getTiType: TiType = tiType - - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(list.getInt(i)) - } + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToInt[ArrayData] with FromList { + override def getInt(list: ArrayData, i: Int): Int = list.getInt(i) - override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToInt[InternalRow] { - override def getInt(struct: InternalRow): Int = struct.getInt(ordinal) + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onInteger(list.getInt(i)) + } - override def getTiType: TiType = tiType + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToInt[InternalRow] with FromStruct { + override def getInt(struct: InternalRow): Int = struct.getInt(ordinal) - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(struct.getInt(ordinal)) - } + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onInteger(struct.getInt(ordinal)) + } } case object Uint32 extends AtomicYtLogicalType("uint32", 0x1006, ColumnValueType.INT64, TiType.uint32(), LongType) { - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToInt[ArrayData] { - override def getInt(list: ArrayData, i: Int): Int = list.getLong(i).toInt - - override def getSize(list: ArrayData): Int = list.numElements() - - override def getTiType: TiType = tiType - - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getInt(list, i)) - } + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToInt[ArrayData] with FromList { + override def getInt(list: ArrayData, i: Int): Int = list.getLong(i).toInt - override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToInt[InternalRow] { - override def getInt(struct: InternalRow): Int = struct.getLong(ordinal).toInt + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getInt(list, i)) + } - override def getTiType: TiType = tiType + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToInt[InternalRow] with FromStruct { + override def getInt(struct: InternalRow): Int = struct.getLong(ordinal).toInt - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getInt(struct)) - } + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getInt(struct)) + } } case object Utf8 extends AtomicYtLogicalType("utf8", 0x1007, ColumnValueType.STRING, TiType.utf8(), StringType) { - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToString[ArrayData] { - override def getString(list: ArrayData, i: Int): ByteBuffer = list.getUTF8String(i).getByteBuffer - - override def getSize(list: ArrayData): Int = list.numElements() - - override def getTiType: TiType = tiType + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToString[ArrayData] with FromList { + override def getString(list: ArrayData, i: Int): ByteBuffer = list.getUTF8String(i).getByteBuffer - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = { - val bytes = getBytes(list.getUTF8String(i).getByteBuffer) - ysonConsumer.onString(bytes, 0, bytes.length) + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = { + val bytes = getBytes(list.getUTF8String(i).getByteBuffer) + ysonConsumer.onString(bytes, 0, bytes.length) + } } - } - - override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToString[InternalRow] { - override def getString(struct: InternalRow): ByteBuffer = struct.getUTF8String(ordinal).getByteBuffer - override def getTiType: TiType = tiType + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToString[InternalRow] with FromStruct { + override def getString(struct: InternalRow): ByteBuffer = struct.getUTF8String(ordinal).getByteBuffer - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = { - val bytes = getBytes(struct.getUTF8String(ordinal).getByteBuffer) - ysonConsumer.onString(bytes, 0, bytes.length) + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = { + val bytes = getBytes(struct.getUTF8String(ordinal).getByteBuffer) + ysonConsumer.onString(bytes, 0, bytes.length) + } } - } } // Unsupported types are listed here: yt/yt/client/arrow/arrow_row_stream_encoder.cpp case object Date extends AtomicYtLogicalType("date", 0x1008, ColumnValueType.UINT64, TiType.date(), DateType, arrowSupported = false) { - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToInt[ArrayData] { - override def getInt(list: ArrayData, i: Int): Int = list.getInt(i) - - override def getSize(list: ArrayData): Int = list.numElements() - - override def getTiType: TiType = tiType - - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = - ysonConsumer.onUnsignedInteger(getInt(list, i)) - } + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToInt[ArrayData] with FromList { + override def getInt(list: ArrayData, i: Int): Int = list.getInt(i) - override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToInt[InternalRow] { - override def getInt(struct: InternalRow): Int = struct.getInt(ordinal) + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getInt(list, i)) + } - override def getTiType: TiType = tiType + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToInt[InternalRow] with FromStruct { + override def getInt(struct: InternalRow): Int = struct.getInt(ordinal) - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = - ysonConsumer.onUnsignedInteger(getInt(struct)) - } + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getInt(struct)) + } } case object Datetime extends AtomicYtLogicalType("datetime", 0x1009, ColumnValueType.UINT64, TiType.datetime(), new DatetimeType(), arrowSupported = false) { - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToLong[ArrayData] { - override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) - - override def getSize(list: ArrayData): Int = list.numElements() - - override def getTiType: TiType = tiType - - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = - ysonConsumer.onUnsignedInteger(getLong(list, i)) - } + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToLong[ArrayData] with FromList { + override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) - override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToLong[InternalRow] { - override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getLong(list, i)) + } - override def getTiType: TiType = tiType + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToLong[InternalRow] with FromStruct { + override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = - ysonConsumer.onUnsignedInteger(getLong(struct)) - } + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getLong(struct)) + } } case object Timestamp extends AtomicYtLogicalType("timestamp", 0x100a, ColumnValueType.UINT64, TiType.timestamp(), TimestampType, arrowSupported = false) { - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToLong[ArrayData] { - override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) - - override def getSize(list: ArrayData): Int = list.numElements() - - override def getTiType: TiType = tiType - - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = - ysonConsumer.onUnsignedInteger(getLong(list, i)) - } + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToLong[ArrayData] with FromList { + override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) - override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToLong[InternalRow] { - override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getLong(list, i)) + } - override def getTiType: TiType = tiType + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToLong[InternalRow] with FromStruct { + override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = - ysonConsumer.onUnsignedInteger(getLong(struct)) - } + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getLong(struct)) + } } case object Interval extends AtomicYtLogicalType("interval", 0x100b, ColumnValueType.INT64, TiType.interval(), LongType, arrowSupported = false) { - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToLong[ArrayData] { - override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) - - override def getSize(list: ArrayData): Int = list.numElements() - - override def getTiType: TiType = tiType - - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = - ysonConsumer.onInteger(getLong(list, i)) - } + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToLong[ArrayData] with FromList { + override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) - override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToLong[InternalRow] { - override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onInteger(getLong(list, i)) + } - override def getTiType: TiType = tiType + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToLong[InternalRow] with FromStruct { + override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = - ysonConsumer.onInteger(getLong(struct)) - } + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onInteger(getLong(struct)) + } } case object Void extends AtomicYtLogicalType("void", 0x100c, ColumnValueType.NULL, TiType.voidType(), NullType) { - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToNull[ArrayData] { - override def getTiType: TiType = tiType - - override def getSize(list: ArrayData): Int = list.numElements() - - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onEntity() - } + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToNull[ArrayData] with FromList { + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onEntity() + } override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = - new YTGetters.FromStructToNull[InternalRow] { - override def getTiType: TiType = tiType - + new YTGetters.FromStructToNull[InternalRow] with FromStruct { override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onEntity() } } case object Date32 extends AtomicYtLogicalType("date32", 0x1018, ColumnValueType.INT64, TiType.date32(), new Date32Type(), arrowSupported = false) { - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToInt[ArrayData] { - override def getInt(list: ArrayData, i: Int): Int = list.getInt(i) - - override def getSize(list: ArrayData): Int = list.numElements() - - override def getTiType: TiType = tiType - - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = - ysonConsumer.onUnsignedInteger(getInt(list, i)) - } + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToInt[ArrayData] with FromList { + override def getInt(list: ArrayData, i: Int): Int = list.getInt(i) - override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToInt[InternalRow] { - override def getInt(struct: InternalRow): Int = struct.getInt(ordinal) + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getInt(list, i)) + } - override def getTiType: TiType = tiType + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToInt[InternalRow] with FromStruct { + override def getInt(struct: InternalRow): Int = struct.getInt(ordinal) - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = - ysonConsumer.onUnsignedInteger(getInt(struct)) - } + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getInt(struct)) + } } case object Datetime64 extends AtomicYtLogicalType("datetime64", 0x1019, ColumnValueType.INT64, TiType.datetime64(), new Datetime64Type(), arrowSupported = false) { - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToLong[ArrayData] { - override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) - - override def getSize(list: ArrayData): Int = list.numElements() - - override def getTiType: TiType = tiType - - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = - ysonConsumer.onUnsignedInteger(getLong(list, i)) - } + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToLong[ArrayData] with FromList { + override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) - override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToLong[InternalRow] { - override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getLong(list, i)) + } - override def getTiType: TiType = tiType + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToLong[InternalRow] with FromStruct { + override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = - ysonConsumer.onUnsignedInteger(getLong(struct)) - } + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getLong(struct)) + } } case object Timestamp64 extends AtomicYtLogicalType("timestamp64", 0x101a, ColumnValueType.INT64, TiType.timestamp64(), new Timestamp64Type(), arrowSupported = false) { - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToLong[ArrayData] { - override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) - - override def getSize(list: ArrayData): Int = list.numElements() - - override def getTiType: TiType = tiType - - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = - ysonConsumer.onUnsignedInteger(getLong(list, i)) - } + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToLong[ArrayData] with FromList { + override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) - override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToLong[InternalRow] { - override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getLong(list, i)) + } - override def getTiType: TiType = tiType + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToLong[InternalRow] with FromStruct { + override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = - ysonConsumer.onUnsignedInteger(getLong(struct)) - } + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getLong(struct)) + } } case object Interval64 extends AtomicYtLogicalType("interval64", 0x101b, ColumnValueType.INT64, TiType.interval64(), new Interval64Type(), arrowSupported = false) { - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToLong[ArrayData] { - override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) - - override def getSize(list: ArrayData): Int = list.numElements() - - override def getTiType: TiType = tiType - - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = - ysonConsumer.onInteger(getLong(list, i)) - } + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToLong[ArrayData] with FromList { + override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) - override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToLong[InternalRow] { - override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onInteger(getLong(list, i)) + } - override def getTiType: TiType = tiType + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToLong[InternalRow] with FromStruct { + override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = - ysonConsumer.onInteger(getLong(struct)) - } + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onInteger(getLong(struct)) + } } case class Decimal(precision: Int, scale: Int, decimalType: DecimalType) extends CompositeYtLogicalType { @@ -681,31 +589,27 @@ object YtLogicalType { override def tiType: TiType = TiType.decimal(precision, scale) - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToBigDecimal[ArrayData] { - override def getBigDecimal(list: ArrayData, i: Int): java.math.BigDecimal = - list.getDecimal(i, decimalType.precision, decimalType.scale).toJavaBigDecimal.setScale(scale) - - override def getSize(list: ArrayData): Int = list.numElements() - - override def getTiType: TiType = tiType + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToBigDecimal[ArrayData] with FromList { + override def getBigDecimal(list: ArrayData, i: Int): java.math.BigDecimal = + list.getDecimal(i, decimalType.precision, decimalType.scale).toJavaBigDecimal.setScale(scale) - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = { - val bytes = getBigDecimal(list, i).unscaledValue().toByteArray - ysonConsumer.onString(bytes, 0, bytes.length) + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = { + val bytes = getBigDecimal(list, i).unscaledValue().toByteArray + ysonConsumer.onString(bytes, 0, bytes.length) + } } - } - override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToBigDecimal[InternalRow] { - override def getBigDecimal(struct: InternalRow): java.math.BigDecimal = - struct.getDecimal(ordinal, decimalType.precision, decimalType.scale).toJavaBigDecimal.setScale(scale) - - override def getTiType: TiType = tiType + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToBigDecimal[InternalRow] with FromStruct { + override def getBigDecimal(struct: InternalRow): java.math.BigDecimal = + struct.getDecimal(ordinal, decimalType.precision, decimalType.scale).toJavaBigDecimal.setScale(scale) - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = { - val bytes = getBigDecimal(struct).unscaledValue().toByteArray - ysonConsumer.onString(bytes, 0, bytes.length) + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = { + val bytes = getBigDecimal(struct).unscaledValue().toByteArray + ysonConsumer.onString(bytes, 0, bytes.length) + } } - } } case object Decimal extends CompositeYtLogicalTypeAlias("decimal") @@ -727,53 +631,49 @@ object YtLogicalType { override def arrowSupported: Boolean = inner.arrowSupported - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToOptional[ArrayData] { - private val notEmptyGetter = inner.ytGettersFromList(dataType) - - override def getNotEmptyGetter: YTGetters.FromList[ArrayData] = notEmptyGetter - - override def isEmpty(list: ArrayData, i: Int): Boolean = list.isNullAt(i) + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToOptional[ArrayData] with FromList { + private val notEmptyGetter = inner.ytGettersFromList(dataType) - override def getSize(list: ArrayData): Int = list.numElements() + override def getNotEmptyGetter: YTGetters.FromList[ArrayData] = notEmptyGetter - override def getTiType: TiType = tiType + override def isEmpty(list: ArrayData, i: Int): Boolean = list.isNullAt(i) - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = { - if (list.isNullAt(i)) { - ysonConsumer.onEntity() - } else if (inner.isInstanceOf[Optional]) { - ysonConsumer.onBeginList() - ysonConsumer.onListItem() - notEmptyGetter.getYson(list, i, ysonConsumer) - ysonConsumer.onEndList() - } else { - notEmptyGetter.getYson(list, i, ysonConsumer) + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = { + if (list.isNullAt(i)) { + ysonConsumer.onEntity() + } else if (inner.isInstanceOf[Optional]) { + ysonConsumer.onBeginList() + ysonConsumer.onListItem() + notEmptyGetter.getYson(list, i, ysonConsumer) + ysonConsumer.onEndList() + } else { + notEmptyGetter.getYson(list, i, ysonConsumer) + } } } - } - override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToOptional[InternalRow] { - private val notEmptyGetter = inner.ytGettersFromStruct(dataType, ordinal) - - override def getNotEmptyGetter: YTGetters.FromStruct[InternalRow] = notEmptyGetter + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToOptional[InternalRow] with FromStruct { + private val notEmptyGetter = inner.ytGettersFromStruct(dataType, ordinal) - override def isEmpty(struct: InternalRow): Boolean = struct.isNullAt(ordinal) + override def getNotEmptyGetter: YTGetters.FromStruct[InternalRow] = notEmptyGetter - override def getTiType: TiType = tiType + override def isEmpty(struct: InternalRow): Boolean = struct.isNullAt(ordinal) - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = { - if (struct.isNullAt(ordinal)) { - ysonConsumer.onEntity() - } else if (inner.isInstanceOf[Optional]) { - ysonConsumer.onBeginList() - ysonConsumer.onListItem() - notEmptyGetter.getYson(struct, ysonConsumer) - ysonConsumer.onEndList() - } else { - notEmptyGetter.getYson(struct, ysonConsumer) + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = { + if (struct.isNullAt(ordinal)) { + ysonConsumer.onEntity() + } else if (inner.isInstanceOf[Optional]) { + ysonConsumer.onBeginList() + ysonConsumer.onListItem() + notEmptyGetter.getYson(struct, ysonConsumer) + ysonConsumer.onEndList() + } else { + notEmptyGetter.getYson(struct, ysonConsumer) + } } } - } } case object Optional extends CompositeYtLogicalTypeAlias(TypeName.Optional.getWireName) @@ -831,16 +731,12 @@ object YtLogicalType { override def alias: CompositeYtLogicalTypeAlias = Dict override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = - new YTGetters.FromListToDict[ArrayData, MapData, ArrayData, ArrayData] { + new YTGetters.FromListToDict[ArrayData, MapData, ArrayData, ArrayData] with FromList { private val getter = newGetter(dataType) private val ysonSerializer = newYsonSerializer(getter) override def getGetter(): YTGetters.FromDict[MapData, ArrayData, ArrayData] = getter - override def getTiType: TiType = tiType - - override def getSize(list: ArrayData): Int = list.numElements() - override def getDict(list: ArrayData, i: Int): MapData = list.getMap(i) override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = @@ -848,7 +744,7 @@ object YtLogicalType { } override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = - new YTGetters.FromStructToDict[InternalRow, MapData, ArrayData, ArrayData] { + new YTGetters.FromStructToDict[InternalRow, MapData, ArrayData, ArrayData] with FromStruct { private val getter = newGetter(dataType) private val ysonSerializer = newYsonSerializer(getter) @@ -856,8 +752,6 @@ object YtLogicalType { override def getDict(struct: InternalRow): MapData = struct.getMap(ordinal) - override def getTiType: TiType = tiType - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonSerializer(struct.getMap(ordinal), ysonConsumer) } @@ -872,15 +766,10 @@ object YtLogicalType { override def alias: CompositeYtLogicalTypeAlias = Array - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = - new YTGetters.FromListToList[ArrayData, ArrayData] { + new YTGetters.FromListToList[ArrayData, ArrayData] with FromList { val elementGetter: YTGetters.FromList[ArrayData] = inner.ytGettersFromList(dataType.asInstanceOf[ArrayType].elementType) - override def getSize(list: ArrayData): Int = list.numElements() - - override def getTiType: TiType = tiType - override def getElementGetter: YTGetters.FromList[ArrayData] = elementGetter override def getList(list: ArrayData, i: Int): ArrayData = list.getArray(i) @@ -897,15 +786,13 @@ object YtLogicalType { } override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = - new YTGetters.FromStructToList[InternalRow, ArrayData] { + new YTGetters.FromStructToList[InternalRow, ArrayData] with FromStruct { val elementGetter: YTGetters.FromList[ArrayData] = inner.ytGettersFromList(dataType.asInstanceOf[ArrayType].elementType) override def getElementGetter: YTGetters.FromList[ArrayData] = elementGetter override def getList(struct: InternalRow): ArrayData = struct.getArray(ordinal) - override def getTiType: TiType = tiType - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = { val value = struct.getArray(ordinal) ysonConsumer.onBeginList() @@ -938,9 +825,9 @@ object YtLogicalType { }.asJava def yson( - membersGetters: java.util.List[java.util.Map.Entry[String, YTGetters.FromStruct[InternalRow]]], - internalRow: InternalRow, ysonConsumer: YsonConsumer, - ): Unit = { + membersGetters: java.util.List[java.util.Map.Entry[String, YTGetters.FromStruct[InternalRow]]], + internalRow: InternalRow, ysonConsumer: YsonConsumer, + ): Unit = { ysonConsumer.onBeginList() for (i <- 0 until membersGetters.size()) { ysonConsumer.onListItem() @@ -949,35 +836,31 @@ object YtLogicalType { ysonConsumer.onEndList() } - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToStruct[ArrayData, InternalRow] { - private val membersGetters = newMembersGetters(dataType) - - override def getMembersGetters(): java.util.List[java.util.Map.Entry[String, YTGetters.FromStruct[InternalRow]]] = - membersGetters - - override def getStruct(list: ArrayData, i: Int): InternalRow = list.getStruct(i, fields.size) - - override def getSize(list: ArrayData): Int = list.numElements() + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToStruct[ArrayData, InternalRow] with FromList { + private val membersGetters = newMembersGetters(dataType) - override def getTiType: TiType = tiType + override def getMembersGetters(): java.util.List[java.util.Map.Entry[String, YTGetters.FromStruct[InternalRow]]] = + membersGetters - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = - yson(membersGetters, list.getStruct(i, membersGetters.size()), ysonConsumer) - } + override def getStruct(list: ArrayData, i: Int): InternalRow = list.getStruct(i, fields.size) - override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToStruct[InternalRow, InternalRow] { - private val membersGetters = newMembersGetters(dataType) + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + yson(membersGetters, list.getStruct(i, membersGetters.size()), ysonConsumer) + } - override def getMembersGetters(): java.util.List[java.util.Map.Entry[String, YTGetters.FromStruct[InternalRow]]] = - membersGetters + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToStruct[InternalRow, InternalRow] with FromStruct { + private val membersGetters = newMembersGetters(dataType) - override def getStruct(struct: InternalRow): InternalRow = struct.getStruct(ordinal, fields.size) + override def getMembersGetters(): java.util.List[java.util.Map.Entry[String, YTGetters.FromStruct[InternalRow]]] = + membersGetters - override def getTiType: TiType = tiType + override def getStruct(struct: InternalRow): InternalRow = struct.getStruct(ordinal, fields.size) - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = - yson(membersGetters, struct.getStruct(ordinal, membersGetters.size()), ysonConsumer) - } + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + yson(membersGetters, struct.getStruct(ordinal, membersGetters.size()), ysonConsumer) + } } case object Struct extends CompositeYtLogicalTypeAlias(TypeName.Struct.getWireName) @@ -996,53 +879,49 @@ object YtLogicalType { override def alias: CompositeYtLogicalTypeAlias = Tuple - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToStruct[ArrayData, InternalRow] { - private val membersGetters = entries.zip(dataType.asInstanceOf[types.StructType]).zipWithIndex.map { - case (((name, logicalType), structField), i) => - java.util.Map.entry(name, logicalType.ytGettersFromStruct(structField.dataType, i)) - }.asJava - - override def getMembersGetters(): java.util.List[java.util.Map.Entry[String, YTGetters.FromStruct[InternalRow]]] = membersGetters - - override def getStruct(list: ArrayData, i: Int): InternalRow = list.getStruct(i, elements.size) + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToStruct[ArrayData, InternalRow] with FromList { + private val membersGetters = entries.zip(dataType.asInstanceOf[types.StructType]).zipWithIndex.map { + case (((name, logicalType), structField), i) => + java.util.Map.entry(name, logicalType.ytGettersFromStruct(structField.dataType, i)) + }.asJava - override def getSize(list: ArrayData): Int = list.numElements() + override def getMembersGetters(): java.util.List[java.util.Map.Entry[String, YTGetters.FromStruct[InternalRow]]] = membersGetters - override def getTiType: TiType = tiType + override def getStruct(list: ArrayData, i: Int): InternalRow = list.getStruct(i, elements.size) - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = { - val value = list.getStruct(i, membersGetters.size()) - ysonConsumer.onBeginList() - membersGetters.forEach { getter => - ysonConsumer.onListItem() - getter.getValue.getYson(value, ysonConsumer) + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = { + val value = list.getStruct(i, membersGetters.size()) + ysonConsumer.onBeginList() + membersGetters.forEach { getter => + ysonConsumer.onListItem() + getter.getValue.getYson(value, ysonConsumer) + } + ysonConsumer.onEndList() } - ysonConsumer.onEndList() } - } - override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToStruct[InternalRow, InternalRow] { - private val membersGetters = entries.zip(dataType.asInstanceOf[types.StructType]).zipWithIndex.map { - case (((name, logicalType), structField), i) => - java.util.Map.entry(name, logicalType.ytGettersFromStruct(structField.dataType, i)) - }.asJava - - override def getMembersGetters(): java.util.List[java.util.Map.Entry[String, YTGetters.FromStruct[InternalRow]]] = membersGetters + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToStruct[InternalRow, InternalRow] with FromStruct { + private val membersGetters = entries.zip(dataType.asInstanceOf[types.StructType]).zipWithIndex.map { + case (((name, logicalType), structField), i) => + java.util.Map.entry(name, logicalType.ytGettersFromStruct(structField.dataType, i)) + }.asJava - override def getStruct(struct: InternalRow): InternalRow = struct.getStruct(ordinal, elements.size) + override def getMembersGetters(): java.util.List[java.util.Map.Entry[String, YTGetters.FromStruct[InternalRow]]] = membersGetters - override def getTiType: TiType = tiType + override def getStruct(struct: InternalRow): InternalRow = struct.getStruct(ordinal, elements.size) - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = { - val value = struct.getStruct(ordinal, membersGetters.size()) - ysonConsumer.onBeginList() - membersGetters.forEach { getter => - ysonConsumer.onListItem() - getter.getValue.getYson(value, ysonConsumer) + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = { + val value = struct.getStruct(ordinal, membersGetters.size()) + ysonConsumer.onBeginList() + membersGetters.forEach { getter => + ysonConsumer.onListItem() + getter.getValue.getYson(value, ysonConsumer) + } + ysonConsumer.onEndList() } - ysonConsumer.onEndList() } - } } case object Tuple extends CompositeYtLogicalTypeAlias(TypeName.Tuple.getWireName) @@ -1097,25 +976,21 @@ object YtLogicalType { override def alias: CompositeYtLogicalTypeAlias = Variant - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToYson[ArrayData] { - val getter = new VariantGetter(fields.map(_._2), dataType) - - override def getSize(list: ArrayData): Int = list.numElements() - - override def getTiType: TiType = tiType - - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = - getter.get(list.getStruct(i, fields.size), ysonConsumer) - } + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToYson[ArrayData] with FromList { + val getter = new VariantGetter(fields.map(_._2), dataType) - override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToYson[InternalRow] { - val getter = new VariantGetter(fields.map(_._2), dataType) + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + getter.get(list.getStruct(i, fields.size), ysonConsumer) + } - override def getTiType: TiType = tiType + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToYson[InternalRow] with FromStruct { + val getter = new VariantGetter(fields.map(_._2), dataType) - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = - getter.get(struct.getStruct(ordinal, fields.size), ysonConsumer) - } + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + getter.get(struct.getStruct(ordinal, fields.size), ysonConsumer) + } } case class VariantOverTuple(fields: Seq[(YtLogicalType, Metadata)]) extends CompositeYtLogicalType { @@ -1133,25 +1008,21 @@ object YtLogicalType { override def alias: CompositeYtLogicalTypeAlias = Variant - override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = new YTGetters.FromListToYson[ArrayData] { - val getter = new VariantGetter(fields.map(_._1), dataType) - - override def getSize(list: ArrayData): Int = list.numElements() - - override def getTiType: TiType = tiType - - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = - getter.get(list.getStruct(i, fields.size), ysonConsumer) - } + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToYson[ArrayData] with FromList { + val getter = new VariantGetter(fields.map(_._1), dataType) - override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = new YTGetters.FromStructToYson[InternalRow] { - val getter = new VariantGetter(fields.map(_._1), dataType) + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + getter.get(list.getStruct(i, fields.size), ysonConsumer) + } - override def getTiType: TiType = tiType + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToYson[InternalRow] with FromStruct { + val getter = new VariantGetter(fields.map(_._1), dataType) - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = - getter.get(struct.getStruct(ordinal, fields.size), ysonConsumer) - } + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + getter.get(struct.getStruct(ordinal, fields.size), ysonConsumer) + } } case object Variant extends CompositeYtLogicalTypeAlias(TypeName.Variant.getWireName) From f7425693deaaa118e9191ed6d4e347094ee8875d Mon Sep 17 00:00:00 2001 From: Nikita Sokolov Date: Thu, 21 Nov 2024 15:22:31 +0100 Subject: [PATCH 11/12] ArrowTableRowsSerializer should close stuff --- .../client/ArrowTableRowsSerializer.java | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java b/data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java index 1f4d8d74..7abc490c 100644 --- a/data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java +++ b/data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java @@ -1208,27 +1208,28 @@ protected void writeRows(ByteBuf buf, TRowsetDescriptor descriptor, java.util.Li try { var writeChannel = new WriteChannel(new ByteBufWritableByteChannel(buf)); MessageSerializer.serialize(writeChannel, schema); - var root = VectorSchemaRoot.create(schema, allocator); - var unloader = new VectorUnloader(root); - var writers = IntStream.range(0, fieldGetters.size()).mapToObj(column -> { - var valueVector = root.getFieldVectors().get(column); - if (valueVector instanceof FixedWidthVector) { - ((FixedWidthVector) valueVector).allocateNew(rows.size()); - } else { - valueVector.allocateNew(); + try (var root = VectorSchemaRoot.create(schema, allocator)) { + var unloader = new VectorUnloader(root); + var writers = IntStream.range(0, fieldGetters.size()).mapToObj(column -> { + var valueVector = root.getFieldVectors().get(column); + if (valueVector instanceof FixedWidthVector) { + ((FixedWidthVector) valueVector).allocateNew(rows.size()); + } else { + valueVector.allocateNew(); + } + return fieldGetters.get(column).writer(valueVector); + }).collect(Collectors.toList()); + for (var row : rows) { + for (var writer : writers) { + writer.setFromStruct(row); + } } - return fieldGetters.get(column).writer(valueVector); - }).collect(Collectors.toList()); - for (var row : rows) { - for (var writer : writers) { - writer.setFromStruct(row); + root.setRowCount(rows.size()); + try (var batch = unloader.getRecordBatch()) { + MessageSerializer.serialize(writeChannel, batch); } + writeChannel.writeZeros(4); } - root.setRowCount(rows.size()); - try (var batch = unloader.getRecordBatch()) { - MessageSerializer.serialize(writeChannel, batch); - } - writeChannel.writeZeros(4); } catch (IOException e) { throw new RuntimeException(e); } From b221450dda837a5577538d87236304fa2ef14fe8 Mon Sep 17 00:00:00 2001 From: Nikita Sokolov Date: Thu, 21 Nov 2024 15:23:18 +0100 Subject: [PATCH 12/12] groom --- .../spyt/serializers/YtLogicalType.scala | 39 +++++++------------ 1 file changed, 15 insertions(+), 24 deletions(-) diff --git a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalType.scala b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalType.scala index 85fa6ac8..47fbf1a7 100644 --- a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalType.scala +++ b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalType.scala @@ -263,9 +263,7 @@ object YtLogicalType { override def getString(list: ArrayData, i: Int): ByteBuffer = ByteBuffer.wrap(list.getBinary(i)) override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = { - val byteBuffer = getString(list, i) - val bytes = new scala.Array[Byte](byteBuffer.remaining()) - byteBuffer.get(bytes) + val bytes = getBytes(getString(list, i)) ysonConsumer.onString(bytes, 0, bytes.length) } } @@ -275,9 +273,7 @@ object YtLogicalType { override def getString(struct: InternalRow): ByteBuffer = ByteBuffer.wrap(struct.getBinary(ordinal)) override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = { - val byteBuffer = getString(struct) - val bytes = new scala.Array[Byte](byteBuffer.remaining()) - byteBuffer.get(bytes) + val bytes = getBytes(getString(struct)) ysonConsumer.onString(bytes, 0, bytes.length) } } @@ -774,15 +770,8 @@ object YtLogicalType { override def getList(list: ArrayData, i: Int): ArrayData = list.getArray(i) - override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = { - val value = list.getArray(i) - ysonConsumer.onBeginList() - for (j <- 0 until value.numElements()) { - ysonConsumer.onListItem() - elementGetter.getYson(value, j, ysonConsumer) - } - ysonConsumer.onEndList() - } + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + onList(ysonConsumer, elementGetter, list.getArray(i)) } override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = @@ -793,16 +782,18 @@ object YtLogicalType { override def getList(struct: InternalRow): ArrayData = struct.getArray(ordinal) - override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = { - val value = struct.getArray(ordinal) - ysonConsumer.onBeginList() - for (j <- 0 until value.numElements()) { - ysonConsumer.onListItem() - elementGetter.getYson(value, j, ysonConsumer) - } - ysonConsumer.onEndList() - } + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + onList(ysonConsumer, elementGetter, struct.getArray(ordinal)) } + + private def onList(ysonConsumer: YsonConsumer, elementGetter: YTGetters.FromList[ArrayData], value: ArrayData): Unit = { + ysonConsumer.onBeginList() + for (j <- 0 until value.numElements()) { + ysonConsumer.onListItem() + elementGetter.getYson(value, j, ysonConsumer) + } + ysonConsumer.onEndList() + } } case object Array extends CompositeYtLogicalTypeAlias(TypeName.List.getWireName)