diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java
index 6370531428..4860cd7648 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java
@@ -176,7 +176,7 @@ public KudoSerializer(Schema schema) {
* @param numRows number of rows to write
* @return number of bytes written
*/
- long writeToStream(Table table, OutputStream out, int rowOffset, int numRows) {
+ WriteMetrics writeToStreamWithMetrics(Table table, OutputStream out, int rowOffset, int numRows) {
HostColumnVector[] columns = null;
try {
columns = IntStream.range(0, table.getNumberOfColumns())
@@ -185,7 +185,7 @@ long writeToStream(Table table, OutputStream out, int rowOffset, int numRows) {
.toArray(HostColumnVector[]::new);
Cuda.DEFAULT_STREAM.sync();
- return writeToStream(columns, out, rowOffset, numRows);
+ return writeToStreamWithMetrics(columns, out, rowOffset, numRows);
} finally {
if (columns != null) {
for (HostColumnVector column : columns) {
@@ -195,6 +195,16 @@ long writeToStream(Table table, OutputStream out, int rowOffset, int numRows) {
}
}
+ /**
+ * Write partition of an array of {@link HostColumnVector} to an output stream.
+ * See {@link #writeToStreamWithMetrics(HostColumnVector[], OutputStream, int, int)} for more
+ * details.
+ * @return number of bytes written
+ */
+ public long writeToStream(HostColumnVector[] columns, OutputStream out, int rowOffset, int numRows) {
+ return writeToStreamWithMetrics(columns, out, rowOffset, numRows).getWrittenBytes();
+ }
+
/**
* Write partition of an array of {@link HostColumnVector} to an output stream.
*
@@ -208,7 +218,7 @@ long writeToStream(Table table, OutputStream out, int rowOffset, int numRows) {
* @param numRows number of rows to write
* @return number of bytes written
*/
- public long writeToStream(HostColumnVector[] columns, OutputStream out, int rowOffset, int numRows) {
+ public WriteMetrics writeToStreamWithMetrics(HostColumnVector[] columns, OutputStream out, int rowOffset, int numRows) {
ensure(numRows > 0, () -> "numRows must be > 0, but was " + numRows);
ensure(columns.length > 0, () -> "columns must not be empty, for row count only records " +
"please call writeRowCountToStream");
@@ -286,17 +296,23 @@ public Pair
mergeToTable(List kudoTables) throws
}
}
- private long writeSliced(HostColumnVector[] columns, DataWriter out, int rowOffset, int numRows) throws Exception {
+ private WriteMetrics writeSliced(HostColumnVector[] columns, DataWriter out, int rowOffset, int numRows) throws Exception {
+ WriteMetrics metrics = new WriteMetrics();
KudoTableHeaderCalc headerCalc = new KudoTableHeaderCalc(rowOffset, numRows, flattenedColumnCount);
- Visitors.visitColumns(columns, headerCalc);
+ withTime(() -> Visitors.visitColumns(columns, headerCalc), metrics::addCalcHeaderTime);
KudoTableHeader header = headerCalc.getHeader();
+ long currentTime = System.nanoTime();
header.writeTo(out);
+ metrics.addCopyHeaderTime(System.nanoTime() - currentTime);
+ metrics.addWrittenBytes(header.getSerializedSize());
long bytesWritten = 0;
for (BufferType bufferType : ALL_BUFFER_TYPES) {
- SlicedBufferSerializer serializer = new SlicedBufferSerializer(rowOffset, numRows, bufferType, out);
+ SlicedBufferSerializer serializer = new SlicedBufferSerializer(rowOffset, numRows, bufferType,
+ out, metrics);
Visitors.visitColumns(columns, serializer);
bytesWritten += serializer.getTotalDataLen();
+ metrics.addWrittenBytes(serializer.getTotalDataLen());
}
if (bytesWritten != header.getTotalDataLen()) {
@@ -307,7 +323,7 @@ private long writeSliced(HostColumnVector[] columns, DataWriter out, int rowOffs
out.flush();
- return header.getSerializedSize() + bytesWritten;
+ return metrics;
}
private static DataWriter writerFrom(OutputStream out) {
@@ -348,6 +364,12 @@ static T withTime(Supplier task, LongConsumer timeConsumer) {
return ret;
}
+ static void withTime(Runnable task, LongConsumer timeConsumer) {
+ long now = System.nanoTime();
+ task.run();
+ timeConsumer.accept(System.nanoTime() - now);
+ }
+
/**
* This method returns the length in bytes needed to represent X number of rows
* e.g. getValidityLengthInBytes(5) => 1 byte
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedBufferSerializer.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedBufferSerializer.java
index e22a523855..86d51116b6 100644
--- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedBufferSerializer.java
+++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/SlicedBufferSerializer.java
@@ -16,19 +16,19 @@
package com.nvidia.spark.rapids.jni.kudo;
+import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padForHostAlignment;
+import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.withTime;
+
import ai.rapids.cudf.BufferType;
import ai.rapids.cudf.DType;
import ai.rapids.cudf.HostColumnVectorCore;
import ai.rapids.cudf.HostMemoryBuffer;
import com.nvidia.spark.rapids.jni.schema.HostColumnsVisitor;
-
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.List;
-import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padForHostAlignment;
-
/**
* This class visits a list of columns and serialize one of the buffers (validity, offset, or data) into with kudo
* format.
@@ -48,13 +48,16 @@ class SlicedBufferSerializer implements HostColumnsVisitor {
private final DataWriter writer;
private final Deque sliceInfos = new ArrayDeque<>();
+ private final WriteMetrics metrics;
private long totalDataLen;
- SlicedBufferSerializer(int rowOffset, int numRows, BufferType bufferType, DataWriter writer) {
+ SlicedBufferSerializer(int rowOffset, int numRows, BufferType bufferType, DataWriter writer,
+ WriteMetrics metrics) {
this.root = new SliceInfo(rowOffset, numRows);
this.bufferType = bufferType;
this.writer = writer;
this.sliceInfos.addLast(root);
+ this.metrics = metrics;
this.totalDataLen = 0;
}
@@ -153,28 +156,26 @@ public Void visit(HostColumnVectorCore col) {
}
}
- private long copySlicedValidity(HostColumnVectorCore column, SliceInfo sliceInfo) throws IOException {
+ private long copySlicedValidity(HostColumnVectorCore column, SliceInfo sliceInfo)
+ throws IOException {
if (column.getValidity() != null && sliceInfo.getRowCount() > 0) {
HostMemoryBuffer buff = column.getValidity();
long len = sliceInfo.getValidityBufferInfo().getBufferLength();
- writer.copyDataFrom(buff, sliceInfo.getValidityBufferInfo().getBufferOffset(),
- len);
- return padForHostAlignment(writer, len);
+ return copyBufferAndPadForHost(buff, sliceInfo.getValidityBufferInfo().getBufferOffset(), len);
} else {
return 0;
}
}
- private long copySlicedOffset(HostColumnVectorCore column, SliceInfo sliceInfo) throws IOException {
+ private long copySlicedOffset(HostColumnVectorCore column, SliceInfo sliceInfo)
+ throws IOException {
if (sliceInfo.rowCount <= 0 || column.getOffsets() == null) {
// Don't copy anything, there are no rows
return 0;
}
long bytesToCopy = (sliceInfo.rowCount + 1) * Integer.BYTES;
long srcOffset = sliceInfo.offset * Integer.BYTES;
- HostMemoryBuffer buff = column.getOffsets();
- writer.copyDataFrom(buff, srcOffset, bytesToCopy);
- return padForHostAlignment(writer, bytesToCopy);
+ return copyBufferAndPadForHost(column.getOffsets(), srcOffset, bytesToCopy);
}
private long copySlicedData(HostColumnVectorCore column, SliceInfo sliceInfo) throws IOException {
@@ -182,7 +183,8 @@ private long copySlicedData(HostColumnVectorCore column, SliceInfo sliceInfo) th
DType type = column.getType();
if (type.equals(DType.STRING)) {
long startByteOffset = column.getOffsets().getInt(sliceInfo.offset * Integer.BYTES);
- long endByteOffset = column.getOffsets().getInt((sliceInfo.offset + sliceInfo.rowCount) * Integer.BYTES);
+ long endByteOffset =
+ column.getOffsets().getInt((sliceInfo.offset + sliceInfo.rowCount) * Integer.BYTES);
long bytesToCopy = endByteOffset - startByteOffset;
if (column.getData() == null) {
if (bytesToCopy != 0) {
@@ -192,14 +194,12 @@ private long copySlicedData(HostColumnVectorCore column, SliceInfo sliceInfo) th
return 0;
} else {
- writer.copyDataFrom(column.getData(), startByteOffset, bytesToCopy);
- return padForHostAlignment(writer, bytesToCopy);
+ return copyBufferAndPadForHost(column.getData(), startByteOffset, bytesToCopy);
}
} else if (type.getSizeInBytes() > 0) {
long bytesToCopy = sliceInfo.rowCount * type.getSizeInBytes();
long srcOffset = sliceInfo.offset * type.getSizeInBytes();
- writer.copyDataFrom(column.getData(), srcOffset, bytesToCopy);
- return padForHostAlignment(writer, bytesToCopy);
+ return copyBufferAndPadForHost(column.getData(), srcOffset, bytesToCopy);
} else {
return 0;
}
@@ -207,4 +207,13 @@ private long copySlicedData(HostColumnVectorCore column, SliceInfo sliceInfo) th
return 0;
}
}
+
+ private long copyBufferAndPadForHost(HostMemoryBuffer buffer, long offset, long length)
+ throws IOException {
+ long now = System.nanoTime();
+ writer.copyDataFrom(buffer, offset, length);
+ long ret = padForHostAlignment(writer, length);
+ metrics.addCopyBufferTime(System.nanoTime() - now);
+ return ret;
+ }
}
diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/WriteMetrics.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/WriteMetrics.java
new file mode 100644
index 0000000000..d34564e776
--- /dev/null
+++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/WriteMetrics.java
@@ -0,0 +1,79 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark.rapids.jni.kudo;
+
+/**
+ * This class contains metrics for serializing table using kudo format.
+ */
+public class WriteMetrics {
+ private long calcHeaderTime;
+ private long copyHeaderTime;
+ private long copyBufferTime;
+ private long writtenBytes;
+
+
+ public WriteMetrics() {
+ this.calcHeaderTime = 0;
+ this.copyHeaderTime = 0;
+ this.copyBufferTime = 0;
+ this.writtenBytes = 0;
+ }
+
+ /**
+ * Get the time spent on calculating the header.
+ */
+ public long getCalcHeaderTime() {
+ return calcHeaderTime;
+ }
+
+ /**
+ * Get the time spent on copying the buffer.
+ */
+ public long getCopyBufferTime() {
+ return copyBufferTime;
+ }
+
+ public void addCopyBufferTime(long time) {
+ copyBufferTime += time;
+ }
+
+ /**
+ * Get the time spent on copying the header.
+ */
+ public long getCopyHeaderTime() {
+ return copyHeaderTime;
+ }
+
+ public void addCalcHeaderTime(long time) {
+ calcHeaderTime += time;
+ }
+
+ public void addCopyHeaderTime(long time) {
+ copyHeaderTime += time;
+ }
+
+ /**
+ * Get the number of bytes written.
+ */
+ public long getWrittenBytes() {
+ return writtenBytes;
+ }
+
+ public void addWrittenBytes(long bytes) {
+ writtenBytes += bytes;
+ }
+}
diff --git a/src/test/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializerTest.java b/src/test/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializerTest.java
index 210777accf..3ffcb5e61b 100644
--- a/src/test/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializerTest.java
+++ b/src/test/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializerTest.java
@@ -75,7 +75,7 @@ public void testWriteSimple() throws Exception {
try (Table t = buildSimpleTable()) {
ByteArrayOutputStream out = new ByteArrayOutputStream();
- long bytesWritten = serializer.writeToStream(t, out, 0, 4);
+ long bytesWritten = serializer.writeToStreamWithMetrics(t, out, 0, 4).getWrittenBytes();
assertEquals(189, bytesWritten);
ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray());
@@ -365,7 +365,7 @@ private static void checkMergeTable(Table expected, List tableSlices
ByteArrayOutputStream bout = new ByteArrayOutputStream();
for (TableSlice slice : tableSlices) {
- serializer.writeToStream(slice.getBaseTable(), bout, slice.getStartRow(), slice.getNumRows());
+ serializer.writeToStreamWithMetrics(slice.getBaseTable(), bout, slice.getStartRow(), slice.getNumRows());
}
bout.flush();