diff --git a/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/desugar/ClosureGenerator.java b/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/desugar/ClosureGenerator.java index d5fc03ba9651..8dddd9b2e014 100644 --- a/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/desugar/ClosureGenerator.java +++ b/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/desugar/ClosureGenerator.java @@ -209,10 +209,10 @@ import java.util.Set; import static org.ballerinalang.model.symbols.SymbolOrigin.VIRTUAL; +import static org.wso2.ballerinalang.compiler.util.CompilerUtils.isInParameterList; import static org.wso2.ballerinalang.compiler.util.Constants.DOLLAR; import static org.wso2.ballerinalang.compiler.util.Constants.RECORD_DELIMITER; import static org.wso2.ballerinalang.compiler.util.Constants.UNDERSCORE; -import static org.wso2.ballerinalang.compiler.util.CompilerUtils.isInParameterList; /** * ClosureGenerator for creating closures for default values. @@ -421,36 +421,19 @@ public void visit(BLangRecordTypeNode recordTypeNode) { rewrite(field, recordTypeNode.typeDefEnv); } recordTypeNode.restFieldType = rewrite(recordTypeNode.restFieldType, env); - // In the current implementation, closures generated for default values in inclusions defined in a - // separate module are unidentifiable. - // Due to that, if the inclusions are in different modules, we generate closures again. - // Will be fixed with #41949 issue. - generateClosuresForDefaultValuesInTypeInclusionsFromDifferentModule(recordTypeNode); + generateClosuresForNonOverriddenFields(recordTypeNode); result = recordTypeNode; } - private List getFieldNames(List fields) { - List fieldNames = new ArrayList<>(); - for (BLangSimpleVariable field : fields) { - fieldNames.add(field.name.getValue()); - } - return fieldNames; - } - - private void generateClosuresForDefaultValuesInTypeInclusionsFromDifferentModule( - BLangRecordTypeNode recordTypeNode) { + private void generateClosuresForNonOverriddenFields(BLangRecordTypeNode recordTypeNode) { if (recordTypeNode.typeRefs.isEmpty()) { return; } List fieldNames = getFieldNames(recordTypeNode.fields); BTypeSymbol typeSymbol = recordTypeNode.getBType().tsymbol; String typeName = recordTypeNode.symbol.name.value; - PackageID packageID = typeSymbol.pkgID; for (BLangType type : recordTypeNode.typeRefs) { BType bType = type.getBType(); - if (packageID.equals(bType.tsymbol.pkgID)) { - continue; - } BRecordType recordType = (BRecordType) Types.getReferredType(bType); Map defaultValuesOfTypeRef = ((BRecordTypeSymbol) recordType.tsymbol).defaultValues; @@ -467,6 +450,14 @@ private void generateClosuresForDefaultValuesInTypeInclusionsFromDifferentModule } } + private List getFieldNames(List fields) { + List fieldNames = new ArrayList<>(); + for (BLangSimpleVariable field : fields) { + fieldNames.add(field.name.getValue()); + } + return fieldNames; + } + @Override public void visit(BLangTupleTypeNode tupleTypeNode) { BTypeSymbol typeSymbol = tupleTypeNode.getBType().tsymbol; diff --git a/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/desugar/Desugar.java b/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/desugar/Desugar.java index 738a97288800..1e9308685081 100644 --- a/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/desugar/Desugar.java +++ b/compiler/ballerina-lang/src/main/java/org/wso2/ballerinalang/compiler/desugar/Desugar.java @@ -6190,9 +6190,12 @@ private BLangRecordLiteral.BLangRecordKeyValueField createRecordKeyValueField(Lo } public void generateFieldsForUserUnspecifiedRecordFields(BLangRecordLiteral recordLiteral, - List userSpecifiedFields) { + List userSpecifiedFields) { BType type = Types.getImpliedType(recordLiteral.getBType()); - if (type.getKind() != TypeKind.RECORD) { + // If we are spreading an open record at compile time we can't determine which fields may be missing. Instead, + // {@code MapValueImpl.populateInitialValues} should fill in any missing fields by calling the default + // closures. + if (type.getKind() != TypeKind.RECORD || isSpreadingAnOpenRecord(userSpecifiedFields)) { return; } List fieldNames = getNamesOfUserSpecifiedRecordFields(userSpecifiedFields); @@ -6202,6 +6205,23 @@ public void generateFieldsForUserUnspecifiedRecordFields(BLangRecordLiteral reco generateFieldsForUserUnspecifiedRecordFields(recordType, userSpecifiedFields, fieldNames, pos, isReadonly); } + private boolean isSpreadingAnOpenRecord(List userSpecifiedFields) { + for (RecordLiteralNode.RecordField field : userSpecifiedFields) { + if (!(field instanceof BLangRecordLiteral.BLangRecordSpreadOperatorField spreadOperatorField)) { + continue; + } + BType type = Types.getReferredType(spreadOperatorField.expr.getBType()); + if (!(type instanceof BRecordType recordType)) { + return true; + } + if (recordType.restFieldType != null && + !types.isNeverTypeOrStructureTypeWithARequiredNeverMember(recordType.restFieldType)) { + return true; + } + } + return false; + } + private void generateFieldsForUserUnspecifiedRecordFields(BRecordType recordType, List fields, List fieldNames, Location pos, diff --git a/tests/jballerina-unit-test/src/test/java/org/ballerinalang/test/bala/record/ClosedRecordTypeInclusionTest.java b/tests/jballerina-unit-test/src/test/java/org/ballerinalang/test/bala/record/ClosedRecordTypeInclusionTest.java index 4be4ecfc9ee9..5740c8877f9a 100644 --- a/tests/jballerina-unit-test/src/test/java/org/ballerinalang/test/bala/record/ClosedRecordTypeInclusionTest.java +++ b/tests/jballerina-unit-test/src/test/java/org/ballerinalang/test/bala/record/ClosedRecordTypeInclusionTest.java @@ -201,6 +201,7 @@ public Object[] testFunctions() { "testRestTypeOverriding", "testOutOfOrderFieldOverridingFieldFromTypeInclusion", "testTypeInclusionWithFiniteField", + "testDefaultValueFromInclusion" }; } diff --git a/tests/jballerina-unit-test/src/test/java/org/ballerinalang/test/bala/record/OpenRecordTypeInclusionTest.java b/tests/jballerina-unit-test/src/test/java/org/ballerinalang/test/bala/record/OpenRecordTypeInclusionTest.java index 122965a4088d..1bdc0a54b570 100644 --- a/tests/jballerina-unit-test/src/test/java/org/ballerinalang/test/bala/record/OpenRecordTypeInclusionTest.java +++ b/tests/jballerina-unit-test/src/test/java/org/ballerinalang/test/bala/record/OpenRecordTypeInclusionTest.java @@ -193,7 +193,9 @@ public Object[] testFunctions() { "testCyclicRecord", "testOutOfOrderFieldOverridingFieldFromTypeInclusion", "testCreatingRecordWithOverriddenFields", - "testDefaultValuesOfRecordFieldsWithTypeInclusion" + "testDefaultValuesOfRecordFieldsWithTypeInclusion", + "testDefaultValueFromInclusion", + "testSpreadOverrideDefault" }; } diff --git a/tests/jballerina-unit-test/src/test/resources/test-src/record/closed_record_type_inclusion.bal b/tests/jballerina-unit-test/src/test/resources/test-src/record/closed_record_type_inclusion.bal index 04e896c57eee..ec43cd156b5a 100644 --- a/tests/jballerina-unit-test/src/test/resources/test-src/record/closed_record_type_inclusion.bal +++ b/tests/jballerina-unit-test/src/test/resources/test-src/record/closed_record_type_inclusion.bal @@ -291,9 +291,111 @@ function testTypeInclusionWithFiniteField() { assertEquality(true, expr is UnaryExpr); } +isolated int count1 = 0; +isolated int count2 = 0; + +isolated function getDefaultVal1() returns Val1 { + lock { + count1 = count1 + 1; + } + return {val: 10}; +} + +isolated function getDefaultVal2() returns Val2 { + lock { + count2 = count2 + 1; + } + return {val: 3.3}; +} + +type Val1 record {| + int val; +|}; + +type Val2 record {| + float val; +|}; + +type Base1 record {| + Val1 val1 = getDefaultVal1(); +|}; + +type BO1 record {| + Val1 val1?; +|}; + +type Base2 record {| + Val2 val2 = getDefaultVal2(); +|}; + +type BO2 record {| + Val2 val2?; +|}; + +type Derived record {| + *Base1; + *Base2; +|}; + +isolated function testDefaultValueFromInclusion() { + BO1 bo1 = {}; + BO2 bo2 = {}; + Derived d = {...bo1, ...bo2}; + assertEquality(10, d.val1.val); + assertEquality(3.3, d.val2.val); + lock { + assertEquality(1, count1); + } + lock { + assertEquality(1, count2); + } + + BO1 bo3 = {val1: {val: 20}}; + Derived d1 = {...bo3, ...bo2}; + assertEquality(20, d1.val1.val); + assertEquality(3.3, d1.val2.val); + lock { + assertEquality(1, count1); + } + lock { + assertEquality(2, count2); + } + + BO2 bo4 = {val2: {val: 30.3}}; + Derived d2 = {...bo3, ...bo4}; + assertEquality(20, d2.val1.val); + assertEquality(30.3, d2.val2.val); + lock { + assertEquality(1, count1); + } + lock { + assertEquality(2, count2); + } + + Derived d3 = {}; + assertEquality(10, d3.val1.val); + assertEquality(3.3, d3.val2.val); + lock { + assertEquality(2, count1); + } + lock { + assertEquality(3, count2); + } + + Derived d4 = {val1: {val: 40}, val2: {val: 50.5}}; + assertEquality(40, d4.val1.val); + assertEquality(50.5, d4.val2.val); + lock { + assertEquality(2, count1); + } + lock { + assertEquality(3, count2); + } +} + const ASSERTION_ERROR_REASON = "AssertionError"; -function assertEquality(any|error expected, any|error actual) { +isolated function assertEquality(any|error expected, any|error actual) { if expected is anydata && actual is anydata && expected == actual { return; } diff --git a/tests/jballerina-unit-test/src/test/resources/test-src/record/open_record_type_inclusion.bal b/tests/jballerina-unit-test/src/test/resources/test-src/record/open_record_type_inclusion.bal index dc96e34252e2..2c093b73eb54 100644 --- a/tests/jballerina-unit-test/src/test/resources/test-src/record/open_record_type_inclusion.bal +++ b/tests/jballerina-unit-test/src/test/resources/test-src/record/open_record_type_inclusion.bal @@ -289,9 +289,158 @@ function testDefaultValuesOfRecordFieldsWithTypeInclusion() { assertEquality(30, info1.age); } +type Inner record {| + int foo; +|}; + +type Outer record {| + Inner inner?; +|}; + +isolated int count = 0; + +isolated function getDefaultInner() returns Inner { + lock { + count += 1; + } + return {foo: 10}; +} + +type OuterXBase record { + Inner inner = getDefaultInner(); +}; + +type OuterX record {| + *OuterXBase; +|}; + +type OuterXOpenRecord record {| + *OuterXBase; + Inner...; +|}; + +type InnerOpenRec record {| + Inner...; +|}; + +type OuterXAlsoClosed record {| + *OuterXBase; + never...; +|}; + +type EffectivelyCloseRecord record {| + *OuterXBase; + record {| + never bar; + |}...; +|}; + +isolated function testDefaultValueFromInclusion() { + Outer o = {}; + OuterX ox = {...o}; + assertEquality(ox.inner.foo, 10); + lock { + assertEquality(1, count); + } + Outer o1 = {inner: {foo: 20}}; + OuterX ox1 = {...o1}; + assertEquality(20, ox1.inner.foo); + lock { + assertEquality(1, count); + } + map innerMap = {}; + OuterX ox2 = {...innerMap}; + assertEquality(10, ox2.inner.foo); + lock { + assertEquality(2, count); + } + + map innerMap1 = {inner: {foo: 20}}; + OuterX ox3 = {...innerMap1}; + assertEquality(20, ox3.inner.foo); + lock { + assertEquality(2, count); + } + + InnerOpenRec innerMap2 = {"inner": {foo: 20}}; + + OuterX ox4 = {...innerMap2}; + assertEquality(20, ox4.inner.foo); + lock { + assertEquality(2, count); + } + + OuterXOpenRecord ox5 = {...innerMap2}; + assertEquality(20, ox5.inner.foo); + lock { + assertEquality(2, count); + } + + OuterXOpenRecord ox6 = {...o}; + assertEquality(10, ox6.inner.foo); + lock { + assertEquality(3, count); + } + + OuterXAlsoClosed oxx = {...o}; + assertEquality(oxx.inner.foo, 10); + lock { + assertEquality(4, count); + } + OuterXAlsoClosed oxx1 = {...o1}; + assertEquality(20, oxx1.inner.foo); + lock { + assertEquality(4, count); + } + + OuterX ox7 = {}; + assertEquality(ox7.inner.foo, 10); + lock { + assertEquality(5, count); + } + + OuterX ox8 = {inner: {foo: 20}}; + assertEquality(ox8.inner.foo, 20); + lock { + assertEquality(5, count); + } + + EffectivelyCloseRecord ecr = {...o}; + assertEquality(ecr.inner.foo, 10); + lock { + assertEquality(6, count); + } + + EffectivelyCloseRecord ecr1 = {...o1}; + assertEquality(ecr1.inner.foo, 20); + lock { + assertEquality(6, count); + } +} + +type Data record { + string id = fn(); + string name; +}; + +type OpenData record {| + string name; + string...; +|}; + +isolated function fn() returns string { + panic error("shouldn't be called"); +} + +public function testSpreadOverrideDefault() { + OpenData or = {name: "May", "id": "A1234"}; + Data emp = {...or}; + assertEquality("A1234", emp.id); +} + const ASSERTION_ERROR_REASON = "AssertionError"; -function assertEquality(any|error expected, any|error actual) { +isolated function assertEquality(any|error expected, any|error actual) { if expected is anydata && actual is anydata && expected == actual { return; }