Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Write protocol should be Apache Arrow #49

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
1,237 changes: 1,237 additions & 0 deletions data-source/src/main/scala/tech/ytsaurus/client/ArrowTableRowsSerializer.java

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package tech.ytsaurus.client;

import tech.ytsaurus.client.request.Format;
import tech.ytsaurus.client.request.SerializationContext;
import tech.ytsaurus.rpcproxy.ERowsetFormat;

import java.util.HashMap;
import java.util.Map;

public class ArrowWriteSerializationContext<Row> extends SerializationContext<Row> {
private final java.util.List<? extends Map.Entry<String, ? extends YTGetters.FromStruct<Row>>> rowGetters;

public ArrowWriteSerializationContext(
java.util.List<? extends Map.Entry<String, ? extends YTGetters.FromStruct<Row>>> rowGetters
) {
this.rowsetFormat = ERowsetFormat.RF_FORMAT;
this.format = new Format("arrow", new HashMap<>());
this.rowGetters = rowGetters;
}

public java.util.List<? extends Map.Entry<String, ? extends YTGetters.FromStruct<Row>>> getRowGetters() {
return rowGetters;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package tech.ytsaurus.client;

import java.io.IOException;
import java.util.List;
import java.util.concurrent.CompletableFuture;

import javax.annotation.Nullable;

import tech.ytsaurus.client.request.WriteTable;
import tech.ytsaurus.client.rows.UnversionedRow;
import tech.ytsaurus.client.rows.UnversionedRowSerializer;
import tech.ytsaurus.client.rpc.Compression;
import tech.ytsaurus.client.rpc.RpcUtil;
import tech.ytsaurus.core.tables.TableSchema;
import tech.ytsaurus.lang.NonNullApi;
import tech.ytsaurus.rpcproxy.TWriteTableMeta;


@NonNullApi
class TableWriterBaseImpl<T> extends RawTableWriterImpl {
protected @Nullable
TableSchema schema;
protected final WriteTable<T> req;
protected @Nullable
TableRowsSerializer<T> tableRowsSerializer;
private final SerializationResolver serializationResolver;
@Nullable
protected ApiServiceTransaction transaction;

TableWriterBaseImpl(WriteTable<T> req, SerializationResolver serializationResolver) {
super(req.getWindowSize(), req.getPacketSize());
this.req = req;
this.serializationResolver = serializationResolver;
var format = this.req.getSerializationContext().getFormat();
if (format.isEmpty() || !"arrow".equals(format.get().getType())) {
tableRowsSerializer = TableRowsSerializer.createTableRowsSerializer(
this.req.getSerializationContext(), serializationResolver
).orElse(null);
}
}

public void setTransaction(ApiServiceTransaction transaction) {
if (this.transaction != null) {
throw new IllegalStateException("Write transaction already started");
}
this.transaction = transaction;
}

public CompletableFuture<TableWriterBaseImpl<T>> startUploadImpl() {
TableWriterBaseImpl<T> self = this;

return startUpload.thenApply((attachments) -> {
if (attachments.size() != 1) {
throw new IllegalArgumentException("protocol error");
}
byte[] head = attachments.get(0);
if (head == null) {
throw new IllegalArgumentException("protocol error");
}

TWriteTableMeta metadata = RpcUtil.parseMessageBodyWithCompression(
head,
TWriteTableMeta.parser(),
Compression.None
);
self.schema = ApiServiceUtil.deserializeTableSchema(metadata.getSchema());
logger.debug("schema -> {}", schema.toYTree().toString());

{
var format = this.req.getSerializationContext().getFormat();
if (format.isPresent() && "arrow".equals(format.get().getType())) {
tableRowsSerializer = new ArrowTableRowsSerializer<>(
((ArrowWriteSerializationContext<T>) this.req.getSerializationContext()).getRowGetters()
);
}
}

if (this.tableRowsSerializer == null) {
if (this.req.getSerializationContext().getObjectClass().isEmpty()) {
throw new IllegalStateException("No object clazz");
}
Class<T> objectClazz = self.req.getSerializationContext().getObjectClass().get();
if (UnversionedRow.class.equals(objectClazz)) {
this.tableRowsSerializer =
(TableRowsSerializer<T>) new TableRowsWireSerializer<>(new UnversionedRowSerializer());
} else {
this.tableRowsSerializer = new TableRowsWireSerializer<>(
serializationResolver.createWireRowSerializer(
serializationResolver.forClass(objectClazz, self.schema))
);
}
}

return self;
});
}

public boolean write(List<T> rows, TableSchema schema) throws IOException {
byte[] serializedRows = tableRowsSerializer.serializeRows(rows, schema);
return write(serializedRows);
}

@Override
public CompletableFuture<?> close() {
return super.close()
.thenCompose(response -> {
if (transaction != null && transaction.isActive()) {
return transaction.commit()
.thenApply(unused -> response);
}
return CompletableFuture.completedFuture(response);
});
}
}
168 changes: 168 additions & 0 deletions data-source/src/main/scala/tech/ytsaurus/client/YTGetters.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
package tech.ytsaurus.client;

import tech.ytsaurus.typeinfo.*;
import tech.ytsaurus.yson.YsonConsumer;

import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.util.Map;

public class YTGetters {
public interface GetTiType {
TiType getTiType();
}

public interface FromStruct<Struct> extends GetTiType {
void getYson(Struct struct, YsonConsumer ysonConsumer);
}

public interface FromList<List> extends GetTiType {
int getSize(List list);

void getYson(List list, int i, YsonConsumer ysonConsumer);
}

public interface FromStructToYson<Struct> extends FromStruct<Struct> {
}

public interface FromListToYson<List> extends FromList<List> {
}

public interface FromDict<Dict, Keys, Values> extends GetTiType {
FromList<Keys> getKeyGetter();

FromList<Values> getValueGetter();

int getSize(Dict dict);

Keys getKeys(Dict dict);

Values getValues(Dict dict);
}

public interface FromStructToNull<Struct> extends FromStruct<Struct> {
}

public interface FromListToNull<List> extends FromList<List> {
}

public interface FromStructToOptional<Struct> extends FromStruct<Struct> {
FromStruct<Struct> getNotEmptyGetter();

boolean isEmpty(Struct struct);
}

public interface FromListToOptional<List> extends FromList<List> {
FromList<List> getNotEmptyGetter();

boolean isEmpty(List list, int i);
}

public interface FromStructToString<Struct> extends FromStruct<Struct> {
ByteBuffer getString(Struct struct);
}

public interface FromListToString<List> extends FromList<List> {
ByteBuffer getString(List struct, int i);
}

public interface FromStructToByte<Struct> extends FromStruct<Struct> {
byte getByte(Struct struct);
}

public interface FromListToByte<List> extends FromList<List> {
byte getByte(List list, int i);
}

public interface FromStructToShort<Struct> extends FromStruct<Struct> {
short getShort(Struct struct);
}

public interface FromListToShort<List> extends FromList<List> {
short getShort(List list, int i);
}

public interface FromStructToInt<Struct> extends FromStruct<Struct> {
int getInt(Struct struct);
}

public interface FromListToInt<List> extends FromList<List> {
int getInt(List list, int i);
}

public interface FromStructToLong<Struct> extends FromStruct<Struct> {
long getLong(Struct struct);
}

public interface FromListToLong<List> extends FromList<List> {
long getLong(List list, int i);
}

public interface FromStructToBoolean<Struct> extends FromStruct<Struct> {
boolean getBoolean(Struct struct);
}

public interface FromListToBoolean<List> extends FromList<List> {
boolean getBoolean(List list, int i);
}

public interface FromStructToFloat<Struct> extends FromStruct<Struct> {
float getFloat(Struct struct);
}

public interface FromListToFloat<List> extends FromList<List> {
float getFloat(List list, int i);
}

public interface FromStructToDouble<Struct> extends FromStruct<Struct> {
double getDouble(Struct struct);
}

public interface FromListToDouble<List> extends FromList<List> {
double getDouble(List list, int i);
}

public interface FromStructToStruct<Struct, Value> extends FromStruct<Struct> {
java.util.List<Map.Entry<String, FromStruct<Value>>> getMembersGetters();

Value getStruct(Struct struct);
}

public interface FromListToStruct<List, Value> extends FromList<List> {
java.util.List<Map.Entry<String, FromStruct<Value>>> getMembersGetters();

Value getStruct(List list, int i);
}

public interface FromStructToList<Struct, List> extends FromStruct<Struct> {
FromList<List> getElementGetter();

List getList(Struct struct);
}

public interface FromListToList<List, Value> extends FromList<List> {
FromList<Value> getElementGetter();

Value getList(List list, int i);
}

public interface FromStructToDict<Struct, Dict, Keys, Values> extends FromStruct<Struct> {
FromDict<Dict, Keys, Values> getGetter();

Dict getDict(Struct struct);
}

public interface FromListToDict<List, Dict, Keys, Values> extends FromList<List> {
FromDict<Dict, Keys, Values> getGetter();

Dict getDict(List list, int i);
}

public interface FromStructToBigDecimal<Struct> extends FromStruct<Struct> {
BigDecimal getBigDecimal(Struct struct);
}

public interface FromListToBigDecimal<List> extends FromList<List> {
BigDecimal getBigDecimal(List list, int i);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,19 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.OutputWriter
import org.apache.spark.sql.types.StructType
import org.slf4j.LoggerFactory
import tech.ytsaurus.client.request.{TransactionalOptions, WriteSerializationContext, WriteTable}
import tech.ytsaurus.client.{ArrowWriteSerializationContext, CompoundClient, TableWriter}
import tech.ytsaurus.core.GUID
import tech.ytsaurus.spyt.format.conf.SparkYtWriteConfiguration
import tech.ytsaurus.spyt.format.conf.YtTableSparkSettings._
import tech.ytsaurus.spyt.fs.conf._
import tech.ytsaurus.spyt.fs.path.YPathEnriched
import tech.ytsaurus.spyt.serializers.{InternalRowSerializer, WriteSchemaConverter}
import tech.ytsaurus.spyt.wrapper.LogLazy
import tech.ytsaurus.client.request.{TransactionalOptions, WriteSerializationContext, WriteTable}
import tech.ytsaurus.client.{CompoundClient, TableWriter}
import tech.ytsaurus.core.GUID
import tech.ytsaurus.spyt.format.conf.SparkYtWriteConfiguration

import java.util
import java.util.concurrent.{CompletableFuture, TimeUnit}
import scala.collection.JavaConverters.seqAsJavaListConverter
import scala.concurrent.{Await, Future}
import scala.util.{Failure, Try}

Expand Down Expand Up @@ -151,9 +152,24 @@ class YtOutputWriter(richPath: YPathEnriched,
protected def initializeWriter(): TableWriter[InternalRow] = {
val appendPath = richPath.withAttr("append", "true").toYPath
log.debugLazy(s"Initialize new write: $appendPath, transaction: $transactionGuid")
val writeSchemaConverter = WriteSchemaConverter(options)
val request = WriteTable.builder[InternalRow]()
.setPath(appendPath)
.setSerializationContext(new WriteSerializationContext(new InternalRowSerializer(schema, WriteSchemaConverter(options))))
.setSerializationContext(
if (options.ytConf(ArrowWriteEnabled)) {
if (!writeSchemaConverter.typeV3Format) {
throw new RuntimeException("arrow writer is only supported with typeV3")
}
new ArrowWriteSerializationContext[InternalRow](
schema.fields.zipWithIndex.map { case (field, i) =>
util.Map.entry(field.name, writeSchemaConverter.ytLogicalTypeV3(field).ytGettersFromStruct(
field.dataType, i
))
}.toSeq.asJava
)
} else
new WriteSerializationContext(new InternalRowSerializer(schema, WriteSchemaConverter(options)))
)
.setTransactionalOptions(new TransactionalOptions(GUID.valueOf(transactionGuid)))
.setNeedRetries(false)
.build()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ object YtTableSparkSettings {

case object ArrowEnabled extends ConfigEntry[Boolean]("arrow_enabled", Some(true))

case object ArrowWriteEnabled extends ConfigEntry[Boolean]("arrow_write_enabled", Some(false))

case object KeyPartitioned extends ConfigEntry[Boolean]("key_partitioned")

case object Dynamic extends ConfigEntry[Boolean]("dynamic")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import org.slf4j.LoggerFactory
import tech.ytsaurus.client.TableWriter
import tech.ytsaurus.client.rows.{WireProtocolWriteable, WireRowSerializer}
import tech.ytsaurus.core.tables.{ColumnValueType, TableSchema}
import tech.ytsaurus.spyt.format.conf.YtTableSparkSettings.{WriteSchemaHint, WriteTypeV3}
import tech.ytsaurus.spyt.serialization.YsonEncoder
import tech.ytsaurus.spyt.serializers.InternalRowSerializer._
import tech.ytsaurus.spyt.serializers.SchemaConverter.{Unordered, decimalToBinary}
Expand Down
Loading
Loading