From d661ed326a1127e8b4154156000d84d69f931631 Mon Sep 17 00:00:00 2001 From: Gaojie Liu Date: Thu, 5 Oct 2023 13:49:36 -0700 Subject: [PATCH] [fast-avro] Added a few serde optimization (#517) Serializer: This code change will try to reuse the backed bytes if the float list is 'BufferBackedPrimitiveFloatList' when writing float list. If an instance of `BufferBackedPrimitiveFloatList` is changed after deserialization: `readPrimitiveFloatArray`, Fast Serializer won't use the backed bytes because of the divergence. Deserializer: Use `reset` instead of `clear` when reusing `GenericArray` since `reset` is cheaper than `clear` and the behavior difference is that `reset` won't nullify the elements in current array, but just resize the array length to be 0. --- .../FastGenericSerializerGeneratorTest.java | 86 +++++++++++++++++++ .../BufferBackedPrimitiveFloatList.java | 52 ++++++++++- .../avro/fastserde/CompositeByteBuffer.java | 8 ++ .../fastserde/FastDeserializerGenerator.java | 15 +++- .../fastserde/FastSerializerGenerator.java | 20 ++++- .../BufferBackedPrimitiveFloatList.java | 44 +++++++++- 6 files changed, 218 insertions(+), 7 deletions(-) diff --git a/fastserde/avro-fastserde-tests-common/src/test/java/com/linkedin/avro/fastserde/FastGenericSerializerGeneratorTest.java b/fastserde/avro-fastserde-tests-common/src/test/java/com/linkedin/avro/fastserde/FastGenericSerializerGeneratorTest.java index 436d0575e..d648b5065 100644 --- a/fastserde/avro-fastserde-tests-common/src/test/java/com/linkedin/avro/fastserde/FastGenericSerializerGeneratorTest.java +++ b/fastserde/avro-fastserde-tests-common/src/test/java/com/linkedin/avro/fastserde/FastGenericSerializerGeneratorTest.java @@ -12,6 +12,7 @@ import com.linkedin.avroutil1.compatibility.AvroCompatibilityHelper; import java.io.ByteArrayOutputStream; import java.io.File; +import java.io.IOException; import java.net.URL; import java.net.URLClassLoader; import java.nio.ByteBuffer; @@ -562,6 +563,81 @@ public double getPrimitive(int index) { Assert.assertTrue(primitiveApiCalled.get()); } + @Test(groups = {"serializationTest"}) + public void shouldPassThroughByteBufferForArrayOfFloats() { + String arrayOfFloatFieldName = "array_of_float"; + Schema recordSchema = createRecord("TestArrayOfFloats", new Schema.Field(arrayOfFloatFieldName, Schema.createArray(Schema.create(Schema.Type.FLOAT)), null, null)); + GenericRecord record = new GenericData.Record(recordSchema); + record.put(arrayOfFloatFieldName, Arrays.asList(1.0f, 2.0f, 3.0f)); + + /** + * Deserialize it first by fast deserializer to check whether {@link BufferBackedPrimitiveFloatList} is being used or not. + */ + GenericRecord decodedRecord = decodeRecordFast(recordSchema, dataAsBinaryDecoder(record)); + Assert.assertTrue(decodedRecord.get(arrayOfFloatFieldName) instanceof BufferBackedPrimitiveFloatList); + + class TestBufferBackedPrimitiveFloatList extends BufferBackedPrimitiveFloatList { + + boolean writeFloatsCalled = false; + boolean writeFloatsByBackedBytesCalled = false; + public TestBufferBackedPrimitiveFloatList(BufferBackedPrimitiveFloatList floatList) { + super(0); + floatList.copyInternalState(this); + } + + public void resetFlag() { + this.writeFloatsCalled = false; + this.writeFloatsByBackedBytesCalled = false; + } + + @Override + public void writeFloats(Encoder encoder) throws IOException { + writeFloatsCalled = true; + super.writeFloats(encoder); + } + + @Override + protected void writeFloatsByBackedBytes(Encoder encoder) throws IOException { + writeFloatsByBackedBytesCalled = true; + super.writeFloatsByBackedBytes(encoder); + } + } + + TestBufferBackedPrimitiveFloatList floatListWithHook = new TestBufferBackedPrimitiveFloatList((BufferBackedPrimitiveFloatList)decodedRecord.get(arrayOfFloatFieldName)); + + // Replace the record field by the object with hook function + decodedRecord.put(arrayOfFloatFieldName, floatListWithHook); + + // Serialize it with fast serializer + Decoder anotherDecoder = dataAsBinaryDecoder(decodedRecord, recordSchema); + + Assert.assertTrue(floatListWithHook.writeFloatsCalled); + Assert.assertTrue(floatListWithHook.writeFloatsByBackedBytesCalled); + + // Deserialize it by vanilla Avro to verify data + GenericRecord decodedRecord1 = decodeRecord(recordSchema, anotherDecoder); + List decodedFloatList1 = (List)decodedRecord1.get(arrayOfFloatFieldName); + Assert.assertEquals(decodedFloatList1.get(0), 1.0f); + Assert.assertEquals(decodedFloatList1.get(1), 2.0f); + Assert.assertEquals(decodedFloatList1.get(2), 3.0f); + + /** + * Change the elements of {@link BufferBackedPrimitiveFloatList}, then pass-through bytes won't be used in fast serializer. + */ + floatListWithHook.set(0, 10.0f); + floatListWithHook.resetFlag(); + // Serialize it again with fast serializer + GenericRecord decodedRecord2 = decodeRecord(recordSchema, dataAsBinaryDecoder(decodedRecord, recordSchema)); + + Assert.assertTrue(floatListWithHook.writeFloatsCalled); + Assert.assertFalse(floatListWithHook.writeFloatsByBackedBytesCalled); + List decodedFloatList2 = (List)decodedRecord2.get(arrayOfFloatFieldName); + Assert.assertEquals(decodedFloatList2.get(0), 10.0f); + Assert.assertEquals(decodedFloatList2.get(1), 2.0f); + Assert.assertEquals(decodedFloatList2.get(2), 3.0f); + + } + @Test(groups = {"serializationTest"}) public void shouldWriteArrayOfFloats() { // given @@ -697,4 +773,14 @@ public T decodeRecord(Schema schema, Decoder decoder) { throw new RuntimeException(e); } } + + public T decodeRecordFast(Schema schema, Decoder decoder) { + try { + FastGenericDeserializerGenerator fastGenericDeserializerGenerator = new FastGenericDeserializerGenerator<>(schema, schema, tempDir, classLoader, null, null); + FastDeserializer fastDeserializer = fastGenericDeserializerGenerator.generateDeserializer(); + return fastDeserializer.deserialize(decoder); + } catch (Exception e) { + throw new RuntimeException(e); + } + } } diff --git a/fastserde/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/BufferBackedPrimitiveFloatList.java b/fastserde/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/BufferBackedPrimitiveFloatList.java index 32284d1ff..7d6f55c9c 100644 --- a/fastserde/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/BufferBackedPrimitiveFloatList.java +++ b/fastserde/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/BufferBackedPrimitiveFloatList.java @@ -6,10 +6,12 @@ import java.util.AbstractList; import java.util.Collection; import java.util.Iterator; +import java.util.List; import org.apache.avro.Schema; import org.apache.avro.generic.GenericArray; import org.apache.avro.generic.GenericData; import org.apache.avro.io.Decoder; +import org.apache.avro.io.Encoder; /** @@ -45,6 +47,8 @@ public class BufferBackedPrimitiveFloatList extends AbstractList private boolean isCached = false; private CompositeByteBuffer byteBuffer; + private boolean changed = false; + public BufferBackedPrimitiveFloatList(int capacity) { if (capacity != 0) { elements = new float[capacity]; @@ -61,6 +65,17 @@ public BufferBackedPrimitiveFloatList(Collection c) { byteBuffer = new CompositeByteBuffer(c != null); } + /** + * For testing purpose. + */ + public void copyInternalState(BufferBackedPrimitiveFloatList another) { + another.size = this.size; + another.elements = this.elements; + another.isCached = this.isCached; + another.byteBuffer = this.byteBuffer; + another.changed = this.changed; + } + /** * Instantiate (or re-use) and populate a {@link BufferBackedPrimitiveFloatList} from a {@link org.apache.avro.io.Decoder}. * @@ -132,9 +147,10 @@ private static Object newPrimitiveFloatArray(Object old) { oldFloatList.byteBuffer.clear(); oldFloatList.isCached = false; oldFloatList.size = 0; + oldFloatList.changed = false; return oldFloatList; } else { - // Just a place holder, will set up the elements later. + // Just a placeholder, will set up the elements later. return new BufferBackedPrimitiveFloatList(0); } } @@ -152,6 +168,7 @@ public int size() { @Override public void clear() { size = 0; + changed = true; } private int getCapacity() { @@ -217,6 +234,7 @@ public boolean addPrimitive(float o) { elements = newElements; } elements[size++] = o; + changed = true; return true; } @@ -239,11 +257,12 @@ public void add(int location, Float o) { System.arraycopy(elements, location, elements, location + 1, size - location); elements[location] = o; size++; + changed = true; } @Override public Float set(int i, Float o) { - return set(i, o); + return setPrimitive(i, o); } @Override @@ -254,6 +273,7 @@ public float setPrimitive(int i, float o) { cacheFromByteBuffer(); float response = elements[i]; elements[i] = o; + changed = true; return response; } @@ -268,6 +288,7 @@ public Float remove(int i) { --size; System.arraycopy(elements, i + 1, elements, i, (size - i)); elements[size] = 0; + changed = true; return result; } @@ -332,6 +353,33 @@ public void reverse() { left++; right--; } + changed = true; + } + + protected void writeFloatsByBackedBytes(Encoder encoder) throws IOException { + List byteBufferList = byteBuffer.getByteBuffers(); + for (int i = 0; i < byteBuffer.getByteBufferCount(); ++i) { + ByteBuffer bb = byteBufferList.get(i); + encoder.writeFixed(bb.array(), 0, bb.limit()); + } + } + + public void writeFloats(Encoder encoder) throws IOException { + if (changed) { + /** + * The backed {@link #byteBuffer} diverges from the current array, so this function will write float from + * {@link #elements}. + */ + for (int i = 0; i < size; ++i) { + encoder.startItem(); + encoder.writeFloat(elements[i]); + } + } else { + /** + * So we will write the original bytes directly. + */ + writeFloatsByBackedBytes(encoder); + } } @Override diff --git a/fastserde/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/CompositeByteBuffer.java b/fastserde/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/CompositeByteBuffer.java index e9a6d6ae8..6136465f9 100644 --- a/fastserde/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/CompositeByteBuffer.java +++ b/fastserde/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/CompositeByteBuffer.java @@ -43,6 +43,14 @@ public void setByteBufferCount(int count) { byteBufferCount = count; } + int getByteBufferCount() { + return this.byteBufferCount; + } + + List getByteBuffers() { + return this.byteBuffers; + } + public float getElement(int i) { int index = i * 4; // most common case: diff --git a/fastserde/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/FastDeserializerGenerator.java b/fastserde/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/FastDeserializerGenerator.java index 2a0c344db..85c2ca3c9 100644 --- a/fastserde/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/FastDeserializerGenerator.java +++ b/fastserde/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/FastDeserializerGenerator.java @@ -4,6 +4,7 @@ import com.linkedin.avro.fastserde.backport.ResolvingGrammarGenerator; import com.linkedin.avro.fastserde.backport.Symbol; import com.linkedin.avroutil1.compatibility.AvroCompatibilityHelper; +import com.linkedin.avroutil1.compatibility.AvroVersion; import com.sun.codemodel.JArray; import com.sun.codemodel.JBlock; import com.sun.codemodel.JCatchBlock; @@ -802,7 +803,19 @@ private void processArray(JVar arraySchemaVar, final String name, final Schema a /* N.B.: Need to use the erasure because instanceof does not support generic types */ ifCodeGen(parentBody, finalReuseSupplier.get()._instanceof(abstractErasedArrayClass), then2 -> { then2.assign(arrayVar, JExpr.cast(abstractErasedArrayClass, finalReuseSupplier.get())); - then2.invoke(arrayVar, "clear"); + + if (SchemaAssistant.isPrimitive(arraySchema.getElementType()) || + Utils.getRuntimeAvroVersion().earlierThan(AvroVersion.AVRO_1_9) ) { // GenericArray in Avro-1.9 or later supports 'reset' + then2.invoke(arrayVar, "clear"); + } else { + /** + * For {@link GenericArray}, 'reset' is more efficient than 'clear', since 'reset' won't + * clear the previous elements, but just set size to be 0. + */ + ifCodeGen(then2, arrayVar._instanceof(codeModel.ref(GenericArray.class)), then3 -> { + then3.invoke(JExpr.cast(codeModel.ref(GenericArray.class), arrayVar), "reset"); + }, else3 -> else3.invoke(arrayVar, "clear")); + } }, else2 -> { else2.assign(arrayVar, finalNewArrayExp); }); diff --git a/fastserde/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/FastSerializerGenerator.java b/fastserde/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/FastSerializerGenerator.java index 7ff49db98..00d994e1c 100644 --- a/fastserde/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/FastSerializerGenerator.java +++ b/fastserde/avro-fastserde/src/main/java/com/linkedin/avro/fastserde/FastSerializerGenerator.java @@ -26,7 +26,6 @@ import org.apache.avro.Schema; import org.apache.avro.generic.GenericData; import org.apache.avro.io.Encoder; -import org.apache.avro.specific.SpecificData; import org.apache.avro.util.Utf8; import org.apache.commons.lang3.StringUtils; @@ -204,8 +203,23 @@ private void processArray(final Schema arraySchema, JExpression arrayExpr, JBloc ifCodeGen(else1, primitiveListCondition, then2 -> { final JVar primitiveList = declareValueVar("primitiveList", arraySchema, then2, true, false, true) - .init(JExpr.cast(primitiveListInterface, arrayVar)); - processArrayElementLoop(arraySchema, arrayClass, primitiveList, then2, "getPrimitive"); + .init(JExpr.cast(primitiveListInterface, arrayVar)); + if (arraySchema.getElementType().getType().equals(Schema.Type.FLOAT)) { + /** + * Check whether it is an instance of {@link BufferBackedPrimitiveFloatList} or not. + */ + JClass bufferBackedPrimitiveFloatListClass = codeModel.ref(BufferBackedPrimitiveFloatList.class); + final JExpression bufferBackedPrimitiveFloatListCondition = primitiveList._instanceof(bufferBackedPrimitiveFloatListClass); + ifCodeGen(then2, bufferBackedPrimitiveFloatListCondition, then3 -> { + final JVar bufferBackedPrimitiveFloatList = then3.decl(bufferBackedPrimitiveFloatListClass, + "bufferBackedPrimitiveFloatList", JExpr.cast(bufferBackedPrimitiveFloatListClass, primitiveList)); + then3.invoke(bufferBackedPrimitiveFloatList, "writeFloats").arg(JExpr.direct(ENCODER)); + }, else3 -> { + processArrayElementLoop(arraySchema, arrayClass, primitiveList, else3, "getPrimitive"); + }); + } else { + processArrayElementLoop(arraySchema, arrayClass, primitiveList, then2, "getPrimitive"); + } }, else2 -> { processArrayElementLoop(arraySchema, arrayClass, arrayExpr, else2, "get"); }); diff --git a/fastserde/avro-fastserde/src/main/java11/com/linkedin/avro/fastserde/BufferBackedPrimitiveFloatList.java b/fastserde/avro-fastserde/src/main/java11/com/linkedin/avro/fastserde/BufferBackedPrimitiveFloatList.java index 98ee3b4ad..7082f3974 100644 --- a/fastserde/avro-fastserde/src/main/java11/com/linkedin/avro/fastserde/BufferBackedPrimitiveFloatList.java +++ b/fastserde/avro-fastserde/src/main/java11/com/linkedin/avro/fastserde/BufferBackedPrimitiveFloatList.java @@ -2,6 +2,7 @@ import com.linkedin.avro.api.PrimitiveFloatList; import java.io.IOException; +import java.nio.ByteBuffer; import java.lang.invoke.MethodHandles; import java.lang.invoke.VarHandle; import java.nio.ByteOrder; @@ -12,6 +13,7 @@ import org.apache.avro.generic.GenericArray; import org.apache.avro.generic.GenericData; import org.apache.avro.io.Decoder; +import org.apache.avro.io.Encoder; import org.apache.commons.lang3.ArrayUtils; @@ -44,6 +46,7 @@ public class BufferBackedPrimitiveFloatList extends AbstractList private float[] elements = EMPTY; private boolean isCached = false; private byte[] byteBuffer; + private boolean changed = false; private static final VarHandle VH = MethodHandles.byteArrayViewVarHandle(float[].class, ByteOrder.LITTLE_ENDIAN); @@ -60,6 +63,17 @@ public BufferBackedPrimitiveFloatList(Collection c) { } } + /** + * For testing purpose. + */ + public void copyInternalState(BufferBackedPrimitiveFloatList another) { + another.size = this.size; + another.elements = this.elements; + another.isCached = this.isCached; + another.byteBuffer = this.byteBuffer; + another.changed = this.changed; + } + /** * Instantiate (or re-use) and populate a {@link BufferBackedPrimitiveFloatList} from a {@link Decoder}. @@ -143,6 +157,7 @@ private static Object newPrimitiveFloatArray(Object old) { oldFloatList.byteBuffer = null; oldFloatList.isCached = false; oldFloatList.size = 0; + oldFloatList.changed = false; return oldFloatList; } else { // Just a place holder, will set up the elements later. @@ -228,6 +243,7 @@ public boolean addPrimitive(float o) { elements = newElements; } elements[size++] = o; + changed = true; return true; } @@ -250,11 +266,12 @@ public void add(int location, Float o) { System.arraycopy(elements, location, elements, location + 1, size - location); elements[location] = o; size++; + changed = true; } @Override public Float set(int i, Float o) { - return set(i, o); + return setPrimitive(i, o); } @Override @@ -265,6 +282,7 @@ public float setPrimitive(int i, float o) { cacheFromByteBuffer(); float response = elements[i]; elements[i] = o; + changed = true; return response; } @@ -279,6 +297,7 @@ public Float remove(int i) { --size; System.arraycopy(elements, i + 1, elements, i, (size - i)); elements[size] = 0; + changed = true; return result; } @@ -345,6 +364,29 @@ public void reverse() { left++; right--; } + changed = true; + } + + protected void writeFloatsByBackedBytes(Encoder encoder) throws IOException { + encoder.writeFixed(byteBuffer); + } + + public void writeFloats(Encoder encoder) throws IOException { + if (changed) { + /** + * The backed {@link #byteBuffer} diverges from the current array, so this function will write float from + * {@link #elements}. + */ + for (int i = 0; i < size; ++i) { + encoder.startItem(); + encoder.writeFloat(elements[i]); + } + } else { + /** + * So we will write the original bytes directly. + */ + writeFloatsByBackedBytes(encoder); + } } @Override