diff --git a/ballerina/Ballerina.toml b/ballerina/Ballerina.toml index 80d8934..b5347f1 100644 --- a/ballerina/Ballerina.toml +++ b/ballerina/Ballerina.toml @@ -1,7 +1,7 @@ [package] org = "ballerina" name = "avro" -version = "1.0.1" +version = "1.0.2" authors = ["Ballerina"] export=["avro"] keywords = ["avro", "serialization", "deserialization", "serdes"] @@ -18,8 +18,8 @@ graalvmCompatible = true [[platform.java17.dependency]] groupId = "io.ballerina.lib" artifactId = "avro-native" -version = "1.0.1" -path = "../native/build/libs/avro-native-1.0.1.jar" +version = "1.0.2" +path = "../native/build/libs/avro-native-1.0.2-SNAPSHOT.jar" [[platform.java17.dependency]] groupId = "org.apache.avro" diff --git a/ballerina/Dependencies.toml b/ballerina/Dependencies.toml index 64fe075..398d69e 100644 --- a/ballerina/Dependencies.toml +++ b/ballerina/Dependencies.toml @@ -10,7 +10,7 @@ distribution-version = "2201.9.0" [[package]] org = "ballerina" name = "avro" -version = "1.0.1" +version = "1.0.2" dependencies = [ {org = "ballerina", name = "io"}, {org = "ballerina", name = "jballerina.java"}, @@ -23,7 +23,7 @@ modules = [ [[package]] org = "ballerina" name = "io" -version = "1.6.0" +version = "1.6.1" scope = "testOnly" dependencies = [ {org = "ballerina", name = "jballerina.java"}, diff --git a/ballerina/tests/primitive_tests.bal b/ballerina/tests/primitive_tests.bal index 0b7add8..5d3ebaa 100644 --- a/ballerina/tests/primitive_tests.bal +++ b/ballerina/tests/primitive_tests.bal @@ -134,3 +134,67 @@ public isolated function testNullValuesWithNonNullData() returns error? { byte[]|error serializedValue = avro.toAvro("string"); test:assertTrue(serializedValue is error); } + +@test:Config { + groups: ["primitive"] +} +public isolated function testUnsignedShortValue() returns error? { + string schema = string ` + { + "type": "int", + "name" : "shortValue", + "namespace": "data" + }`; + + int:Unsigned8 value = 255; + return verifyOperation(int:Unsigned8, value, schema); +} + +@test:Config { + groups: ["primitive", "check"] +} +public isolated function testSignedShortValue() returns error? { + string schema = string ` + { + "type": "int", + "name" : "shortValue", + "namespace": "data" + }`; + + int:Signed32 value = 555950000; + return verifyOperation(int:Signed32, value, schema); +} + +@test:Config { + groups: ["primitive", "check"] +} +public isolated function testSignedMinusShortValue() returns error? { + string schema = string ` + { + "type": "double", + "name" : "shortValue", + "namespace": "data" + }`; + + int:Signed32 value = -2147483; + Schema avro = check new (schema); + byte[] serializedValue = check avro.toAvro(value); + float deserializedValue = check avro.fromAvro(serializedValue); + test:assertEquals(deserializedValue, -2147483.0); +} + +@test:Config { + groups: ["primitive", "check"] +} +public isolated function testByteValue() returns error? { + string schema = string ` + { + "type": "bytes", + "name" : "byteValue", + "namespace": "data" + }`; + byte[] data = []; + byte value = 2; + data.push(value); + return verifyOperation(ByteArray, data, schema); +} diff --git a/ballerina/tests/record_tests.bal b/ballerina/tests/record_tests.bal index bc26942..2335abc 100644 --- a/ballerina/tests/record_tests.bal +++ b/ballerina/tests/record_tests.bal @@ -370,3 +370,99 @@ public isolated function testOptionalMultipleFieldsInRecords() returns error? { }; return verifyOperation(Lecturer6, lecturer6, schema); } + +@test:Config { + groups: ["record", "union"] +} +public isolated function testTypeCastingInRecords() returns error? { + string jsonFileName = string `tests/resources/schema_complex.json`; + string schema = (check io:fileReadJson(jsonFileName)).toString(); + Record recordValue = { + hrdCasinoGgrLt: 4999.01, + hrdAccountCreationTimestamp: "2024-02-25T20:50:37.891782", + hrdSportsBetCountLifetime: 77, + hrdSportsFreeBetAmountLifetime: 0, + hriUnityId: "807612199", + hrdLastRealMoneySportsbookBetTs: "2024-04-24T20:15:51.932", + hrdLoyaltyTier: "MEMBER_TIER", + rowInsertTimestampEst: "2024-09-23T03:29:43.431", + ltvSports365Total: 0, + hrdFirstDepositTimestamp: "2024-02-25T20:56:24.109021", + hrdVipStatus: true, + hrdLastRealMoneyCasinoWagerTs: null, + hrdSportsCashBetAmountLifetime: 731, + hrdAccountStatus: "ACTIVE", + hrdAccountId: "1362347172559585332", + ltvAllVerticals365Total: 0, + hrdOptInSms: false, + ltvCasino365Total: null, + hrdAccountSubStatus: null, + hrdCasinoTotalWagerLifetime: null, + hrdSportsGgrLt: 4999.01, + currentGeoSegment: "HRD_FLORIDA", + kycStatus: "VERIFIED", + signupGeoSegment: "HRD_FLORIDA", + hrdSportsBetAmountLifetime: 731, + hrdOptInEmail: false, + hrdOptInPush: false + }; + Schema avro = check new (schema); + byte[] serializedValue = check avro.toAvro(recordValue); + Record deserializedValue = check avro.fromAvro(serializedValue); + + json expected = { + "hrdCasinoGgrLt":4999.01, + "hrdAccountCreationTimestamp":"2024-02-25T20:50:37.891782", + "hrdSportsFreeBetAmountLifetime":0.0, + "hriUnityId":"807612199", + "hrdLastRealMoneySportsbookBetTs":"2024-04-24T20:15:51.932", + "hrdLoyaltyTier":"MEMBER_TIER", + "rowInsertTimestampEst":"2024-09-23T03:29:43.431", + "ltvSports365Total":0.0, + "hrdVipStatus":true, + "hrdSportsBetCountLifetime":77.0, + "hrdLastRealMoneyCasinoWagerTs":null, + "hrdSportsCashBetAmountLifetime":731.0, + "hrdAccountStatus":"ACTIVE", + "hrdAccountId":"1362347172559585332", + "ltvAllVerticals365Total":0.0, + "ltvCasino365Total":null, + "hrdAccountSubStatus":null, + "hrdCasinoTotalWagerLifetime":null, + "hrdSportsGgrLt":4999.01, + "currentGeoSegment":"HRD_FLORIDA", + "kycStatus":"VERIFIED", + "hrdOptInSms":false, + "hrdFirstDepositTimestamp":"2024-02-25T20:56:24.109021", + "signupGeoSegment":"HRD_FLORIDA", + "hrdSportsBetAmountLifetime":731.0, + "hrdOptInEmail":false, + "hrdOptInPush":false + }; + test:assertEquals(deserializedValue.toJson(), expected.toJson()); +} + +@test:Config { + groups: ["record", "union"] +} +public isolated function testUnionsWithRecords() returns error? { + string schema = string `{ + "type": "record", + "name": "DataRecord", + "namespace": "data", + "fields": [ + { + "name": "shortValue", + "type": ["null", "int", "double"] + } + ] + }`; + int:Signed16 short = 5555; + json value = { + "shortValue": short + }; + Schema avro = check new (schema); + byte[] serializedValue = check avro.toAvro(value); + DataRecord deserializedValue = check avro.fromAvro(serializedValue); + test:assertEquals(deserializedValue, value); +} diff --git a/ballerina/tests/resources/schema_complex.json b/ballerina/tests/resources/schema_complex.json new file mode 100644 index 0000000..145f401 --- /dev/null +++ b/ballerina/tests/resources/schema_complex.json @@ -0,0 +1,196 @@ +{ + "connect.name": "io.confluent.ksql.avro_schemas.KsqlDataSourceSchema", + "fields": [ + { + "name": "hrdCasinoGgrLt", + "type": [ + "null", + "double" + ] + }, + { + "name": "hrdAccountCreationTimestamp", + "type": [ + "null", + "string" + ] + }, + { + "name": "hrdSportsBetCountLifetime", + "type": [ + "null", + "double" + ] + }, + { + "name": "hrdSportsFreeBetAmountLifetime", + "type": [ + "null", + "double" + ] + }, + { + "name": "hriUnityId", + "type": [ + "null", + "string" + ] + }, + { + "name": "hrdLastRealMoneySportsbookBetTs", + "type": [ + "null", + "string" + ] + }, + { + "name": "hrdLoyaltyTier", + "type": [ + "null", + "string" + ] + }, + { + "name": "rowInsertTimestampEst", + "type": [ + "null", + "string" + ] + }, + { + "name": "ltvSports365Total", + "type": [ + "null", + "double" + ] + }, + { + "name": "hrdFirstDepositTimestamp", + "type": [ + "null", + "string" + ] + }, + { + "name": "hrdVipStatus", + "type": [ + "null", + "boolean" + ] + }, + { + "name": "hrdLastRealMoneyCasinoWagerTs", + "type": [ + "null", + "string" + ] + }, + { + "name": "hrdSportsCashBetAmountLifetime", + "type": [ + "null", + "double" + ] + }, + { + "name": "hrdAccountStatus", + "type": [ + "null", + "string" + ] + }, + { + "name": "hrdAccountId", + "type": [ + "null", + "string" + ] + }, + { + "name": "ltvAllVerticals365Total", + "type": [ + "null", + "double" + ] + }, + { + "name": "hrdOptInSms", + "type": [ + "null", + "boolean" + ] + }, + { + "name": "ltvCasino365Total", + "type": [ + "null", + "double" + ] + }, + { + "name": "hrdAccountSubStatus", + "type": [ + "null", + "string" + ] + }, + { + "name": "hrdCasinoTotalWagerLifetime", + "type": [ + "null", + "double" + ] + }, + { + "name": "hrdSportsGgrLt", + "type": [ + "null", + "double" + ] + }, + { + "name": "currentGeoSegment", + "type": [ + "null", + "string" + ] + }, + { + "name": "kycStatus", + "type": [ + "null", + "string" + ] + }, + { + "name": "signupGeoSegment", + "type": [ + "null", + "string" + ] + }, + { + "name": "hrdSportsBetAmountLifetime", + "type": [ + "null", + "double" + ] + }, + { + "name": "hrdOptInEmail", + "type": [ + "null", + "boolean" + ] + }, + { + "name": "hrdOptInPush", + "type": [ + "null", + "boolean" + ] + } + ], + "name": "Record", + "type": "record" +} diff --git a/ballerina/tests/types.bal b/ballerina/tests/types.bal index 0c613c9..f68869b 100644 --- a/ballerina/tests/types.bal +++ b/ballerina/tests/types.bal @@ -271,7 +271,6 @@ public type Envelope2 record { string? MessageSource; }; - public type UnionEnumRecord record { string|Numbers? field1; }; @@ -288,6 +287,36 @@ public type ReadOnlyRec readonly & record { string|UnionEnumRecord? & readonly field1; }; +type Record record { + decimal? hrdCasinoGgrLt; + string? hrdAccountCreationTimestamp; + float? hrdSportsBetCountLifetime; + float? hrdSportsFreeBetAmountLifetime; + string? hriUnityId; + string? hrdLastRealMoneySportsbookBetTs; + string? hrdLoyaltyTier; + string? rowInsertTimestampEst; + float? ltvSports365Total; + string? hrdFirstDepositTimestamp; + boolean? hrdVipStatus; + string? hrdLastRealMoneyCasinoWagerTs; + float? hrdSportsCashBetAmountLifetime; + string? hrdAccountStatus; + string? hrdAccountId; + float? ltvAllVerticals365Total; + boolean? hrdOptInSms; + float? ltvCasino365Total; + string? hrdAccountSubStatus; + float? hrdCasinoTotalWagerLifetime; + float? hrdSportsGgrLt; + string? currentGeoSegment; + string? kycStatus; + string? signupGeoSegment; + float? hrdSportsBetAmountLifetime; + boolean? hrdOptInEmail; + boolean? hrdOptInPush; +}; + type ReadOnlyUnionFixed UnionFixedRecord & readonly; type ByteArray byte[]; type ReadOnlyIntArray int[] & readonly; @@ -334,3 +363,4 @@ type FloatArray float[]; type EnumArray Numbers[]; type Enum2DArray Numbers[][]; type ReadOnlyString2DArray string[][] & readonly; +type DataRecord record{}; diff --git a/native/src/main/java/io/ballerina/lib/avro/deserialize/visitor/UnionRecordUtils.java b/native/src/main/java/io/ballerina/lib/avro/deserialize/visitor/UnionRecordUtils.java index b56484f..c631945 100644 --- a/native/src/main/java/io/ballerina/lib/avro/deserialize/visitor/UnionRecordUtils.java +++ b/native/src/main/java/io/ballerina/lib/avro/deserialize/visitor/UnionRecordUtils.java @@ -39,41 +39,26 @@ public class UnionRecordUtils { public static void visitUnionRecords(Type type, BMap ballerinaRecord, Schema.Field field, Object fieldData) throws Exception { + int size = ballerinaRecord.size(); for (Schema schemaType : field.schema().getTypes()) { if (fieldData == null) { ballerinaRecord.put(StringUtils.fromString(field.name()), null); break; } switch (schemaType.getType()) { - case BYTES: - handleBytesField(field, fieldData, ballerinaRecord); - break; - case FIXED: - handleFixedField(field, fieldData, ballerinaRecord); - break; - case ARRAY: - handleArrayField(field, fieldData, ballerinaRecord, schemaType); - break; - case MAP: - handleMapField(field, fieldData, ballerinaRecord); - break; - case RECORD: - handleRecordField(type, field, fieldData, ballerinaRecord, schemaType); - break; - case STRING: - handleStringField(field, fieldData, ballerinaRecord); - break; - case INT, LONG: - handleIntegerField(field, fieldData, ballerinaRecord); - break; - case FLOAT, DOUBLE: - handleFloatField(field, fieldData, ballerinaRecord); - break; - case ENUM: - handleEnumField(field, fieldData, ballerinaRecord); - break; - default: - handleDefaultField(field, fieldData, ballerinaRecord); + case BYTES -> handleBytesField(field, fieldData, ballerinaRecord); + case FIXED -> handleFixedField(field, fieldData, ballerinaRecord); + case ARRAY -> handleArrayField(field, fieldData, ballerinaRecord, schemaType); + case MAP -> handleMapField(field, fieldData, ballerinaRecord); + case RECORD -> handleRecordField(type, field, fieldData, ballerinaRecord, schemaType); + case STRING -> handleStringField(field, fieldData, ballerinaRecord); + case INT, LONG -> handleIntegerField(field, fieldData, ballerinaRecord); + case FLOAT, DOUBLE -> handleFloatField(field, fieldData, ballerinaRecord); + case ENUM -> handleEnumField(field, fieldData, ballerinaRecord); + default -> handleDefaultField(field, fieldData, ballerinaRecord); + } + if (ballerinaRecord.size() != size) { + break; } } } diff --git a/native/src/main/java/io/ballerina/lib/avro/serialize/visitor/SerializeVisitor.java b/native/src/main/java/io/ballerina/lib/avro/serialize/visitor/SerializeVisitor.java index 4c52977..527655e 100644 --- a/native/src/main/java/io/ballerina/lib/avro/serialize/visitor/SerializeVisitor.java +++ b/native/src/main/java/io/ballerina/lib/avro/serialize/visitor/SerializeVisitor.java @@ -22,6 +22,7 @@ import io.ballerina.lib.avro.serialize.EnumSerializer; import io.ballerina.lib.avro.serialize.FixedSerializer; import io.ballerina.lib.avro.serialize.MapSerializer; +import io.ballerina.lib.avro.serialize.MessageFactory; import io.ballerina.lib.avro.serialize.PrimitiveSerializer; import io.ballerina.lib.avro.serialize.RecordSerializer; import io.ballerina.lib.avro.serialize.Serializer; @@ -33,12 +34,14 @@ import io.ballerina.runtime.api.utils.StringUtils; import io.ballerina.runtime.api.utils.TypeUtils; import io.ballerina.runtime.api.values.BArray; +import io.ballerina.runtime.api.values.BDecimal; import io.ballerina.runtime.api.values.BMap; import org.apache.avro.Schema; import org.apache.avro.generic.GenericData; import org.apache.avro.generic.GenericRecord; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -95,32 +98,42 @@ private Object serializeField(Schema schema, Object fieldData) throws Exception @Override public Object visit(PrimitiveSerializer primitiveSerializer, Object data) throws Exception { - switch (primitiveSerializer.getSchema().getType()) { + return switch (primitiveSerializer.getSchema().getType()) { case INT -> { - return ((Long) data).intValue(); + if (data instanceof Long longValue) { + yield longValue.intValue(); + } + yield data; } case FLOAT -> { - return ((Double) data).floatValue(); + if (data instanceof Double doubleValue) { + yield doubleValue.floatValue(); + } + yield data; + } + case DOUBLE -> { + if (data instanceof Long longValue) { + yield longValue.doubleValue(); + } else if (data instanceof BDecimal decimalValue) { + yield decimalValue.floatValue(); + } + yield data; } case BYTES -> { ByteBuffer byteBuffer = ByteBuffer.allocate(((BArray) data).getByteArray().length); byteBuffer.put(((BArray) data).getByteArray()); byteBuffer.position(0); - return byteBuffer; - } - case STRING -> { - return data.toString(); + yield byteBuffer; } + case STRING -> data.toString(); case NULL -> { if (data != null) { throw new Exception("The value does not match with the null schema"); } - return null; + yield null; } - default -> { - return data; - } - } + default -> data; + }; } public Map visit(MapSerializer mapSerializer, BMap data) throws Exception { @@ -156,96 +169,42 @@ public GenericData.Array visit(ArraySerializer arraySerializer, BArray d return Objects.requireNonNull(visitor).visit(data, arraySerializer.getSchema(), array); } - public Object visit(UnionSerializer unionSerializer, Object data) throws Exception { - Schema fieldSchema = unionSerializer.getSchema(); - Type typeName = TypeUtils.getType(data); - switch (typeName.getTag()) { - case TypeTags.STRING_TAG -> { - return visitUnionStrings(data, fieldSchema); - } - case TypeTags.ARRAY_TAG -> { - return visitUnionArrays(data, fieldSchema); - } - case TypeTags.MAP_TAG -> { - return new MapSerializer(fieldSchema).convert(this, data); - } - case TypeTags.RECORD_TYPE_TAG -> { - Schema schema = getRecordSchema(Schema.Type.RECORD, fieldSchema.getTypes()); - return new RecordSerializer(schema).convert(this, data); - } - case TypeTags.INT_TAG -> { - return visitUnionIntegers(data, fieldSchema); - } - case TypeTags.FLOAT_TAG -> { - return visitUnionFloats(data, fieldSchema); - } - default -> { - return data; - } - } - } - - private Object visitUnionFloats(Object data, Schema fieldSchema) { - return fieldSchema.getTypes().stream() - .filter(schema -> schema.getType().equals(Schema.Type.FLOAT)) - .findFirst() - .map(schema -> { - try { - return new PrimitiveSerializer(schema).convert(this, data); - } catch (Exception e) { - throw new RuntimeException(e); - } - }) - .orElse(data); - } - - private Object visitUnionIntegers(Object data, Schema fieldSchema) { - return fieldSchema.getTypes().stream() - .filter(schema -> schema.getType().equals(Schema.Type.INT)) - .findFirst() - .map(schema -> { - try { - return new PrimitiveSerializer(schema).convert(this, data); - } catch (Exception e) { - throw new RuntimeException(e); - } - }) - .orElse(data); - } - - private Object visitUnionStrings(Object data, Schema fieldSchema) throws Exception { - return fieldSchema.getTypes().stream() - .filter(type -> type.getType().equals(Schema.Type.ENUM)) - .findFirst() - .map(type -> visit(new EnumSerializer(type), data)) - .orElse(visit(new PrimitiveSerializer(fieldSchema), data.toString())); - } - - private Object visitUnionArrays(Object data, Schema fieldSchema) throws Exception { - for (Schema schema : fieldSchema.getTypes()) { - switch (schema.getType()) { - case BYTES -> { - return new PrimitiveSerializer(schema).convert(this, data); - } - case FIXED -> { - return new FixedSerializer(schema).convert(this, data); - } - case ARRAY -> { - return new ArraySerializer(schema).convert(this, data); - } - } + public ArrayList deriveBallerinaTag(Schema schema) { + ArrayList tags = new ArrayList<>(); + switch (schema.getType()) { + case STRING, ENUM -> tags.add(TypeTags.STRING_TAG); + case FLOAT, DOUBLE -> { + tags.add(TypeTags.FLOAT_TAG); + tags.add(TypeTags.DECIMAL_TAG); + tags.add(TypeTags.INT_TAG); + } + case LONG, INT -> tags.add(TypeTags.INT_TAG); + case BOOLEAN -> tags.add(TypeTags.BOOLEAN_TAG); + case NULL -> tags.add(TypeTags.NULL_TAG); + case RECORD -> tags.add(TypeTags.RECORD_TYPE_TAG); + case ARRAY -> tags.add(TypeTags.ARRAY_TAG); + case MAP -> tags.add(TypeTags.MAP_TAG); + case BYTES, FIXED -> { + tags.add(TypeTags.BYTE_TAG); + tags.add(TypeTags.BYTE_ARRAY_TAG); + tags.add(TypeTags.ARRAY_TAG); + } + default -> tags.add(TypeTags.ANYDATA_TAG); } - return new ArraySerializer(fieldSchema).convert(this, data); + return tags; } - public static Schema getRecordSchema(Schema.Type givenType, List schemas) { - for (Schema schema: schemas) { - if (schema.getType().equals(Schema.Type.UNION)) { - getRecordSchema(givenType, schema.getTypes()); - } else if (schema.getType().equals(givenType)) { - return schema; + public Object visit(UnionSerializer unionSerializer, Object data) throws Exception { + Schema fieldSchema = unionSerializer.getSchema(); + Type typeName = TypeUtils.getType(data); + List types = fieldSchema.getTypes(); + for (Schema type : types) { + ArrayList tags = deriveBallerinaTag(type); + if (tags.contains(typeName.getTag())) { + Serializer serializer = MessageFactory.createMessage(type); + return Objects.requireNonNull(serializer).convert(this, data); } } - return null; + throw new Exception("Value does not match with the Avro union types"); } }