Skip to content

Commit

Permalink
Add write metrics for kudo serializer.
Browse files Browse the repository at this point in the history
Signed-off-by: liurenjie1024 <[email protected]>
  • Loading branch information
liurenjie1024 committed Nov 27, 2024
1 parent bf94d21 commit d8fc386
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 25 deletions.
30 changes: 23 additions & 7 deletions src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializer.java
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ public KudoSerializer(Schema schema) {
* @param numRows number of rows to write
* @return number of bytes written
*/
long writeToStream(Table table, OutputStream out, int rowOffset, int numRows) {
WriteMetrics writeToStream(Table table, OutputStream out, int rowOffset, int numRows) {
HostColumnVector[] columns = null;
try {
columns = IntStream.range(0, table.getNumberOfColumns())
Expand Down Expand Up @@ -208,7 +208,7 @@ long writeToStream(Table table, OutputStream out, int rowOffset, int numRows) {
* @param numRows number of rows to write
* @return number of bytes written
*/
public long writeToStream(HostColumnVector[] columns, OutputStream out, int rowOffset, int numRows) {
public WriteMetrics writeToStream(HostColumnVector[] columns, OutputStream out, int rowOffset, int numRows) {
ensure(numRows > 0, () -> "numRows must be > 0, but was " + numRows);
ensure(columns.length > 0, () -> "columns must not be empty, for row count only records " +
"please call writeRowCountToStream");
Expand Down Expand Up @@ -286,17 +286,27 @@ public Pair<Table, MergeMetrics> mergeToTable(List<KudoTable> kudoTables) throws
}
}

private long writeSliced(HostColumnVector[] columns, DataWriter out, int rowOffset, int numRows) throws Exception {
private WriteMetrics writeSliced(HostColumnVector[] columns, DataWriter out, int rowOffset, int numRows) throws Exception {
WriteMetrics metrics = new WriteMetrics();
KudoTableHeaderCalc headerCalc = new KudoTableHeaderCalc(rowOffset, numRows, flattenedColumnCount);
Visitors.visitColumns(columns, headerCalc);
withTime(() -> Visitors.visitColumns(columns, headerCalc), metrics::addCalcHeaderTime);
KudoTableHeader header = headerCalc.getHeader();
header.writeTo(out);
withTime(() -> {
try {
header.writeTo(out);
} catch (IOException e) {
throw new RuntimeException(e);
}
}, metrics::addCopyHeaderTime);
metrics.addWrittenBytes(header.getSerializedSize());

long bytesWritten = 0;
for (BufferType bufferType : ALL_BUFFER_TYPES) {
SlicedBufferSerializer serializer = new SlicedBufferSerializer(rowOffset, numRows, bufferType, out);
SlicedBufferSerializer serializer = new SlicedBufferSerializer(rowOffset, numRows, bufferType,
out, metrics);
Visitors.visitColumns(columns, serializer);
bytesWritten += serializer.getTotalDataLen();
metrics.addWrittenBytes(serializer.getTotalDataLen());
}

if (bytesWritten != header.getTotalDataLen()) {
Expand All @@ -307,7 +317,7 @@ private long writeSliced(HostColumnVector[] columns, DataWriter out, int rowOffs

out.flush();

return header.getSerializedSize() + bytesWritten;
return metrics;
}

private static DataWriter writerFrom(OutputStream out) {
Expand Down Expand Up @@ -348,6 +358,12 @@ static <T> T withTime(Supplier<T> task, LongConsumer timeConsumer) {
return ret;
}

static void withTime(Runnable task, LongConsumer timeConsumer) {
long now = System.nanoTime();
task.run();
timeConsumer.accept(System.nanoTime() - now);
}

/**
* This method returns the length in bytes needed to represent X number of rows
* e.g. getValidityLengthInBytes(5) => 1 byte
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,19 @@

package com.nvidia.spark.rapids.jni.kudo;

import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padForHostAlignment;
import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.withTime;

import ai.rapids.cudf.BufferType;
import ai.rapids.cudf.DType;
import ai.rapids.cudf.HostColumnVectorCore;
import ai.rapids.cudf.HostMemoryBuffer;
import com.nvidia.spark.rapids.jni.schema.HostColumnsVisitor;

import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.List;

import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padForHostAlignment;

/**
* This class visits a list of columns and serialize one of the buffers (validity, offset, or data) into with kudo
* format.
Expand All @@ -48,13 +48,16 @@ class SlicedBufferSerializer implements HostColumnsVisitor<Void> {
private final DataWriter writer;

private final Deque<SliceInfo> sliceInfos = new ArrayDeque<>();
private final WriteMetrics metrics;
private long totalDataLen;

SlicedBufferSerializer(int rowOffset, int numRows, BufferType bufferType, DataWriter writer) {
SlicedBufferSerializer(int rowOffset, int numRows, BufferType bufferType, DataWriter writer,
WriteMetrics metrics) {
this.root = new SliceInfo(rowOffset, numRows);
this.bufferType = bufferType;
this.writer = writer;
this.sliceInfos.addLast(root);
this.metrics = metrics;
this.totalDataLen = 0;
}

Expand Down Expand Up @@ -153,36 +156,35 @@ public Void visit(HostColumnVectorCore col) {
}
}

private long copySlicedValidity(HostColumnVectorCore column, SliceInfo sliceInfo) throws IOException {
private long copySlicedValidity(HostColumnVectorCore column, SliceInfo sliceInfo)
throws IOException {
if (column.getValidity() != null && sliceInfo.getRowCount() > 0) {
HostMemoryBuffer buff = column.getValidity();
long len = sliceInfo.getValidityBufferInfo().getBufferLength();
writer.copyDataFrom(buff, sliceInfo.getValidityBufferInfo().getBufferOffset(),
len);
return padForHostAlignment(writer, len);
return copyBufferAndPadForHost(buff, sliceInfo.getValidityBufferInfo().getBufferOffset(), len);
} else {
return 0;
}
}

private long copySlicedOffset(HostColumnVectorCore column, SliceInfo sliceInfo) throws IOException {
private long copySlicedOffset(HostColumnVectorCore column, SliceInfo sliceInfo)
throws IOException {
if (sliceInfo.rowCount <= 0 || column.getOffsets() == null) {
// Don't copy anything, there are no rows
return 0;
}
long bytesToCopy = (sliceInfo.rowCount + 1) * Integer.BYTES;
long srcOffset = sliceInfo.offset * Integer.BYTES;
HostMemoryBuffer buff = column.getOffsets();
writer.copyDataFrom(buff, srcOffset, bytesToCopy);
return padForHostAlignment(writer, bytesToCopy);
return copyBufferAndPadForHost(column.getOffsets(), srcOffset, bytesToCopy);
}

private long copySlicedData(HostColumnVectorCore column, SliceInfo sliceInfo) throws IOException {
if (sliceInfo.rowCount > 0) {
DType type = column.getType();
if (type.equals(DType.STRING)) {
long startByteOffset = column.getOffsets().getInt(sliceInfo.offset * Integer.BYTES);
long endByteOffset = column.getOffsets().getInt((sliceInfo.offset + sliceInfo.rowCount) * Integer.BYTES);
long endByteOffset =
column.getOffsets().getInt((sliceInfo.offset + sliceInfo.rowCount) * Integer.BYTES);
long bytesToCopy = endByteOffset - startByteOffset;
if (column.getData() == null) {
if (bytesToCopy != 0) {
Expand All @@ -192,19 +194,28 @@ private long copySlicedData(HostColumnVectorCore column, SliceInfo sliceInfo) th

return 0;
} else {
writer.copyDataFrom(column.getData(), startByteOffset, bytesToCopy);
return padForHostAlignment(writer, bytesToCopy);
return copyBufferAndPadForHost(column.getData(), startByteOffset, bytesToCopy);
}
} else if (type.getSizeInBytes() > 0) {
long bytesToCopy = sliceInfo.rowCount * type.getSizeInBytes();
long srcOffset = sliceInfo.offset * type.getSizeInBytes();
writer.copyDataFrom(column.getData(), srcOffset, bytesToCopy);
return padForHostAlignment(writer, bytesToCopy);
return copyBufferAndPadForHost(column.getData(), srcOffset, bytesToCopy);
} else {
return 0;
}
} else {
return 0;
}
}

private long copyBufferAndPadForHost(HostMemoryBuffer buffer, long offset, long length) {
return withTime(() -> {
try {
writer.copyDataFrom(buffer, offset, length);
return padForHostAlignment(writer, length);
} catch (IOException e) {
throw new RuntimeException(e);
}
}, metrics::addCopyBufferTime);
}
}
79 changes: 79 additions & 0 deletions src/main/java/com/nvidia/spark/rapids/jni/kudo/WriteMetrics.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids.jni.kudo;

/**
* This class contains metrics for serializing table using kudo format.
*/
public class WriteMetrics {
private long calcHeaderTime;
private long copyHeaderTime;
private long copyBufferTime;
private long writtenBytes;


public WriteMetrics() {
this.calcHeaderTime = 0;
this.copyHeaderTime = 0;
this.copyBufferTime = 0;
this.writtenBytes = 0;
}

/**
* Get the time spent on calculating the header.
*/
public long getCalcHeaderTime() {
return calcHeaderTime;
}

/**
* Get the time spent on copying the buffer.
*/
public long getCopyBufferTime() {
return copyBufferTime;
}

public void addCopyBufferTime(long time) {
copyBufferTime += time;
}

/**
* Get the time spent on copying the header.
*/
public long getCopyHeaderTime() {
return copyHeaderTime;
}

public void addCalcHeaderTime(long time) {
calcHeaderTime += time;
}

public void addCopyHeaderTime(long time) {
copyHeaderTime += time;
}

/**
* Get the number of bytes written.
*/
public long getWrittenBytes() {
return writtenBytes;
}

public void addWrittenBytes(long bytes) {
writtenBytes += bytes;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public void testWriteSimple() throws Exception {

try (Table t = buildSimpleTable()) {
ByteArrayOutputStream out = new ByteArrayOutputStream();
long bytesWritten = serializer.writeToStream(t, out, 0, 4);
long bytesWritten = serializer.writeToStream(t, out, 0, 4).getWrittenBytes();
assertEquals(189, bytesWritten);

ByteArrayInputStream in = new ByteArrayInputStream(out.toByteArray());
Expand Down

0 comments on commit d8fc386

Please sign in to comment.