diff --git a/data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java b/data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java new file mode 100644 index 0000000..7abc490 --- /dev/null +++ b/data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java @@ -0,0 +1,1237 @@ +package tech.ytsaurus.client; + +import io.netty.buffer.ByteBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.*; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.arrow.vector.ipc.WriteChannel; +import org.apache.arrow.vector.ipc.message.IpcOption; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import tech.ytsaurus.rpcproxy.ERowsetFormat; +import tech.ytsaurus.rpcproxy.TRowsetDescriptor; +import tech.ytsaurus.spyt.format.batch.ArrowUtils; +import tech.ytsaurus.typeinfo.DecimalType; +import tech.ytsaurus.yson.YsonBinaryWriter; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.util.*; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +public class ArrowTableRowsSerializer extends TableRowsSerializer implements AutoCloseable { + private static abstract class ArrowGetterFromStruct { + public final Field field; + public final ArrowType arrowType; + + ArrowGetterFromStruct(Field field) { + super(); + this.field = field; + this.arrowType = field.getType(); + } + + public final ArrowType getArrowType() { + return arrowType; + } + + public abstract ArrowWriterFromStruct writer(ValueVector valueVector); + } + + private static abstract class ArrowWriterFromStruct { + abstract void setFromStruct(Row struct); + } + + private static abstract class ArrowGetterFromList { + public final Field field; + public final ArrowType arrowType; + + ArrowGetterFromList(Field field) { + this.field = field; + this.arrowType = field.getType(); + } + + public final ArrowType getArrowType() { + return arrowType; + } + + public abstract ArrowWriterFromList writer(ValueVector valueVector); + } + + private static abstract class ArrowWriterFromList { + abstract void setFromList(List list, int i); + } + + private ArrowGetterFromList arrowGetter(String name, YTGetters.FromList getter) { + var optionalGetter = getter instanceof YTGetters.FromListToOptional + ? (YTGetters.FromListToOptional) getter + : null; + var nonEmptyGetter = optionalGetter != null ? optionalGetter.getNotEmptyGetter() : getter; + var arrowGetter = nonComplexArrowGetter(name, nonEmptyGetter); + if (arrowGetter != null) { + return optionalGetter == null ? arrowGetter : new ArrowGetterFromList<>(new Field(name, new FieldType( + true, arrowGetter.field.getType(), null + ), arrowGetter.field.getChildren())) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var nonOptionalWriter = arrowGetter.writer(valueVector); + return new ArrowWriterFromList<>() { + @Override + public void setFromList(Array array, int i) { + nonOptionalWriter.setFromList(optionalGetter.isEmpty(array, i) ? null : array, i); + } + }; + } + }; + } + return new ArrowGetterFromList<>(new Field(name, new FieldType( + optionalGetter != null, new ArrowType.Binary(), null + ), new ArrayList<>())) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var varBinaryVector = (VarBinaryVector) valueVector; + return new ArrowWriterFromList<>() { + @Override + public void setFromList(Array array, int i) { + if (optionalGetter != null && optionalGetter.isEmpty(array, i)) { + varBinaryVector.setNull(varBinaryVector.getValueCount()); + } else { + var byteArrayOutputStream = new ByteArrayOutputStream(); + try (var ysonBinaryWriter = new YsonBinaryWriter(byteArrayOutputStream)) { + nonEmptyGetter.getYson(array, i, ysonBinaryWriter); + } + varBinaryVector.set(varBinaryVector.getValueCount(), byteArrayOutputStream.toByteArray()); + } + varBinaryVector.setValueCount(varBinaryVector.getValueCount() + 1); + } + }; + } + }; + } + + private ArrowGetterFromStruct arrowGetter(String name, YTGetters.FromStruct getter) { + var optionalGetter = getter instanceof YTGetters.FromStructToOptional + ? (YTGetters.FromStructToOptional) getter + : null; + var nonEmptyGetter = optionalGetter != null ? optionalGetter.getNotEmptyGetter() : getter; + var arrowGetter = nonComplexArrowGetter(name, nonEmptyGetter); + if (arrowGetter != null) { + return optionalGetter == null ? arrowGetter : new ArrowGetterFromStruct<>(new Field(name, new FieldType( + true, arrowGetter.field.getType(), null + ), arrowGetter.field.getChildren())) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var nonOptionalWriter = arrowGetter.writer(valueVector); + return new ArrowWriterFromStruct<>() { + @Override + public void setFromStruct(Struct struct) { + nonOptionalWriter.setFromStruct(optionalGetter.isEmpty(struct) ? null : struct); + } + }; + } + }; + } else { + return new ArrowGetterFromStruct<>(new Field(name, new FieldType( + optionalGetter != null, new ArrowType.Binary(), null + ), new ArrayList<>())) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var varBinaryVector = (VarBinaryVector) valueVector; + return new ArrowWriterFromStruct<>() { + @Override + void setFromStruct(Struct struct) { + if (optionalGetter != null && optionalGetter.isEmpty(struct)) { + varBinaryVector.setNull(varBinaryVector.getValueCount()); + } else { + var byteArrayOutputStream = new ByteArrayOutputStream(); + try (var ysonBinaryWriter = new YsonBinaryWriter(byteArrayOutputStream)) { + nonEmptyGetter.getYson(struct, ysonBinaryWriter); + } + varBinaryVector.set(varBinaryVector.getValueCount(), byteArrayOutputStream.toByteArray()); + } + varBinaryVector.setValueCount(varBinaryVector.getValueCount() + 1); + } + }; + } + }; + } + } + + private Field field(String name, ArrowType arrowType) { + return new Field(name, new FieldType(false, arrowType, null), Collections.emptyList()); + } + + private ArrowGetterFromList nonComplexArrowGetter(String name, YTGetters.FromList getter) { + var tiType = getter.getTiType(); + switch (tiType.getTypeName()) { + case Null: + case Void: { + return new ArrowGetterFromList<>( + new Field(name, new FieldType(false, new ArrowType.Null(), null), new ArrayList<>()) + ) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var nullVector = (NullVector) valueVector; + return new ArrowWriterFromList<>() { + @Override + void setFromList(Array list, int i) { + nullVector.setValueCount(nullVector.getValueCount() + 1); + } + }; + } + }; + } + case Utf8: + case String: { + var stringGetter = (YTGetters.FromListToString) getter; + return new ArrowGetterFromList<>( + new Field(name, new FieldType(false, new ArrowType.Binary(), null), new ArrayList<>()) + ) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var varBinaryVector = (VarBinaryVector) valueVector; + return new ArrowWriterFromList<>() { + @Override + void setFromList(Array list, int i) { + if (list == null) { + varBinaryVector.setNull(varBinaryVector.getValueCount()); + } else { + var byteBuffer = stringGetter.getString(list, i); + varBinaryVector.set( + varBinaryVector.getValueCount(), + byteBuffer, byteBuffer.position(), byteBuffer.remaining() + ); + } + varBinaryVector.setValueCount(varBinaryVector.getValueCount() + 1); + } + }; + } + }; + } + case Int8: { + var byteGetter = (YTGetters.FromListToByte) getter; + return new ArrowGetterFromList<>(field(name, new ArrowType.Int(8, true))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var tinyIntVector = (TinyIntVector) valueVector; + return new ArrowWriterFromList<>() { + @Override + void setFromList(Array list, int i) { + if (list == null) { + tinyIntVector.setNull(tinyIntVector.getValueCount()); + } else { + tinyIntVector.set(tinyIntVector.getValueCount(), byteGetter.getByte(list, i)); + } + tinyIntVector.setValueCount(tinyIntVector.getValueCount() + 1); + } + }; + } + }; + } + case Uint8: { + var byteGetter = (YTGetters.FromListToByte) getter; + return new ArrowGetterFromList<>(field(name, new ArrowType.Int(8, false))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var uInt1Vector = (UInt1Vector) valueVector; + return new ArrowWriterFromList<>() { + @Override + void setFromList(Array list, int i) { + if (list == null) { + uInt1Vector.setNull(uInt1Vector.getValueCount()); + } else { + uInt1Vector.set(uInt1Vector.getValueCount(), byteGetter.getByte(list, i)); + } + uInt1Vector.setValueCount(uInt1Vector.getValueCount() + 1); + } + }; + } + }; + } + case Int16: { + var shortGetter = (YTGetters.FromListToShort) getter; + return new ArrowGetterFromList<>(field(name, new ArrowType.Int(16, true))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var smallIntVector = (SmallIntVector) valueVector; + return new ArrowWriterFromList<>() { + @Override + void setFromList(Array list, int i) { + if (list == null) { + smallIntVector.setNull(smallIntVector.getValueCount()); + } else { + smallIntVector.set(smallIntVector.getValueCount(), shortGetter.getShort(list, i)); + } + smallIntVector.setValueCount(smallIntVector.getValueCount() + 1); + } + }; + } + }; + } + case Uint16: { + var shortGetter = (YTGetters.FromListToShort) getter; + return new ArrowGetterFromList<>(field(name, new ArrowType.Int(16, false))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var uInt2Vector = (UInt2Vector) valueVector; + return new ArrowWriterFromList<>() { + @Override + void setFromList(Array list, int i) { + if (list == null) { + uInt2Vector.setNull(uInt2Vector.getValueCount()); + } else { + uInt2Vector.set(uInt2Vector.getValueCount(), shortGetter.getShort(list, i)); + } + uInt2Vector.setValueCount(uInt2Vector.getValueCount() + 1); + } + }; + } + }; + } + case Int32: { + var intGetter = (YTGetters.FromListToInt) getter; + return new ArrowGetterFromList<>(field(name, new ArrowType.Int(32, true))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var intVector = (IntVector) valueVector; + return new ArrowWriterFromList<>() { + @Override + void setFromList(Array list, int i) { + if (list == null) { + intVector.setNull(intVector.getValueCount()); + } else { + intVector.set(intVector.getValueCount(), intGetter.getInt(list, i)); + } + intVector.setValueCount(intVector.getValueCount() + 1); + } + }; + } + }; + } + case Uint32: { + var intGetter = (YTGetters.FromListToInt) getter; + return new ArrowGetterFromList<>(field(name, new ArrowType.Int(32, false))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var uInt4Vector = (UInt4Vector) valueVector; + return new ArrowWriterFromList<>() { + @Override + void setFromList(Array list, int i) { + if (list == null) { + uInt4Vector.setNull(uInt4Vector.getValueCount()); + } else { + uInt4Vector.set(uInt4Vector.getValueCount(), intGetter.getInt(list, i)); + } + uInt4Vector.setValueCount(uInt4Vector.getValueCount() + 1); + } + }; + } + }; + } + case Interval: + case Interval64: + case Int64: { + var longGetter = (YTGetters.FromListToLong) getter; + return new ArrowGetterFromList<>(field(name, new ArrowType.Int(64, true))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var bigIntVector = (BigIntVector) valueVector; + return new ArrowWriterFromList<>() { + @Override + void setFromList(Array list, int i) { + if (list == null) { + bigIntVector.setNull(bigIntVector.getValueCount()); + } else { + bigIntVector.set(bigIntVector.getValueCount(), longGetter.getLong(list, i)); + } + bigIntVector.setValueCount(bigIntVector.getValueCount() + 1); + } + }; + } + }; + } + case Uint64: { + var longGetter = (YTGetters.FromListToLong) getter; + return new ArrowGetterFromList<>(field(name, new ArrowType.Int(64, false))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var uInt8Vector = (UInt8Vector) valueVector; + return new ArrowWriterFromList<>() { + @Override + void setFromList(Array list, int i) { + if (list == null) { + uInt8Vector.setNull(uInt8Vector.getValueCount()); + } else { + uInt8Vector.set(uInt8Vector.getValueCount(), longGetter.getLong(list, i)); + } + uInt8Vector.setValueCount(uInt8Vector.getValueCount() + 1); + } + }; + } + }; + } + case Bool: { + var booleanGetter = (YTGetters.FromListToBoolean) getter; + return new ArrowGetterFromList<>(field(name, new ArrowType.Bool())) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var bitVector = (BitVector) valueVector; + return new ArrowWriterFromList<>() { + @Override + void setFromList(Array list, int i) { + if (list == null) { + bitVector.setNull(bitVector.getValueCount()); + } else { + bitVector.set(bitVector.getValueCount(), booleanGetter.getBoolean(list, i) ? 1 : 0); + } + bitVector.setValueCount(bitVector.getValueCount() + 1); + } + }; + } + }; + } + case Float: { + var floatGetter = (YTGetters.FromListToFloat) getter; + return new ArrowGetterFromList<>(field(name, new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var float4Vector = (Float4Vector) valueVector; + return new ArrowWriterFromList<>() { + @Override + void setFromList(Array list, int i) { + if (list == null) { + float4Vector.setNull(float4Vector.getValueCount()); + } else { + float4Vector.set(float4Vector.getValueCount(), floatGetter.getFloat(list, i)); + } + float4Vector.setValueCount(float4Vector.getValueCount() + 1); + } + }; + } + }; + } + case Double: { + var doubleGetter = (YTGetters.FromListToDouble) getter; + return new ArrowGetterFromList<>(field(name, new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var float8Vector = (Float8Vector) valueVector; + return new ArrowWriterFromList<>() { + @Override + void setFromList(Array list, int i) { + if (list == null) { + float8Vector.setNull(float8Vector.getValueCount()); + } else { + float8Vector.set(float8Vector.getValueCount(), doubleGetter.getDouble(list, i)); + } + float8Vector.setValueCount(float8Vector.getValueCount() + 1); + } + }; + } + }; + } + case Decimal: { + var decimalGetter = (YTGetters.FromListToBigDecimal) getter; + var decimalType = (DecimalType) decimalGetter.getTiType(); + return new ArrowGetterFromList<>(field(name, new ArrowType.Decimal( + decimalType.getPrecision(), decimalType.getScale() + ))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var decimalVector = (DecimalVector) valueVector; + return new ArrowWriterFromList<>() { + @Override + void setFromList(Array list, int i) { + if (list == null) { + decimalVector.setNull(decimalVector.getValueCount()); + } else { + decimalVector.set(decimalVector.getValueCount(), decimalGetter.getBigDecimal(list, i)); + } + decimalVector.setValueCount(decimalVector.getValueCount() + 1); + } + }; + } + }; + } + case Date: + case Date32: { + var intGetter = (YTGetters.FromListToInt) getter; + return new ArrowGetterFromList<>(field(name, new ArrowType.Date(DateUnit.DAY))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var dateDayVector = (DateDayVector) valueVector; + return new ArrowWriterFromList<>() { + @Override + void setFromList(Array list, int i) { + if (list == null) { + dateDayVector.setNull(dateDayVector.getValueCount()); + } else { + dateDayVector.set(dateDayVector.getValueCount(), intGetter.getInt(list, i)); + } + dateDayVector.setValueCount(dateDayVector.getValueCount() + 1); + } + }; + } + }; + } + case Datetime: + case Datetime64: { + var longGetter = (YTGetters.FromListToLong) getter; + return new ArrowGetterFromList<>(field(name, new ArrowType.Date(DateUnit.MILLISECOND))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var dateMilliVector = (DateMilliVector) valueVector; + return new ArrowWriterFromList<>() { + @Override + void setFromList(Array list, int i) { + if (list == null) { + dateMilliVector.setNull(dateMilliVector.getValueCount()); + } else { + dateMilliVector.set(dateMilliVector.getValueCount(), longGetter.getLong(list, i)); + } + dateMilliVector.setValueCount(dateMilliVector.getValueCount() + 1); + } + }; + } + }; + } + case Timestamp: + case Timestamp64: { + var longGetter = (YTGetters.FromListToLong) getter; + return new ArrowGetterFromList<>(field(name, new ArrowType.Timestamp(TimeUnit.MICROSECOND, null))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var timeStampMicroVector = (TimeStampMicroVector) valueVector; + return new ArrowWriterFromList<>() { + @Override + void setFromList(Array list, int i) { + if (list == null) { + timeStampMicroVector.setNull(timeStampMicroVector.getValueCount()); + } else { + timeStampMicroVector.set(timeStampMicroVector.getValueCount(), longGetter.getLong(list, i)); + } + timeStampMicroVector.setValueCount(timeStampMicroVector.getValueCount() + 1); + } + }; + } + }; + } + case List: { + return getArrowGetterFromList(name, (YTGetters.FromListToList) getter); + } + case Dict: { + return getArrowGetterFromList(name, (YTGetters.FromListToDict) getter); + } + case Struct: { + return getArrowGetterFromList(name, (YTGetters.FromListToStruct) getter); + } + default: + return null; + } + } + + private ArrowGetterFromList getArrowGetterFromList( + String name, YTGetters.FromListToStruct structGetter + ) { + var members = structGetter.getMembersGetters(); + var membersGetters = new ArrayList>(members.size()); + for (var member : members) { + membersGetters.add(arrowGetter(member.getKey(), member.getValue())); + } + return new ArrowGetterFromList<>(new Field( + name, new FieldType(false, new ArrowType.Struct(), null), + membersGetters.stream().map(member -> member.field).collect(Collectors.toList()) + )) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var structVector = (StructVector) valueVector; + var membersWriters = new ArrayList>(members.size()); + for (int i = 0; i < members.size(); i++) { + membersWriters.add(membersGetters.get(i).writer(structVector.getChildByOrdinal(i))); + } + return new ArrowWriterFromList<>() { + @Override + void setFromList(Array list, int i) { + if (list == null) { + for (int j = 0; j < members.size(); j++) { + membersWriters.get(j).setFromStruct(null); + } + } else { + var struct = structGetter.getStruct(list, i); + structVector.setIndexDefined(structVector.getValueCount()); + for (int j = 0; j < members.size(); j++) { + membersWriters.get(j).setFromStruct(struct); + } + } + structVector.setValueCount(structVector.getValueCount() + 1); + } + }; + } + }; + } + + private ArrowGetterFromList getArrowGetterFromList( + String name, YTGetters.FromListToDict dictGetter + ) { + var fromDictGetter = dictGetter.getGetter(); + var keyGetter = nonComplexArrowGetter("key", fromDictGetter.getKeyGetter()); + var valueGetter = arrowGetter("value", fromDictGetter.getValueGetter()); + if (keyGetter == null) { + return null; + } + return new ArrowGetterFromList<>(new Field( + name, new FieldType(false, new ArrowType.Map(false), null), + Collections.singletonList(new Field( + "entries", new FieldType(false, new ArrowType.Struct(), null), + Arrays.asList(keyGetter.field, valueGetter.field) + )) + )) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var mapVector = (MapVector) valueVector; + var structVector = (StructVector) mapVector.getDataVector(); + var keyWriter = keyGetter.writer(structVector.getChildByOrdinal(0)); + var valueWriter = valueGetter.writer(structVector.getChildByOrdinal(1)); + return new ArrowWriterFromList<>() { + @Override + void setFromList(Array list, int i) { + var dict = list == null ? null : dictGetter.getDict(list, i); + if (dict != null) { + int size = fromDictGetter.getSize(dict); + var keys = fromDictGetter.getKeys(dict); + var values = fromDictGetter.getValues(dict); + mapVector.startNewValue(mapVector.getValueCount()); + for (int j = 0; j < size; j++) { + structVector.setIndexDefined(structVector.getValueCount()); + keyWriter.setFromList(keys, j); + valueWriter.setFromList(values, j); + structVector.setValueCount(structVector.getValueCount() + 1); + } + mapVector.endValue(mapVector.getValueCount(), size); + } + mapVector.setValueCount(mapVector.getValueCount() + 1); + } + }; + } + }; + } + + private ArrowGetterFromList getArrowGetterFromList( + String name, YTGetters.FromListToList listGetter + ) { + var elementGetter = listGetter.getElementGetter(); + var itemGetter = arrowGetter("item", elementGetter); + return new ArrowGetterFromList<>(new Field(name, new FieldType( + false, new ArrowType.List(), null + ), Collections.singletonList(itemGetter.field))) { + @Override + public ArrowWriterFromList writer(ValueVector valueVector) { + var listVector = (ListVector) valueVector; + var dataWriter = itemGetter.writer(listVector.getDataVector()); + return new ArrowWriterFromList<>() { + @Override + void setFromList(Array list, int i) { + var value = list == null ? null : listGetter.getList(list, i); + if (value != null) { + int size = elementGetter.getSize(value); + listVector.startNewValue(listVector.getValueCount()); + for (int j = 0; j < size; j++) { + dataWriter.setFromList(value, j); + } + listVector.endValue(listVector.getValueCount(), size); + } + listVector.setValueCount(listVector.getValueCount() + 1); + } + }; + } + }; + } + + private ArrowGetterFromStruct nonComplexArrowGetter(String name, YTGetters.FromStruct getter) { + var tiType = getter.getTiType(); + switch (tiType.getTypeName()) { + case Null: + case Void: { + return new ArrowGetterFromStruct<>( + new Field(name, new FieldType(false, new ArrowType.Null(), null), new ArrayList<>()) + ) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var nullVector = (NullVector) valueVector; + return new ArrowWriterFromStruct<>() { + @Override + void setFromStruct(Struct struct) { + nullVector.setValueCount(nullVector.getValueCount() + 1); + } + }; + } + }; + } + case Utf8: + case String: { + var stringGetter = (YTGetters.FromStructToString) getter; + return new ArrowGetterFromStruct<>(field(name, new ArrowType.Binary())) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var varBinaryVector = (VarBinaryVector) valueVector; + return new ArrowWriterFromStruct<>() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + varBinaryVector.setNull(varBinaryVector.getValueCount()); + } else { + var byteBuffer = stringGetter.getString(struct); + varBinaryVector.set( + varBinaryVector.getValueCount(), + byteBuffer, byteBuffer.position(), byteBuffer.remaining() + ); + } + varBinaryVector.setValueCount(varBinaryVector.getValueCount() + 1); + } + }; + } + }; + } + case Int8: { + var byteGetter = (YTGetters.FromStructToByte) getter; + return new ArrowGetterFromStruct<>(field(name, new ArrowType.Int(8, true))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var tinyIntVector = (TinyIntVector) valueVector; + return new ArrowWriterFromStruct<>() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + tinyIntVector.setNull(tinyIntVector.getValueCount()); + } else { + tinyIntVector.set(tinyIntVector.getValueCount(), byteGetter.getByte(struct)); + } + tinyIntVector.setValueCount(tinyIntVector.getValueCount() + 1); + } + }; + } + }; + } + case Uint8: { + var byteGetter = (YTGetters.FromStructToByte) getter; + return new ArrowGetterFromStruct<>(field(name, new ArrowType.Int(8, false))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var uInt1Vector = (UInt1Vector) valueVector; + return new ArrowWriterFromStruct<>() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + uInt1Vector.setNull(uInt1Vector.getValueCount()); + } else { + uInt1Vector.set(uInt1Vector.getValueCount(), byteGetter.getByte(struct)); + } + uInt1Vector.setValueCount(uInt1Vector.getValueCount() + 1); + } + }; + } + }; + } + case Int16: { + var shortGetter = (YTGetters.FromStructToShort) getter; + return new ArrowGetterFromStruct<>(field(name, new ArrowType.Int(16, true))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var smallIntVector = (SmallIntVector) valueVector; + return new ArrowWriterFromStruct<>() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + smallIntVector.setNull(smallIntVector.getValueCount()); + } else { + smallIntVector.set(smallIntVector.getValueCount(), shortGetter.getShort(struct)); + } + smallIntVector.setValueCount(smallIntVector.getValueCount() + 1); + } + }; + } + }; + } + case Uint16: { + var shortGetter = (YTGetters.FromStructToShort) getter; + return new ArrowGetterFromStruct<>(field(name, new ArrowType.Int(16, false))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var uInt2Vector = (UInt2Vector) valueVector; + return new ArrowWriterFromStruct<>() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + uInt2Vector.setNull(uInt2Vector.getValueCount()); + } else { + uInt2Vector.set(uInt2Vector.getValueCount(), shortGetter.getShort(struct)); + } + uInt2Vector.setValueCount(uInt2Vector.getValueCount() + 1); + } + }; + } + }; + } + case Int32: { + var intGetter = (YTGetters.FromStructToInt) getter; + return new ArrowGetterFromStruct<>(field(name, new ArrowType.Int(32, true))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var intVector = (IntVector) valueVector; + return new ArrowWriterFromStruct<>() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + intVector.setNull(intVector.getValueCount()); + } else { + intVector.set(intVector.getValueCount(), intGetter.getInt(struct)); + } + intVector.setValueCount(intVector.getValueCount() + 1); + } + }; + } + }; + } + case Uint32: { + var intGetter = (YTGetters.FromStructToInt) getter; + return new ArrowGetterFromStruct<>(field(name, new ArrowType.Int(32, false))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var uInt4Vector = (UInt4Vector) valueVector; + return new ArrowWriterFromStruct<>() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + uInt4Vector.setNull(uInt4Vector.getValueCount()); + } else { + uInt4Vector.set(uInt4Vector.getValueCount(), intGetter.getInt(struct)); + } + uInt4Vector.setValueCount(uInt4Vector.getValueCount() + 1); + } + }; + } + }; + } + case Interval: + case Interval64: + case Int64: { + var longGetter = (YTGetters.FromStructToLong) getter; + return new ArrowGetterFromStruct<>(field(name, new ArrowType.Int(64, true))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var bigIntVector = (BigIntVector) valueVector; + return new ArrowWriterFromStruct<>() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + bigIntVector.setNull(bigIntVector.getValueCount()); + } else { + bigIntVector.set(bigIntVector.getValueCount(), longGetter.getLong(struct)); + } + bigIntVector.setValueCount(bigIntVector.getValueCount() + 1); + } + }; + } + }; + } + case Uint64: { + var longGetter = (YTGetters.FromStructToLong) getter; + return new ArrowGetterFromStruct<>(field(name, new ArrowType.Int(64, false))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var uInt8Vector = (UInt8Vector) valueVector; + return new ArrowWriterFromStruct<>() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + uInt8Vector.setNull(uInt8Vector.getValueCount()); + } else { + uInt8Vector.set(uInt8Vector.getValueCount(), longGetter.getLong(struct)); + } + uInt8Vector.setValueCount(uInt8Vector.getValueCount() + 1); + } + }; + } + }; + } + case Bool: { + var booleanGetter = (YTGetters.FromStructToBoolean) getter; + return new ArrowGetterFromStruct<>(field(name, new ArrowType.Bool())) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var bitVector = (BitVector) valueVector; + return new ArrowWriterFromStruct<>() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + bitVector.setNull(bitVector.getValueCount()); + } else { + bitVector.set(bitVector.getValueCount(), booleanGetter.getBoolean(struct) ? 1 : 0); + } + bitVector.setValueCount(bitVector.getValueCount() + 1); + } + }; + } + }; + } + case Float: { + var floatGetter = (YTGetters.FromStructToFloat) getter; + return new ArrowGetterFromStruct<>(field(name, new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var float4Vector = (Float4Vector) valueVector; + return new ArrowWriterFromStruct<>() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + float4Vector.setNull(float4Vector.getValueCount()); + } else { + float4Vector.set(float4Vector.getValueCount(), floatGetter.getFloat(struct)); + } + float4Vector.setValueCount(float4Vector.getValueCount() + 1); + } + }; + } + }; + } + case Double: { + var doubleGetter = (YTGetters.FromStructToDouble) getter; + return new ArrowGetterFromStruct<>(field(name, new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var float8Vector = (Float8Vector) valueVector; + return new ArrowWriterFromStruct<>() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + float8Vector.setNull(float8Vector.getValueCount()); + } else { + float8Vector.set(float8Vector.getValueCount(), doubleGetter.getDouble(struct)); + } + float8Vector.setValueCount(float8Vector.getValueCount() + 1); + } + }; + } + }; + } + case Decimal: { + var decimalGetter = (YTGetters.FromStructToBigDecimal) getter; + var decimalType = (DecimalType) decimalGetter.getTiType(); + return new ArrowGetterFromStruct<>(field(name, new ArrowType.Decimal( + decimalType.getPrecision(), decimalType.getScale() + ))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var decimalVector = (DecimalVector) valueVector; + return new ArrowWriterFromStruct<>() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + decimalVector.setNull(decimalVector.getValueCount()); + } else { + decimalVector.set(decimalVector.getValueCount(), decimalGetter.getBigDecimal(struct)); + } + decimalVector.setValueCount(decimalVector.getValueCount() + 1); + } + }; + } + }; + } + case Date: + case Date32: { + var intGetter = (YTGetters.FromStructToInt) getter; + return new ArrowGetterFromStruct<>(field(name, new ArrowType.Date(DateUnit.DAY))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var dateDayVector = (DateDayVector) valueVector; + return new ArrowWriterFromStruct<>() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + dateDayVector.setNull(dateDayVector.getValueCount()); + } else { + dateDayVector.set(dateDayVector.getValueCount(), intGetter.getInt(struct)); + } + dateDayVector.setValueCount(dateDayVector.getValueCount() + 1); + } + }; + } + }; + } + case Datetime: + case Datetime64: { + var longGetter = (YTGetters.FromStructToLong) getter; + return new ArrowGetterFromStruct<>(field(name, new ArrowType.Date(DateUnit.MILLISECOND))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var dateMilliVector = (DateMilliVector) valueVector; + return new ArrowWriterFromStruct<>() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + dateMilliVector.setNull(dateMilliVector.getValueCount()); + } else { + dateMilliVector.set(dateMilliVector.getValueCount(), longGetter.getLong(struct)); + } + dateMilliVector.setValueCount(dateMilliVector.getValueCount() + 1); + } + }; + } + }; + } + case Timestamp: + case Timestamp64: { + var longGetter = (YTGetters.FromStructToLong) getter; + return new ArrowGetterFromStruct<>(field(name, new ArrowType.Timestamp(TimeUnit.MICROSECOND, null))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var timeStampMicroVector = (TimeStampMicroVector) valueVector; + return new ArrowWriterFromStruct<>() { + @Override + void setFromStruct(Struct struct) { + if (struct == null) { + timeStampMicroVector.setNull(timeStampMicroVector.getValueCount()); + } else { + timeStampMicroVector.set(timeStampMicroVector.getValueCount(), longGetter.getLong(struct)); + } + timeStampMicroVector.setValueCount(timeStampMicroVector.getValueCount() + 1); + } + }; + } + }; + } + case List: { + return getArrowGetterFromStruct(name, (YTGetters.FromStructToList) getter); + } + case Dict: { + return getArrowGetterFromStruct(name, (YTGetters.FromStructToDict) getter); + } + case Struct: { + return getArrowGetterFromStruct(name, (YTGetters.FromStructToStruct) getter); + } + default: + return null; + } + } + + private ArrowGetterFromStruct getArrowGetterFromStruct( + String name, YTGetters.FromStructToStruct structGetter + ) { + var members = structGetter.getMembersGetters(); + var membersGetters = new ArrayList>(members.size()); + for (var member : members) { + membersGetters.add(arrowGetter(member.getKey(), member.getValue())); + } + return new ArrowGetterFromStruct<>(new Field( + name, new FieldType(false, new ArrowType.Struct(), null), + membersGetters.stream().map(member -> member.field).collect(Collectors.toList()) + )) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var structVector = (StructVector) valueVector; + var membersWriters = new ArrayList>(members.size()); + for (int i = 0; i < members.size(); i++) { + membersWriters.add(membersGetters.get(i).writer(structVector.getChildByOrdinal(i))); + } + return new ArrowWriterFromStruct<>() { + @Override + void setFromStruct(Struct row) { + if (row == null) { + for (int i = 0; i < members.size(); i++) { + membersWriters.get(i).setFromStruct(null); + } + } else { + var value = structGetter.getStruct(row); + structVector.setIndexDefined(structVector.getValueCount()); + for (int i = 0; i < members.size(); i++) { + membersWriters.get(i).setFromStruct(value); + } + } + structVector.setValueCount(structVector.getValueCount() + 1); + } + }; + } + }; + } + + private ArrowGetterFromStruct getArrowGetterFromStruct( + String name, YTGetters.FromStructToDict dictGetter + ) { + var fromDictGetter = dictGetter.getGetter(); + var keyGetter = nonComplexArrowGetter("key", fromDictGetter.getKeyGetter()); + var valueGetter = arrowGetter("value", fromDictGetter.getValueGetter()); + if (keyGetter == null) { + return null; + } + return new ArrowGetterFromStruct<>(new Field( + name, new FieldType(false, new ArrowType.Map(false), null), + Collections.singletonList(new Field( + "entries", new FieldType(false, new ArrowType.Struct(), null), + Arrays.asList(keyGetter.field, valueGetter.field) + )) + )) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var mapVector = (MapVector) valueVector; + var structVector = (StructVector) mapVector.getDataVector(); + var keyWriter = keyGetter.writer(structVector.getChildByOrdinal(0)); + var valueWriter = valueGetter.writer(structVector.getChildByOrdinal(1)); + return new ArrowWriterFromStruct<>() { + @Override + public void setFromStruct(Struct struct) { + var dict = struct == null ? null : dictGetter.getDict(struct); + if (dict != null) { + int size = fromDictGetter.getSize(dict); + var keys = fromDictGetter.getKeys(dict); + var values = fromDictGetter.getValues(dict); + mapVector.startNewValue(mapVector.getValueCount()); + for (int i = 0; i < size; i++) { + structVector.setIndexDefined(structVector.getValueCount()); + keyWriter.setFromList(keys, i); + valueWriter.setFromList(values, i); + structVector.setValueCount(structVector.getValueCount() + 1); + } + mapVector.endValue(mapVector.getValueCount(), size); + } + mapVector.setValueCount(mapVector.getValueCount() + 1); + } + }; + } + }; + } + + private ArrowGetterFromStruct getArrowGetterFromStruct( + String name, YTGetters.FromStructToList listGetter + ) { + var elementGetter = listGetter.getElementGetter(); + var itemGetter = arrowGetter("item", elementGetter); + return new ArrowGetterFromStruct<>(new Field(name, new FieldType( + false, new ArrowType.List(), null + ), Collections.singletonList(itemGetter.field))) { + @Override + public ArrowWriterFromStruct writer(ValueVector valueVector) { + var listVector = (ListVector) valueVector; + var dataWriter = itemGetter.writer(listVector.getDataVector()); + return new ArrowWriterFromStruct<>() { + @Override + public void setFromStruct(Struct struct) { + var list = struct == null ? null : listGetter.getList(struct); + if (list != null) { + int size = elementGetter.getSize(list); + listVector.startNewValue(listVector.getValueCount()); + for (int i = 0; i < size; i++) { + dataWriter.setFromList(list, i); + } + listVector.endValue(listVector.getValueCount(), size); + } + listVector.setValueCount(listVector.getValueCount() + 1); + } + }; + } + }; + } + + private final java.util.List> fieldGetters; + private final Schema schema; + private final BufferAllocator allocator = + ArrowUtils.rootAllocator().newChildAllocator("toBatchIterator", 0, Long.MAX_VALUE); + + public ArrowTableRowsSerializer(java.util.List>> structsGetter) { + super(ERowsetFormat.RF_FORMAT); + fieldGetters = structsGetter.stream().map(memberGetter -> arrowGetter( + memberGetter.getKey(), memberGetter.getValue() + )).collect(Collectors.toList()); + schema = new Schema(() -> fieldGetters.stream().map(getter -> getter.field).iterator()); + } + + @Override + public void close() { + allocator.close(); + } + + private static class ByteBufWritableByteChannel implements WritableByteChannel { + private final ByteBuf buf; + + private ByteBufWritableByteChannel(ByteBuf buf) { + this.buf = buf; + } + + @Override + public int write(ByteBuffer src) { + int remaining = src.remaining(); + buf.writeBytes(src); + return remaining - src.remaining(); + } + + @Override + public boolean isOpen() { + return buf.isWritable(); + } + + @Override + public void close() { + } + } + + @Override + protected void writeMeta(ByteBuf buf, ByteBuf serializedRows, int rowsCount) { + try { + var writeChannel = new WriteChannel(new ByteBufWritableByteChannel(buf)); + MessageSerializer.serialize(writeChannel, schema); + writeChannel.write(serializedRows.nioBuffer()); + ArrowStreamWriter.writeEndOfStream(writeChannel, new IpcOption()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + protected void writeRowsWithoutCount( + ByteBuf buf, TRowsetDescriptor descriptor, java.util.List rows, int[] idMapping + ) { + writeRows(buf, descriptor, rows, idMapping); + } + + @Override + protected void writeRows(ByteBuf buf, TRowsetDescriptor descriptor, java.util.List rows, int[] idMapping) { + try { + var writeChannel = new WriteChannel(new ByteBufWritableByteChannel(buf)); + MessageSerializer.serialize(writeChannel, schema); + try (var root = VectorSchemaRoot.create(schema, allocator)) { + var unloader = new VectorUnloader(root); + var writers = IntStream.range(0, fieldGetters.size()).mapToObj(column -> { + var valueVector = root.getFieldVectors().get(column); + if (valueVector instanceof FixedWidthVector) { + ((FixedWidthVector) valueVector).allocateNew(rows.size()); + } else { + valueVector.allocateNew(); + } + return fieldGetters.get(column).writer(valueVector); + }).collect(Collectors.toList()); + for (var row : rows) { + for (var writer : writers) { + writer.setFromStruct(row); + } + } + root.setRowCount(rows.size()); + try (var batch = unloader.getRecordBatch()) { + MessageSerializer.serialize(writeChannel, batch); + } + writeChannel.writeZeros(4); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/data-source/src/main/scala/tech/ytsaurus/client/ArrowWriteSerializationContext.java b/data-source/src/main/scala/tech/ytsaurus/client/ArrowWriteSerializationContext.java new file mode 100644 index 0000000..bc3eab4 --- /dev/null +++ b/data-source/src/main/scala/tech/ytsaurus/client/ArrowWriteSerializationContext.java @@ -0,0 +1,24 @@ +package tech.ytsaurus.client; + +import tech.ytsaurus.client.request.Format; +import tech.ytsaurus.client.request.SerializationContext; +import tech.ytsaurus.rpcproxy.ERowsetFormat; + +import java.util.HashMap; +import java.util.Map; + +public class ArrowWriteSerializationContext extends SerializationContext { + private final java.util.List>> rowGetters; + + public ArrowWriteSerializationContext( + java.util.List>> rowGetters + ) { + this.rowsetFormat = ERowsetFormat.RF_FORMAT; + this.format = new Format("arrow", new HashMap<>()); + this.rowGetters = rowGetters; + } + + public java.util.List>> getRowGetters() { + return rowGetters; + } +} diff --git a/data-source/src/main/scala/tech/ytsaurus/client/TableWriterBaseImpl.java b/data-source/src/main/scala/tech/ytsaurus/client/TableWriterBaseImpl.java new file mode 100644 index 0000000..fc8ad5b --- /dev/null +++ b/data-source/src/main/scala/tech/ytsaurus/client/TableWriterBaseImpl.java @@ -0,0 +1,114 @@ +package tech.ytsaurus.client; + +import java.io.IOException; +import java.util.List; +import java.util.concurrent.CompletableFuture; + +import javax.annotation.Nullable; + +import tech.ytsaurus.client.request.WriteTable; +import tech.ytsaurus.client.rows.UnversionedRow; +import tech.ytsaurus.client.rows.UnversionedRowSerializer; +import tech.ytsaurus.client.rpc.Compression; +import tech.ytsaurus.client.rpc.RpcUtil; +import tech.ytsaurus.core.tables.TableSchema; +import tech.ytsaurus.lang.NonNullApi; +import tech.ytsaurus.rpcproxy.TWriteTableMeta; + + +@NonNullApi +class TableWriterBaseImpl extends RawTableWriterImpl { + protected @Nullable + TableSchema schema; + protected final WriteTable req; + protected @Nullable + TableRowsSerializer tableRowsSerializer; + private final SerializationResolver serializationResolver; + @Nullable + protected ApiServiceTransaction transaction; + + TableWriterBaseImpl(WriteTable req, SerializationResolver serializationResolver) { + super(req.getWindowSize(), req.getPacketSize()); + this.req = req; + this.serializationResolver = serializationResolver; + var format = this.req.getSerializationContext().getFormat(); + if (format.isEmpty() || !"arrow".equals(format.get().getType())) { + tableRowsSerializer = TableRowsSerializer.createTableRowsSerializer( + this.req.getSerializationContext(), serializationResolver + ).orElse(null); + } + } + + public void setTransaction(ApiServiceTransaction transaction) { + if (this.transaction != null) { + throw new IllegalStateException("Write transaction already started"); + } + this.transaction = transaction; + } + + public CompletableFuture> startUploadImpl() { + TableWriterBaseImpl self = this; + + return startUpload.thenApply((attachments) -> { + if (attachments.size() != 1) { + throw new IllegalArgumentException("protocol error"); + } + byte[] head = attachments.get(0); + if (head == null) { + throw new IllegalArgumentException("protocol error"); + } + + TWriteTableMeta metadata = RpcUtil.parseMessageBodyWithCompression( + head, + TWriteTableMeta.parser(), + Compression.None + ); + self.schema = ApiServiceUtil.deserializeTableSchema(metadata.getSchema()); + logger.debug("schema -> {}", schema.toYTree().toString()); + + { + var format = this.req.getSerializationContext().getFormat(); + if (format.isPresent() && "arrow".equals(format.get().getType())) { + tableRowsSerializer = new ArrowTableRowsSerializer<>( + ((ArrowWriteSerializationContext) this.req.getSerializationContext()).getRowGetters() + ); + } + } + + if (this.tableRowsSerializer == null) { + if (this.req.getSerializationContext().getObjectClass().isEmpty()) { + throw new IllegalStateException("No object clazz"); + } + Class objectClazz = self.req.getSerializationContext().getObjectClass().get(); + if (UnversionedRow.class.equals(objectClazz)) { + this.tableRowsSerializer = + (TableRowsSerializer) new TableRowsWireSerializer<>(new UnversionedRowSerializer()); + } else { + this.tableRowsSerializer = new TableRowsWireSerializer<>( + serializationResolver.createWireRowSerializer( + serializationResolver.forClass(objectClazz, self.schema)) + ); + } + } + + return self; + }); + } + + public boolean write(List rows, TableSchema schema) throws IOException { + byte[] serializedRows = tableRowsSerializer.serializeRows(rows, schema); + return write(serializedRows); + } + + @Override + public CompletableFuture close() { + return super.close() + .thenCompose(response -> { + if (transaction != null && transaction.isActive()) { + return transaction.commit() + .thenApply(unused -> response); + } + return CompletableFuture.completedFuture(response); + }); + } +} diff --git a/data-source/src/main/scala/tech/ytsaurus/client/YTGetters.java b/data-source/src/main/scala/tech/ytsaurus/client/YTGetters.java new file mode 100644 index 0000000..173c98f --- /dev/null +++ b/data-source/src/main/scala/tech/ytsaurus/client/YTGetters.java @@ -0,0 +1,168 @@ +package tech.ytsaurus.client; + +import tech.ytsaurus.typeinfo.*; +import tech.ytsaurus.yson.YsonConsumer; + +import java.math.BigDecimal; +import java.nio.ByteBuffer; +import java.util.Map; + +public class YTGetters { + public interface GetTiType { + TiType getTiType(); + } + + public interface FromStruct extends GetTiType { + void getYson(Struct struct, YsonConsumer ysonConsumer); + } + + public interface FromList extends GetTiType { + int getSize(List list); + + void getYson(List list, int i, YsonConsumer ysonConsumer); + } + + public interface FromStructToYson extends FromStruct { + } + + public interface FromListToYson extends FromList { + } + + public interface FromDict extends GetTiType { + FromList getKeyGetter(); + + FromList getValueGetter(); + + int getSize(Dict dict); + + Keys getKeys(Dict dict); + + Values getValues(Dict dict); + } + + public interface FromStructToNull extends FromStruct { + } + + public interface FromListToNull extends FromList { + } + + public interface FromStructToOptional extends FromStruct { + FromStruct getNotEmptyGetter(); + + boolean isEmpty(Struct struct); + } + + public interface FromListToOptional extends FromList { + FromList getNotEmptyGetter(); + + boolean isEmpty(List list, int i); + } + + public interface FromStructToString extends FromStruct { + ByteBuffer getString(Struct struct); + } + + public interface FromListToString extends FromList { + ByteBuffer getString(List struct, int i); + } + + public interface FromStructToByte extends FromStruct { + byte getByte(Struct struct); + } + + public interface FromListToByte extends FromList { + byte getByte(List list, int i); + } + + public interface FromStructToShort extends FromStruct { + short getShort(Struct struct); + } + + public interface FromListToShort extends FromList { + short getShort(List list, int i); + } + + public interface FromStructToInt extends FromStruct { + int getInt(Struct struct); + } + + public interface FromListToInt extends FromList { + int getInt(List list, int i); + } + + public interface FromStructToLong extends FromStruct { + long getLong(Struct struct); + } + + public interface FromListToLong extends FromList { + long getLong(List list, int i); + } + + public interface FromStructToBoolean extends FromStruct { + boolean getBoolean(Struct struct); + } + + public interface FromListToBoolean extends FromList { + boolean getBoolean(List list, int i); + } + + public interface FromStructToFloat extends FromStruct { + float getFloat(Struct struct); + } + + public interface FromListToFloat extends FromList { + float getFloat(List list, int i); + } + + public interface FromStructToDouble extends FromStruct { + double getDouble(Struct struct); + } + + public interface FromListToDouble extends FromList { + double getDouble(List list, int i); + } + + public interface FromStructToStruct extends FromStruct { + java.util.List>> getMembersGetters(); + + Value getStruct(Struct struct); + } + + public interface FromListToStruct extends FromList { + java.util.List>> getMembersGetters(); + + Value getStruct(List list, int i); + } + + public interface FromStructToList extends FromStruct { + FromList getElementGetter(); + + List getList(Struct struct); + } + + public interface FromListToList extends FromList { + FromList getElementGetter(); + + Value getList(List list, int i); + } + + public interface FromStructToDict extends FromStruct { + FromDict getGetter(); + + Dict getDict(Struct struct); + } + + public interface FromListToDict extends FromList { + FromDict getGetter(); + + Dict getDict(List list, int i); + } + + public interface FromStructToBigDecimal extends FromStruct { + BigDecimal getBigDecimal(Struct struct); + } + + public interface FromListToBigDecimal extends FromList { + BigDecimal getBigDecimal(List list, int i); + } +} diff --git a/data-source/src/main/scala/tech/ytsaurus/spyt/format/YtOutputWriter.scala b/data-source/src/main/scala/tech/ytsaurus/spyt/format/YtOutputWriter.scala index 5e87d47..cb35ab7 100644 --- a/data-source/src/main/scala/tech/ytsaurus/spyt/format/YtOutputWriter.scala +++ b/data-source/src/main/scala/tech/ytsaurus/spyt/format/YtOutputWriter.scala @@ -8,18 +8,19 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.OutputWriter import org.apache.spark.sql.types.StructType import org.slf4j.LoggerFactory +import tech.ytsaurus.client.request.{TransactionalOptions, WriteSerializationContext, WriteTable} +import tech.ytsaurus.client.{ArrowWriteSerializationContext, CompoundClient, TableWriter} +import tech.ytsaurus.core.GUID +import tech.ytsaurus.spyt.format.conf.SparkYtWriteConfiguration import tech.ytsaurus.spyt.format.conf.YtTableSparkSettings._ import tech.ytsaurus.spyt.fs.conf._ import tech.ytsaurus.spyt.fs.path.YPathEnriched import tech.ytsaurus.spyt.serializers.{InternalRowSerializer, WriteSchemaConverter} import tech.ytsaurus.spyt.wrapper.LogLazy -import tech.ytsaurus.client.request.{TransactionalOptions, WriteSerializationContext, WriteTable} -import tech.ytsaurus.client.{CompoundClient, TableWriter} -import tech.ytsaurus.core.GUID -import tech.ytsaurus.spyt.format.conf.SparkYtWriteConfiguration import java.util import java.util.concurrent.{CompletableFuture, TimeUnit} +import scala.collection.JavaConverters.seqAsJavaListConverter import scala.concurrent.{Await, Future} import scala.util.{Failure, Try} @@ -151,9 +152,24 @@ class YtOutputWriter(richPath: YPathEnriched, protected def initializeWriter(): TableWriter[InternalRow] = { val appendPath = richPath.withAttr("append", "true").toYPath log.debugLazy(s"Initialize new write: $appendPath, transaction: $transactionGuid") + val writeSchemaConverter = WriteSchemaConverter(options) val request = WriteTable.builder[InternalRow]() .setPath(appendPath) - .setSerializationContext(new WriteSerializationContext(new InternalRowSerializer(schema, WriteSchemaConverter(options)))) + .setSerializationContext( + if (options.ytConf(ArrowWriteEnabled)) { + if (!writeSchemaConverter.typeV3Format) { + throw new RuntimeException("arrow writer is only supported with typeV3") + } + new ArrowWriteSerializationContext[InternalRow]( + schema.fields.zipWithIndex.map { case (field, i) => + util.Map.entry(field.name, writeSchemaConverter.ytLogicalTypeV3(field).ytGettersFromStruct( + field.dataType, i + )) + }.toSeq.asJava + ) + } else + new WriteSerializationContext(new InternalRowSerializer(schema, WriteSchemaConverter(options))) + ) .setTransactionalOptions(new TransactionalOptions(GUID.valueOf(transactionGuid))) .setNeedRetries(false) .build() diff --git a/data-source/src/main/scala/tech/ytsaurus/spyt/format/conf/YtTableSparkSettings.scala b/data-source/src/main/scala/tech/ytsaurus/spyt/format/conf/YtTableSparkSettings.scala index 84eb69a..7866e0b 100644 --- a/data-source/src/main/scala/tech/ytsaurus/spyt/format/conf/YtTableSparkSettings.scala +++ b/data-source/src/main/scala/tech/ytsaurus/spyt/format/conf/YtTableSparkSettings.scala @@ -65,6 +65,8 @@ object YtTableSparkSettings { case object ArrowEnabled extends ConfigEntry[Boolean]("arrow_enabled", Some(true)) + case object ArrowWriteEnabled extends ConfigEntry[Boolean]("arrow_write_enabled", Some(false)) + case object KeyPartitioned extends ConfigEntry[Boolean]("key_partitioned") case object Dynamic extends ConfigEntry[Boolean]("dynamic") diff --git a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/InternalRowSerializer.scala b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/InternalRowSerializer.scala index 71922c4..2e833da 100644 --- a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/InternalRowSerializer.scala +++ b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/InternalRowSerializer.scala @@ -10,7 +10,6 @@ import org.slf4j.LoggerFactory import tech.ytsaurus.client.TableWriter import tech.ytsaurus.client.rows.{WireProtocolWriteable, WireRowSerializer} import tech.ytsaurus.core.tables.{ColumnValueType, TableSchema} -import tech.ytsaurus.spyt.format.conf.YtTableSparkSettings.{WriteSchemaHint, WriteTypeV3} import tech.ytsaurus.spyt.serialization.YsonEncoder import tech.ytsaurus.spyt.serializers.InternalRowSerializer._ import tech.ytsaurus.spyt.serializers.SchemaConverter.{Unordered, decimalToBinary} diff --git a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/WriteSchemaConverter.scala b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/WriteSchemaConverter.scala index 23a0b85..76d4e86 100644 --- a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/WriteSchemaConverter.scala +++ b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/WriteSchemaConverter.scala @@ -19,25 +19,20 @@ class WriteSchemaConverter( ) { private def ytLogicalTypeV3Variant(struct: StructType): YtLogicalType = { if (isVariantOverTuple(struct)) { - YtLogicalType.VariantOverTuple { - struct.fields.map(tF => - (wrapSparkAttributes(ytLogicalTypeV3(tF), tF.nullable, Some(tF.metadata)), tF.metadata)) - } + YtLogicalType.VariantOverTuple { struct.fields.map(tF => (ytLogicalTypeV3(tF), tF.metadata)) } } else { - YtLogicalType.VariantOverStruct { - struct.fields.map(sf => (sf.name.drop(2), - wrapSparkAttributes(ytLogicalTypeV3(sf), sf.nullable, Some(sf.metadata)), sf.metadata)) - } + YtLogicalType.VariantOverStruct { struct.fields.map(sf => (sf.name.drop(2), ytLogicalTypeV3(sf), sf.metadata)) } } } def ytLogicalTypeStruct(structType: StructType): YtLogicalType.Struct = YtLogicalType.Struct { - structType.fields.map(sf => (sf.name, - wrapSparkAttributes(ytLogicalTypeV3(sf), sf.nullable, Some(sf.metadata)), sf.metadata)) + structType.fields.map(sf => (sf.name, ytLogicalTypeV3(sf), sf.metadata)) } - def ytLogicalTypeV3(structField: StructField): YtLogicalType = - ytLogicalTypeV3(structField.dataType, hint.getOrElse(structField.name, null)) + def ytLogicalTypeV3(structField: StructField): YtLogicalType = wrapSparkAttributes( + ytLogicalTypeV3(structField.dataType, hint.getOrElse(structField.name, null)), + structField.nullable, Some(structField.metadata), + ) def ytLogicalTypeV3(sparkType: DataType, hint: YtLogicalType = null): YtLogicalType = sparkType match { case NullType => YtLogicalType.Null @@ -56,11 +51,15 @@ class WriteSchemaConverter( case FloatType => YtLogicalType.Float case DoubleType => YtLogicalType.Double case BooleanType => YtLogicalType.Boolean - case d: DecimalType => + case d: DecimalType => if (hint != null) { + hint + } else { val dT = if (d.precision > 35) applyYtLimitToSparkDecimal(d) else d - YtLogicalType.Decimal(dT.precision, dT.scale) + YtLogicalType.Decimal(dT.precision, dT.scale, d) + } case aT: ArrayType => YtLogicalType.Array(wrapSparkAttributes(ytLogicalTypeV3(aT.elementType), aT.containsNull)) + case _: StructType if hint != null => hint case sT: StructType if isTuple(sT) => YtLogicalType.Tuple { sT.fields.map(tF => @@ -92,8 +91,7 @@ class WriteSchemaConverter( val builder = YTree.builder .beginMap .key("name").value(field.name) - val fieldType = hint.getOrElse(field.name, - wrapSparkAttributes(ytLogicalTypeV3(field), field.nullable, Some(field.metadata))) + val fieldType = hint.getOrElse(field.name, ytLogicalTypeV3(field)) if (typeV3Format) { builder .key("type_v3").value(serializeTypeV3(fieldType)) diff --git a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalType.scala b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalType.scala index 84d6ed0..47fbf1a 100644 --- a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalType.scala +++ b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalType.scala @@ -1,21 +1,32 @@ package tech.ytsaurus.spyt.serializers +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.spyt.types._ +import org.apache.spark.sql.types import org.apache.spark.sql.types._ +import tech.ytsaurus.client.YTGetters import tech.ytsaurus.core.tables.ColumnValueType import tech.ytsaurus.spyt.serializers.SchemaConverter.MetadataFields import tech.ytsaurus.typeinfo.StructType.Member import tech.ytsaurus.typeinfo.{TiType, TypeName} +import tech.ytsaurus.yson.YsonConsumer +import tech.ytsaurus.ysontree.YTreeBinarySerializer +import java.io.ByteArrayInputStream +import java.nio.ByteBuffer import scala.annotation.tailrec sealed trait SparkType { def topLevel: DataType + def innerLevel: DataType } + case class SingleSparkType(topLevel: DataType) extends SparkType { override def innerLevel: DataType = topLevel } + case class TopInnerSparkTypes(topLevel: DataType, innerLevel: DataType) extends SparkType sealed trait YtLogicalType { @@ -48,10 +59,25 @@ sealed trait YtLogicalType { def alias: YtLogicalTypeAlias def arrowSupported: Boolean = true + + trait FromList extends YTGetters.FromList[ArrayData] { + override def getTiType: TiType = tiType + + override def getSize(list: ArrayData): Int = list.numElements() + } + + def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] + + trait FromStruct extends YTGetters.FromStruct[InternalRow] { + override def getTiType: TiType = tiType + } + + def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] } sealed trait YtLogicalTypeAlias { def name: String = aliases.head + def aliases: Seq[String] } @@ -69,6 +95,7 @@ sealed abstract class AtomicYtLogicalType(name: String, this(name, value, columnValueType, tiType, SingleSparkType(sparkType), otherAliases, arrowSupported) override def alias: YtLogicalTypeAlias = this + override def aliases: Seq[String] = name +: otherAliases } @@ -84,60 +111,501 @@ sealed abstract class CompositeYtLogicalTypeAlias(name: String, } object YtLogicalType { + import tech.ytsaurus.spyt.types.YTsaurusTypes.instance.sparkTypeFor - case object Null extends AtomicYtLogicalType("null", 0x02, ColumnValueType.NULL, TiType.nullType(), NullType) + case object Null extends AtomicYtLogicalType("null", 0x02, ColumnValueType.NULL, TiType.nullType(), NullType) { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToNull[ArrayData] with FromList { + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onEntity() + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToNull[InternalRow] with FromStruct { + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onEntity() + } + } + + case object Int64 extends AtomicYtLogicalType("int64", 0x03, ColumnValueType.INT64, TiType.int64(), LongType) { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToLong[ArrayData] with FromList { + override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(list.getLong(i)) + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToLong[InternalRow] with FromStruct { + override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(struct.getLong(ordinal)) + } + } + + case object Uint64 extends AtomicYtLogicalType("uint64", 0x04, ColumnValueType.UINT64, TiType.uint64(), sparkTypeFor(TiType.uint64())) { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = dataType match { + case decimalType: DecimalType => new YTGetters.FromListToLong[ArrayData] with FromList { + override def getLong(list: ArrayData, i: Int): Long = list.getDecimal(i, decimalType.precision, decimalType.scale).toLong + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getLong(list, i)) + } + case _ => new YTGetters.FromListToLong[ArrayData] with FromList { + override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getLong(list, i)) + } + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = dataType match { + case decimalType: DecimalType => new YTGetters.FromStructToLong[InternalRow] with FromStruct { + override def getLong(struct: InternalRow): Long = struct.getDecimal(ordinal, decimalType.precision, decimalType.scale).toLong + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getLong(struct)) + } + case _ => new YTGetters.FromStructToLong[InternalRow] with FromStruct { + override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getLong(struct)) + } + } + } + + case object Float extends AtomicYtLogicalType( + "float", 0x05, ColumnValueType.DOUBLE, TiType.floatType(), + TopInnerSparkTypes(FloatType, DoubleType), Seq.empty, arrowSupported = false, + ) { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToFloat[ArrayData] with FromList { + override def getFloat(list: ArrayData, i: Int): Float = list.getFloat(i) + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onDouble(getFloat(list, i)) + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToFloat[InternalRow] with FromStruct { + override def getFloat(struct: InternalRow): Float = struct.getFloat(ordinal) + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onDouble(getFloat(struct)) + } + } + + case object Double extends AtomicYtLogicalType("double", 0x05, ColumnValueType.DOUBLE, TiType.doubleType(), DoubleType) { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToDouble[ArrayData] with FromList { + override def getDouble(list: ArrayData, i: Int): Double = list.getDouble(i) + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onDouble(getDouble(list, i)) + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToDouble[InternalRow] with FromStruct { + override def getDouble(struct: InternalRow): Double = struct.getDouble(ordinal) + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onDouble(getDouble(struct)) + } + } + + case object Boolean extends AtomicYtLogicalType("boolean", 0x06, ColumnValueType.BOOLEAN, TiType.bool(), BooleanType, Seq("bool")) { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToBoolean[ArrayData] with FromList { + override def getBoolean(list: ArrayData, i: Int): Boolean = list.getBoolean(i) + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onBoolean(list.getBoolean(i)) + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToBoolean[InternalRow] with FromStruct { + override def getBoolean(struct: InternalRow): Boolean = struct.getBoolean(ordinal) - case object Int64 extends AtomicYtLogicalType("int64", 0x03, ColumnValueType.INT64, TiType.int64(), LongType) - case object Uint64 extends AtomicYtLogicalType("uint64", 0x04, ColumnValueType.UINT64, TiType.uint64(), sparkTypeFor(TiType.uint64())) - case object Float extends AtomicYtLogicalType("float", 0x05, ColumnValueType.DOUBLE, TiType.floatType(), - TopInnerSparkTypes(FloatType, DoubleType), Seq.empty, arrowSupported = false) - case object Double extends AtomicYtLogicalType("double", 0x05, ColumnValueType.DOUBLE, TiType.doubleType(), DoubleType) - case object Boolean extends AtomicYtLogicalType("boolean", 0x06, ColumnValueType.BOOLEAN, TiType.bool(), BooleanType, Seq("bool")) + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onBoolean(struct.getBoolean(ordinal)) + } + } + + private def getBytes(byteBuffer: ByteBuffer): scala.Array[Byte] = { + val bytes = new scala.Array[Byte](byteBuffer.remaining()) + byteBuffer.get(bytes) + bytes + } + + case object String extends AtomicYtLogicalType("string", 0x10, ColumnValueType.STRING, TiType.string(), StringType) { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToString[ArrayData] with FromList { + override def getString(list: ArrayData, i: Int): ByteBuffer = list.getUTF8String(i).getByteBuffer + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = { + val bytes = getBytes(getString(list, i)) + ysonConsumer.onString(bytes, 0, bytes.length) + } + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToString[InternalRow] with FromStruct { + override def getString(struct: InternalRow): ByteBuffer = struct.getUTF8String(ordinal).getByteBuffer + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = { + val bytes = getBytes(getString(struct)) + ysonConsumer.onString(bytes, 0, bytes.length) + } + } + } - case object String extends AtomicYtLogicalType("string", 0x10, ColumnValueType.STRING, TiType.string(), StringType) case object Binary extends AtomicYtLogicalType("binary", 0x10, ColumnValueType.STRING, TiType.string(), BinaryType) { override def getName(isColumnType: Boolean): String = columnValueType.getName override def getNameV3(inner: Boolean): String = { if (inner) alias.name else "string" } + + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToString[ArrayData] with FromList { + override def getString(list: ArrayData, i: Int): ByteBuffer = ByteBuffer.wrap(list.getBinary(i)) + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = { + val bytes = getBytes(getString(list, i)) + ysonConsumer.onString(bytes, 0, bytes.length) + } + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToString[InternalRow] with FromStruct { + override def getString(struct: InternalRow): ByteBuffer = ByteBuffer.wrap(struct.getBinary(ordinal)) + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = { + val bytes = getBytes(getString(struct)) + ysonConsumer.onString(bytes, 0, bytes.length) + } + } } + case object Any extends AtomicYtLogicalType("any", 0x11, ColumnValueType.ANY, TiType.yson(), sparkTypeFor(TiType.yson()), Seq("yson")) { override def nullable: Boolean = true + + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToYson[ArrayData] with FromList { + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + YTreeBinarySerializer.deserialize(new ByteArrayInputStream(list.getBinary(i)), ysonConsumer) + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToYson[InternalRow] with FromStruct { + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + YTreeBinarySerializer.deserialize(new ByteArrayInputStream(struct.getBinary(ordinal)), ysonConsumer) + } + } + + case object Int8 extends AtomicYtLogicalType("int8", 0x1000, ColumnValueType.INT64, TiType.int8(), ByteType) { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToByte[ArrayData] with FromList { + override def getByte(list: ArrayData, i: Int): Byte = list.getByte(i) + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(list.getByte(i)) + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToByte[InternalRow] with FromStruct { + override def getByte(struct: InternalRow): Byte = struct.getByte(ordinal) + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(struct.getByte(ordinal)) + } + } + + case object Uint8 extends AtomicYtLogicalType("uint8", 0x1001, ColumnValueType.INT64, TiType.uint8(), ShortType) { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToByte[ArrayData] with FromList { + override def getByte(list: ArrayData, i: Int): Byte = list.getShort(i).toByte + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getByte(list, i)) + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToByte[InternalRow] with FromStruct { + override def getByte(struct: InternalRow): Byte = struct.getShort(ordinal).toByte + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getByte(struct)) + } + } + + case object Int16 extends AtomicYtLogicalType("int16", 0x1003, ColumnValueType.INT64, TiType.int16(), ShortType) { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToShort[ArrayData] with FromList { + override def getShort(list: ArrayData, i: Int): Short = list.getShort(i) + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onInteger(list.getShort(i)) + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToShort[InternalRow] with FromStruct { + override def getShort(struct: InternalRow): Short = struct.getShort(ordinal) + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onInteger(struct.getShort(ordinal)) + } } - case object Int8 extends AtomicYtLogicalType("int8", 0x1000, ColumnValueType.INT64, TiType.int8(), ByteType) - case object Uint8 extends AtomicYtLogicalType("uint8", 0x1001, ColumnValueType.INT64, TiType.uint8(), ShortType) + case object Uint16 extends AtomicYtLogicalType("uint16", 0x1004, ColumnValueType.INT64, TiType.uint16(), IntegerType) { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToShort[ArrayData] with FromList { + override def getShort(list: ArrayData, i: Int): Short = list.getInt(i).toShort + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getShort(list, i)) + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToShort[InternalRow] with FromStruct { + override def getShort(struct: InternalRow): Short = struct.getInt(ordinal).toShort + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getShort(struct)) + } + } + + case object Int32 extends AtomicYtLogicalType("int32", 0x1005, ColumnValueType.INT64, TiType.int32(), IntegerType) { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToInt[ArrayData] with FromList { + override def getInt(list: ArrayData, i: Int): Int = list.getInt(i) + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onInteger(list.getInt(i)) + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToInt[InternalRow] with FromStruct { + override def getInt(struct: InternalRow): Int = struct.getInt(ordinal) + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onInteger(struct.getInt(ordinal)) + } + } - case object Int16 extends AtomicYtLogicalType("int16", 0x1003, ColumnValueType.INT64, TiType.int16(), ShortType) - case object Uint16 extends AtomicYtLogicalType("uint16", 0x1004, ColumnValueType.INT64, TiType.uint16(), IntegerType) + case object Uint32 extends AtomicYtLogicalType("uint32", 0x1006, ColumnValueType.INT64, TiType.uint32(), LongType) { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToInt[ArrayData] with FromList { + override def getInt(list: ArrayData, i: Int): Int = list.getLong(i).toInt - case object Int32 extends AtomicYtLogicalType("int32", 0x1005, ColumnValueType.INT64, TiType.int32(), IntegerType) - case object Uint32 extends AtomicYtLogicalType("uint32", 0x1006, ColumnValueType.INT64, TiType.uint32(), LongType) + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getInt(list, i)) + } - case object Utf8 extends AtomicYtLogicalType("utf8", 0x1007, ColumnValueType.STRING, TiType.utf8(), StringType) + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToInt[InternalRow] with FromStruct { + override def getInt(struct: InternalRow): Int = struct.getLong(ordinal).toInt + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onUnsignedInteger(getInt(struct)) + } + } + + case object Utf8 extends AtomicYtLogicalType("utf8", 0x1007, ColumnValueType.STRING, TiType.utf8(), StringType) { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToString[ArrayData] with FromList { + override def getString(list: ArrayData, i: Int): ByteBuffer = list.getUTF8String(i).getByteBuffer + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = { + val bytes = getBytes(list.getUTF8String(i).getByteBuffer) + ysonConsumer.onString(bytes, 0, bytes.length) + } + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToString[InternalRow] with FromStruct { + override def getString(struct: InternalRow): ByteBuffer = struct.getUTF8String(ordinal).getByteBuffer + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = { + val bytes = getBytes(struct.getUTF8String(ordinal).getByteBuffer) + ysonConsumer.onString(bytes, 0, bytes.length) + } + } + } // Unsupported types are listed here: yt/yt/client/arrow/arrow_row_stream_encoder.cpp - case object Date extends AtomicYtLogicalType("date", 0x1008, ColumnValueType.UINT64, TiType.date(), DateType, arrowSupported = false) - case object Datetime extends AtomicYtLogicalType("datetime", 0x1009, ColumnValueType.UINT64, TiType.datetime(), new DatetimeType(), arrowSupported = false) - case object Timestamp extends AtomicYtLogicalType("timestamp", 0x100a, ColumnValueType.UINT64, TiType.timestamp(), TimestampType, arrowSupported = false) - case object Interval extends AtomicYtLogicalType("interval", 0x100b, ColumnValueType.INT64, TiType.interval(), LongType, arrowSupported = false) + case object Date extends AtomicYtLogicalType("date", 0x1008, ColumnValueType.UINT64, TiType.date(), DateType, arrowSupported = false) { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToInt[ArrayData] with FromList { + override def getInt(list: ArrayData, i: Int): Int = list.getInt(i) + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getInt(list, i)) + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToInt[InternalRow] with FromStruct { + override def getInt(struct: InternalRow): Int = struct.getInt(ordinal) + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getInt(struct)) + } + } + + case object Datetime extends AtomicYtLogicalType("datetime", 0x1009, ColumnValueType.UINT64, TiType.datetime(), new DatetimeType(), arrowSupported = false) { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToLong[ArrayData] with FromList { + override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getLong(list, i)) + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToLong[InternalRow] with FromStruct { + override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getLong(struct)) + } + } - case object Void extends AtomicYtLogicalType("void", 0x100c, ColumnValueType.NULL, TiType.voidType(), NullType) //? + case object Timestamp extends AtomicYtLogicalType("timestamp", 0x100a, ColumnValueType.UINT64, TiType.timestamp(), TimestampType, arrowSupported = false) { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToLong[ArrayData] with FromList { + override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) - case object Date32 extends AtomicYtLogicalType("date32", 0x1018, ColumnValueType.INT64, TiType.date32(), new Date32Type(), arrowSupported = false) - case object Datetime64 extends AtomicYtLogicalType("datetime64", 0x1019, ColumnValueType.INT64, TiType.datetime64(), new Datetime64Type(), arrowSupported = false) - case object Timestamp64 extends AtomicYtLogicalType("timestamp64", 0x101a, ColumnValueType.INT64, TiType.timestamp64(), new Timestamp64Type(), arrowSupported = false) - case object Interval64 extends AtomicYtLogicalType("interval64", 0x101b, ColumnValueType.INT64, TiType.interval64(), new Interval64Type(), arrowSupported = false) + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getLong(list, i)) + } + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToLong[InternalRow] with FromStruct { + override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) - case class Decimal(precision: Int, scale: Int) extends CompositeYtLogicalType { + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getLong(struct)) + } + } + + case object Interval extends AtomicYtLogicalType("interval", 0x100b, ColumnValueType.INT64, TiType.interval(), LongType, arrowSupported = false) { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToLong[ArrayData] with FromList { + override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onInteger(getLong(list, i)) + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToLong[InternalRow] with FromStruct { + override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onInteger(getLong(struct)) + } + } + + case object Void extends AtomicYtLogicalType("void", 0x100c, ColumnValueType.NULL, TiType.voidType(), NullType) { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToNull[ArrayData] with FromList { + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onEntity() + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToNull[InternalRow] with FromStruct { + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = ysonConsumer.onEntity() + } + } + + case object Date32 extends AtomicYtLogicalType("date32", 0x1018, ColumnValueType.INT64, TiType.date32(), new Date32Type(), arrowSupported = false) { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToInt[ArrayData] with FromList { + override def getInt(list: ArrayData, i: Int): Int = list.getInt(i) + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getInt(list, i)) + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToInt[InternalRow] with FromStruct { + override def getInt(struct: InternalRow): Int = struct.getInt(ordinal) + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getInt(struct)) + } + } + + case object Datetime64 extends AtomicYtLogicalType("datetime64", 0x1019, ColumnValueType.INT64, TiType.datetime64(), new Datetime64Type(), arrowSupported = false) { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToLong[ArrayData] with FromList { + override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getLong(list, i)) + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToLong[InternalRow] with FromStruct { + override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getLong(struct)) + } + } + + case object Timestamp64 extends AtomicYtLogicalType("timestamp64", 0x101a, ColumnValueType.INT64, TiType.timestamp64(), new Timestamp64Type(), arrowSupported = false) { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToLong[ArrayData] with FromList { + override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getLong(list, i)) + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToLong[InternalRow] with FromStruct { + override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onUnsignedInteger(getLong(struct)) + } + } + + case object Interval64 extends AtomicYtLogicalType("interval64", 0x101b, ColumnValueType.INT64, TiType.interval64(), new Interval64Type(), arrowSupported = false) { + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToLong[ArrayData] with FromList { + override def getLong(list: ArrayData, i: Int): Long = list.getLong(i) + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onInteger(getLong(list, i)) + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToLong[InternalRow] with FromStruct { + override def getLong(struct: InternalRow): Long = struct.getLong(ordinal) + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonConsumer.onInteger(getLong(struct)) + } + } + + case class Decimal(precision: Int, scale: Int, decimalType: DecimalType) extends CompositeYtLogicalType { override def sparkType: SparkType = SingleSparkType(DecimalType(precision, scale)) override def alias: CompositeYtLogicalTypeAlias = Decimal override def tiType: TiType = TiType.decimal(precision, scale) + + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToBigDecimal[ArrayData] with FromList { + override def getBigDecimal(list: ArrayData, i: Int): java.math.BigDecimal = + list.getDecimal(i, decimalType.precision, decimalType.scale).toJavaBigDecimal.setScale(scale) + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = { + val bytes = getBigDecimal(list, i).unscaledValue().toByteArray + ysonConsumer.onString(bytes, 0, bytes.length) + } + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToBigDecimal[InternalRow] with FromStruct { + override def getBigDecimal(struct: InternalRow): java.math.BigDecimal = + struct.getDecimal(ordinal, decimalType.precision, decimalType.scale).toJavaBigDecimal.setScale(scale) + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = { + val bytes = getBigDecimal(struct).unscaledValue().toByteArray + ysonConsumer.onString(bytes, 0, bytes.length) + } + } } case object Decimal extends CompositeYtLogicalTypeAlias("decimal") @@ -158,6 +626,50 @@ object YtLogicalType { override def alias: CompositeYtLogicalTypeAlias = Optional override def arrowSupported: Boolean = inner.arrowSupported + + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToOptional[ArrayData] with FromList { + private val notEmptyGetter = inner.ytGettersFromList(dataType) + + override def getNotEmptyGetter: YTGetters.FromList[ArrayData] = notEmptyGetter + + override def isEmpty(list: ArrayData, i: Int): Boolean = list.isNullAt(i) + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = { + if (list.isNullAt(i)) { + ysonConsumer.onEntity() + } else if (inner.isInstanceOf[Optional]) { + ysonConsumer.onBeginList() + ysonConsumer.onListItem() + notEmptyGetter.getYson(list, i, ysonConsumer) + ysonConsumer.onEndList() + } else { + notEmptyGetter.getYson(list, i, ysonConsumer) + } + } + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToOptional[InternalRow] with FromStruct { + private val notEmptyGetter = inner.ytGettersFromStruct(dataType, ordinal) + + override def getNotEmptyGetter: YTGetters.FromStruct[InternalRow] = notEmptyGetter + + override def isEmpty(struct: InternalRow): Boolean = struct.isNullAt(ordinal) + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = { + if (struct.isNullAt(ordinal)) { + ysonConsumer.onEntity() + } else if (inner.isInstanceOf[Optional]) { + ysonConsumer.onBeginList() + ysonConsumer.onListItem() + notEmptyGetter.getYson(struct, ysonConsumer) + ysonConsumer.onEndList() + } else { + notEmptyGetter.getYson(struct, ysonConsumer) + } + } + } } case object Optional extends CompositeYtLogicalTypeAlias(TypeName.Optional.getWireName) @@ -172,19 +684,116 @@ object YtLogicalType { MapType(dictKey.sparkType.innerLevel, dictValue.sparkType.innerLevel, dictValue.nullable) ) + private def newGetter(dataType: DataType): YTGetters.FromDict[MapData, ArrayData, ArrayData] = + new YTGetters.FromDict[MapData, ArrayData, ArrayData] { + private val keyGetter = dictKey.ytGettersFromList(dataType.asInstanceOf[MapType].keyType) + private val valueGetter = dictValue.ytGettersFromList(dataType.asInstanceOf[MapType].valueType) + + override def getKeyGetter: YTGetters.FromList[ArrayData] = keyGetter + + override def getValueGetter: YTGetters.FromList[ArrayData] = valueGetter + + override def getSize(dict: MapData): Int = dict.numElements() + + override def getKeys(dict: MapData): ArrayData = dict.keyArray() + + override def getValues(dict: MapData): ArrayData = dict.valueArray() + + override def getTiType: TiType = tiType + } + + def newYsonSerializer(getter: YTGetters.FromDict[MapData, ArrayData, ArrayData]): (MapData, YsonConsumer) => Unit = { + val keyGetter = getter.getKeyGetter + val valueGetter = getter.getValueGetter + (dict, ysonConsumer) => { + ysonConsumer.onBeginList() + val keys = dict.keyArray() + val values = dict.valueArray() + for (i <- 0 until dict.numElements()) { + ysonConsumer.onListItem() + ysonConsumer.onBeginList() + ysonConsumer.onListItem() + keyGetter.getYson(keys, i, ysonConsumer) + ysonConsumer.onListItem() + valueGetter.getYson(values, i, ysonConsumer) + ysonConsumer.onEndList() + } + ysonConsumer.onEndList() + } + } + override def tiType: TiType = TiType.dict(dictKey.tiType, dictValue.tiType) override def alias: CompositeYtLogicalTypeAlias = Dict + + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToDict[ArrayData, MapData, ArrayData, ArrayData] with FromList { + private val getter = newGetter(dataType) + private val ysonSerializer = newYsonSerializer(getter) + + override def getGetter(): YTGetters.FromDict[MapData, ArrayData, ArrayData] = getter + + override def getDict(list: ArrayData, i: Int): MapData = list.getMap(i) + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + ysonSerializer(list.getMap(i), ysonConsumer) + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToDict[InternalRow, MapData, ArrayData, ArrayData] with FromStruct { + private val getter = newGetter(dataType) + private val ysonSerializer = newYsonSerializer(getter) + + override def getGetter(): YTGetters.FromDict[MapData, ArrayData, ArrayData] = getter + + override def getDict(struct: InternalRow): MapData = struct.getMap(ordinal) + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + ysonSerializer(struct.getMap(ordinal), ysonConsumer) + } } case object Dict extends CompositeYtLogicalTypeAlias(TypeName.Dict.getWireName) case class Array(inner: YtLogicalType) extends CompositeYtLogicalType { - override def sparkType: SparkType = SingleSparkType( ArrayType(inner.sparkType.innerLevel, inner.nullable)) + override def sparkType: SparkType = SingleSparkType(ArrayType(inner.sparkType.innerLevel, inner.nullable)) override def tiType: TiType = TiType.list(inner.tiType) override def alias: CompositeYtLogicalTypeAlias = Array + + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToList[ArrayData, ArrayData] with FromList { + val elementGetter: YTGetters.FromList[ArrayData] = inner.ytGettersFromList(dataType.asInstanceOf[ArrayType].elementType) + + override def getElementGetter: YTGetters.FromList[ArrayData] = elementGetter + + override def getList(list: ArrayData, i: Int): ArrayData = list.getArray(i) + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + onList(ysonConsumer, elementGetter, list.getArray(i)) + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToList[InternalRow, ArrayData] with FromStruct { + val elementGetter: YTGetters.FromList[ArrayData] = inner.ytGettersFromList(dataType.asInstanceOf[ArrayType].elementType) + + override def getElementGetter: YTGetters.FromList[ArrayData] = elementGetter + + override def getList(struct: InternalRow): ArrayData = struct.getArray(ordinal) + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + onList(ysonConsumer, elementGetter, struct.getArray(ordinal)) + } + + private def onList(ysonConsumer: YsonConsumer, elementGetter: YTGetters.FromList[ArrayData], value: ArrayData): Unit = { + ysonConsumer.onBeginList() + for (j <- 0 until value.numElements()) { + ysonConsumer.onListItem() + elementGetter.getYson(value, j, ysonConsumer) + } + ysonConsumer.onEndList() + } } case object Array extends CompositeYtLogicalTypeAlias(TypeName.List.getWireName) @@ -194,25 +803,116 @@ object YtLogicalType { .map { case (name, ytType, meta) => getStructField(name, ytType, meta, topLevel = false) })) import scala.collection.JavaConverters._ + override def tiType: TiType = TiType.struct( - fields.map{ case (name, ytType, _) => new Member(name, ytType.tiType)}.asJava + fields.map { case (name, ytType, _) => new Member(name, ytType.tiType) }.asJava ) override def alias: CompositeYtLogicalTypeAlias = Struct + + def newMembersGetters(dataType: DataType): java.util.List[java.util.Map.Entry[String, YTGetters.FromStruct[InternalRow]]] = + fields.zip(dataType.asInstanceOf[StructType].fields).zipWithIndex.map { case ((field, structField), i) => + java.util.Map.entry(field._1, field._2.ytGettersFromStruct(structField.dataType, i)) + }.asJava + + def yson( + membersGetters: java.util.List[java.util.Map.Entry[String, YTGetters.FromStruct[InternalRow]]], + internalRow: InternalRow, ysonConsumer: YsonConsumer, + ): Unit = { + ysonConsumer.onBeginList() + for (i <- 0 until membersGetters.size()) { + ysonConsumer.onListItem() + membersGetters.get(i).getValue.getYson(internalRow, ysonConsumer) + } + ysonConsumer.onEndList() + } + + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToStruct[ArrayData, InternalRow] with FromList { + private val membersGetters = newMembersGetters(dataType) + + override def getMembersGetters(): java.util.List[java.util.Map.Entry[String, YTGetters.FromStruct[InternalRow]]] = + membersGetters + + override def getStruct(list: ArrayData, i: Int): InternalRow = list.getStruct(i, fields.size) + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + yson(membersGetters, list.getStruct(i, membersGetters.size()), ysonConsumer) + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToStruct[InternalRow, InternalRow] with FromStruct { + private val membersGetters = newMembersGetters(dataType) + + override def getMembersGetters(): java.util.List[java.util.Map.Entry[String, YTGetters.FromStruct[InternalRow]]] = + membersGetters + + override def getStruct(struct: InternalRow): InternalRow = struct.getStruct(ordinal, fields.size) + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + yson(membersGetters, struct.getStruct(ordinal, membersGetters.size()), ysonConsumer) + } } case object Struct extends CompositeYtLogicalTypeAlias(TypeName.Struct.getWireName) case class Tuple(elements: Seq[(YtLogicalType, Metadata)]) extends CompositeYtLogicalType { + private val entries = elements.zipWithIndex.map { case ((ytType, _), index) => (s"_${1 + index}", ytType) } + override def sparkType: SparkType = SingleSparkType(StructType(elements.zipWithIndex .map { case ((ytType, meta), index) => getStructField(s"_${1 + index}", ytType, meta, topLevel = false) })) import scala.collection.JavaConverters._ + override def tiType: TiType = TiType.tuple( - elements.map { case (e, _) => e.tiType } .asJava + elements.map { case (e, _) => e.tiType }.asJava ) override def alias: CompositeYtLogicalTypeAlias = Tuple + + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToStruct[ArrayData, InternalRow] with FromList { + private val membersGetters = entries.zip(dataType.asInstanceOf[types.StructType]).zipWithIndex.map { + case (((name, logicalType), structField), i) => + java.util.Map.entry(name, logicalType.ytGettersFromStruct(structField.dataType, i)) + }.asJava + + override def getMembersGetters(): java.util.List[java.util.Map.Entry[String, YTGetters.FromStruct[InternalRow]]] = membersGetters + + override def getStruct(list: ArrayData, i: Int): InternalRow = list.getStruct(i, elements.size) + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = { + val value = list.getStruct(i, membersGetters.size()) + ysonConsumer.onBeginList() + membersGetters.forEach { getter => + ysonConsumer.onListItem() + getter.getValue.getYson(value, ysonConsumer) + } + ysonConsumer.onEndList() + } + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToStruct[InternalRow, InternalRow] with FromStruct { + private val membersGetters = entries.zip(dataType.asInstanceOf[types.StructType]).zipWithIndex.map { + case (((name, logicalType), structField), i) => + java.util.Map.entry(name, logicalType.ytGettersFromStruct(structField.dataType, i)) + }.asJava + + override def getMembersGetters(): java.util.List[java.util.Map.Entry[String, YTGetters.FromStruct[InternalRow]]] = membersGetters + + override def getStruct(struct: InternalRow): InternalRow = struct.getStruct(ordinal, elements.size) + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = { + val value = struct.getStruct(ordinal, membersGetters.size()) + ysonConsumer.onBeginList() + membersGetters.forEach { getter => + ysonConsumer.onListItem() + getter.getValue.getYson(value, ysonConsumer) + } + ysonConsumer.onEndList() + } + } } case object Tuple extends CompositeYtLogicalTypeAlias(TypeName.Tuple.getWireName) @@ -223,34 +923,97 @@ object YtLogicalType { override def tiType: TiType = TiType.tagged(inner.tiType, tag) override def alias: CompositeYtLogicalTypeAlias = Tagged + + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = inner.ytGettersFromList(dataType) + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = inner.ytGettersFromStruct(dataType, ordinal) } case object Tagged extends CompositeYtLogicalTypeAlias(TypeName.Tagged.getWireName) + private class VariantGetter(fields: Seq[YtLogicalType], dataType: DataType) { + private val getters = fields.zip(dataType.asInstanceOf[StructType].fields).zipWithIndex.map { + case ((field, structField), i) => field.ytGettersFromStruct(structField.dataType, i) + } + + def get(row: InternalRow, ysonConsumer: YsonConsumer): Unit = { + val notNulls = (0 until row.numFields).filter(!row.isNullAt(_)) + if (notNulls.isEmpty) { + throw new IllegalArgumentException("All elements in variant is null") + } else if (notNulls.size > 1) { + throw new IllegalArgumentException("Not null element must be single") + } else { + val index = notNulls.head + ysonConsumer.onBeginList() + ysonConsumer.onListItem() + ysonConsumer.onInteger(index) + ysonConsumer.onListItem() + getters(index).getYson(row, ysonConsumer) + ysonConsumer.onEndList() + } + } + } + case class VariantOverStruct(fields: Seq[(String, YtLogicalType, Metadata)]) extends CompositeYtLogicalType { override def sparkType: SparkType = SingleSparkType(StructType(fields.map { case (name, ytType, meta) => - getStructField(s"_v$name", ytType, meta, forcedNullability = Some(true), topLevel = false) })) + getStructField(s"_v$name", ytType, meta, forcedNullability = Some(true), topLevel = false) + })) import scala.collection.JavaConverters._ + override def tiType: TiType = TiType.variantOverStruct( - fields.map{ case (name, ytType, _) => new Member(name, ytType.tiType)}.asJava + fields.map { case (name, ytType, _) => new Member(name, ytType.tiType) }.asJava ) override def alias: CompositeYtLogicalTypeAlias = Variant + + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToYson[ArrayData] with FromList { + val getter = new VariantGetter(fields.map(_._2), dataType) + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + getter.get(list.getStruct(i, fields.size), ysonConsumer) + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToYson[InternalRow] with FromStruct { + val getter = new VariantGetter(fields.map(_._2), dataType) + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + getter.get(struct.getStruct(ordinal, fields.size), ysonConsumer) + } } case class VariantOverTuple(fields: Seq[(YtLogicalType, Metadata)]) extends CompositeYtLogicalType { override def sparkType: SparkType = SingleSparkType( StructType(fields.zipWithIndex.map { case ((ytType, meta), index) => - getStructField(s"_v_${1 + index}", ytType, meta, forcedNullability = Some(true), topLevel = false) }) + getStructField(s"_v_${1 + index}", ytType, meta, forcedNullability = Some(true), topLevel = false) + }) ) import scala.collection.JavaConverters._ + override def tiType: TiType = TiType.variantOverTuple( fields.map { case (e, _) => e.tiType }.asJava ) override def alias: CompositeYtLogicalTypeAlias = Variant + + override def ytGettersFromList(dataType: DataType): YTGetters.FromList[ArrayData] = + new YTGetters.FromListToYson[ArrayData] with FromList { + val getter = new VariantGetter(fields.map(_._1), dataType) + + override def getYson(list: ArrayData, i: Int, ysonConsumer: YsonConsumer): Unit = + getter.get(list.getStruct(i, fields.size), ysonConsumer) + } + + override def ytGettersFromStruct(dataType: DataType, ordinal: Int): YTGetters.FromStruct[InternalRow] = + new YTGetters.FromStructToYson[InternalRow] with FromStruct { + val getter = new VariantGetter(fields.map(_._1), dataType) + + override def getYson(struct: InternalRow, ysonConsumer: YsonConsumer): Unit = + getter.get(struct.getStruct(ordinal, fields.size), ysonConsumer) + } } case object Variant extends CompositeYtLogicalTypeAlias(TypeName.Variant.getWireName) diff --git a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalTypeSerializer.scala b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalTypeSerializer.scala index c2b3599..6a184c9 100644 --- a/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalTypeSerializer.scala +++ b/data-source/src/main/scala/tech/ytsaurus/spyt/serializers/YtLogicalTypeSerializer.scala @@ -1,6 +1,6 @@ package tech.ytsaurus.spyt.serializers -import org.apache.spark.sql.types.Metadata +import org.apache.spark.sql.types.{DecimalType, Metadata} import tech.ytsaurus.ysontree.{YTree, YTreeBuilder, YTreeMapNode, YTreeNode, YTreeStringNode} import scala.collection.JavaConverters.asScalaBufferConverter @@ -105,10 +105,8 @@ object YtLogicalTypeSerializer { case YtLogicalType.Optional => YtLogicalType.Optional(deserializeTypeV3(m.getOrThrow("item"))) case YtLogicalType.Decimal => - YtLogicalType.Decimal( - m.getOrThrow("precision").intValue(), - m.getOrThrow("scale").intValue() - ) + val decimalType = DecimalType(m.getOrThrow("precision").intValue(), m.getOrThrow("scale").intValue()) + YtLogicalType.Decimal(decimalType.precision, decimalType.scale, decimalType) case YtLogicalType.Dict => YtLogicalType.Dict( deserializeTypeV3(m.getOrThrow("key")),