Skip to content

Commit

Permalink
Merge pull request #43363 from heshanpadmasiri/fix/43359
Browse files Browse the repository at this point in the history
Fix NPE when accessing a field with a default value
  • Loading branch information
heshanpadmasiri authored Sep 10, 2024
2 parents 49240e6 + 39c7a3d commit da4eb4e
Show file tree
Hide file tree
Showing 6 changed files with 290 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,10 @@
import java.util.Set;

import static org.ballerinalang.model.symbols.SymbolOrigin.VIRTUAL;
import static org.wso2.ballerinalang.compiler.util.CompilerUtils.isInParameterList;
import static org.wso2.ballerinalang.compiler.util.Constants.DOLLAR;
import static org.wso2.ballerinalang.compiler.util.Constants.RECORD_DELIMITER;
import static org.wso2.ballerinalang.compiler.util.Constants.UNDERSCORE;
import static org.wso2.ballerinalang.compiler.util.CompilerUtils.isInParameterList;

/**
* ClosureGenerator for creating closures for default values.
Expand Down Expand Up @@ -421,36 +421,19 @@ public void visit(BLangRecordTypeNode recordTypeNode) {
rewrite(field, recordTypeNode.typeDefEnv);
}
recordTypeNode.restFieldType = rewrite(recordTypeNode.restFieldType, env);
// In the current implementation, closures generated for default values in inclusions defined in a
// separate module are unidentifiable.
// Due to that, if the inclusions are in different modules, we generate closures again.
// Will be fixed with #41949 issue.
generateClosuresForDefaultValuesInTypeInclusionsFromDifferentModule(recordTypeNode);
generateClosuresForNonOverriddenFields(recordTypeNode);
result = recordTypeNode;
}

private List<String> getFieldNames(List<BLangSimpleVariable> fields) {
List<String> fieldNames = new ArrayList<>();
for (BLangSimpleVariable field : fields) {
fieldNames.add(field.name.getValue());
}
return fieldNames;
}

private void generateClosuresForDefaultValuesInTypeInclusionsFromDifferentModule(
BLangRecordTypeNode recordTypeNode) {
private void generateClosuresForNonOverriddenFields(BLangRecordTypeNode recordTypeNode) {
if (recordTypeNode.typeRefs.isEmpty()) {
return;
}
List<String> fieldNames = getFieldNames(recordTypeNode.fields);
BTypeSymbol typeSymbol = recordTypeNode.getBType().tsymbol;
String typeName = recordTypeNode.symbol.name.value;
PackageID packageID = typeSymbol.pkgID;
for (BLangType type : recordTypeNode.typeRefs) {
BType bType = type.getBType();
if (packageID.equals(bType.tsymbol.pkgID)) {
continue;
}
BRecordType recordType = (BRecordType) Types.getReferredType(bType);
Map<String, BInvokableSymbol> defaultValuesOfTypeRef =
((BRecordTypeSymbol) recordType.tsymbol).defaultValues;
Expand All @@ -467,6 +450,14 @@ private void generateClosuresForDefaultValuesInTypeInclusionsFromDifferentModule
}
}

private List<String> getFieldNames(List<BLangSimpleVariable> fields) {
List<String> fieldNames = new ArrayList<>();
for (BLangSimpleVariable field : fields) {
fieldNames.add(field.name.getValue());
}
return fieldNames;
}

@Override
public void visit(BLangTupleTypeNode tupleTypeNode) {
BTypeSymbol typeSymbol = tupleTypeNode.getBType().tsymbol;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6190,9 +6190,12 @@ private BLangRecordLiteral.BLangRecordKeyValueField createRecordKeyValueField(Lo
}

public void generateFieldsForUserUnspecifiedRecordFields(BLangRecordLiteral recordLiteral,
List<RecordLiteralNode.RecordField> userSpecifiedFields) {
List<RecordLiteralNode.RecordField> userSpecifiedFields) {
BType type = Types.getImpliedType(recordLiteral.getBType());
if (type.getKind() != TypeKind.RECORD) {
// If we are spreading an open record at compile time we can't determine which fields may be missing. Instead,
// {@code MapValueImpl.populateInitialValues} should fill in any missing fields by calling the default
// closures.
if (type.getKind() != TypeKind.RECORD || isSpreadingAnOpenRecord(userSpecifiedFields)) {
return;
}
List<String> fieldNames = getNamesOfUserSpecifiedRecordFields(userSpecifiedFields);
Expand All @@ -6202,6 +6205,23 @@ public void generateFieldsForUserUnspecifiedRecordFields(BLangRecordLiteral reco
generateFieldsForUserUnspecifiedRecordFields(recordType, userSpecifiedFields, fieldNames, pos, isReadonly);
}

private boolean isSpreadingAnOpenRecord(List<RecordLiteralNode.RecordField> userSpecifiedFields) {
for (RecordLiteralNode.RecordField field : userSpecifiedFields) {
if (!(field instanceof BLangRecordLiteral.BLangRecordSpreadOperatorField spreadOperatorField)) {
continue;
}
BType type = Types.getReferredType(spreadOperatorField.expr.getBType());
if (!(type instanceof BRecordType recordType)) {
return true;
}
if (recordType.restFieldType != null &&
!types.isNeverTypeOrStructureTypeWithARequiredNeverMember(recordType.restFieldType)) {
return true;
}
}
return false;
}

private void generateFieldsForUserUnspecifiedRecordFields(BRecordType recordType,
List<RecordLiteralNode.RecordField> fields,
List<String> fieldNames, Location pos,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ public Object[] testFunctions() {
"testRestTypeOverriding",
"testOutOfOrderFieldOverridingFieldFromTypeInclusion",
"testTypeInclusionWithFiniteField",
"testDefaultValueFromInclusion"
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ public Object[] testFunctions() {
"testCyclicRecord",
"testOutOfOrderFieldOverridingFieldFromTypeInclusion",
"testCreatingRecordWithOverriddenFields",
"testDefaultValuesOfRecordFieldsWithTypeInclusion"
"testDefaultValuesOfRecordFieldsWithTypeInclusion",
"testDefaultValueFromInclusion",
"testSpreadOverrideDefault"
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,111 @@ function testTypeInclusionWithFiniteField() {
assertEquality(true, expr is UnaryExpr);
}

isolated int count1 = 0;
isolated int count2 = 0;

isolated function getDefaultVal1() returns Val1 {
lock {
count1 = count1 + 1;
}
return {val: 10};
}

isolated function getDefaultVal2() returns Val2 {
lock {
count2 = count2 + 1;
}
return {val: 3.3};
}

type Val1 record {|
int val;
|};

type Val2 record {|
float val;
|};

type Base1 record {|
Val1 val1 = getDefaultVal1();
|};

type BO1 record {|
Val1 val1?;
|};

type Base2 record {|
Val2 val2 = getDefaultVal2();
|};

type BO2 record {|
Val2 val2?;
|};

type Derived record {|
*Base1;
*Base2;
|};

isolated function testDefaultValueFromInclusion() {
BO1 bo1 = {};
BO2 bo2 = {};
Derived d = {...bo1, ...bo2};
assertEquality(10, d.val1.val);
assertEquality(3.3, d.val2.val);
lock {
assertEquality(1, count1);
}
lock {
assertEquality(1, count2);
}

BO1 bo3 = {val1: {val: 20}};
Derived d1 = {...bo3, ...bo2};
assertEquality(20, d1.val1.val);
assertEquality(3.3, d1.val2.val);
lock {
assertEquality(1, count1);
}
lock {
assertEquality(2, count2);
}

BO2 bo4 = {val2: {val: 30.3}};
Derived d2 = {...bo3, ...bo4};
assertEquality(20, d2.val1.val);
assertEquality(30.3, d2.val2.val);
lock {
assertEquality(1, count1);
}
lock {
assertEquality(2, count2);
}

Derived d3 = {};
assertEquality(10, d3.val1.val);
assertEquality(3.3, d3.val2.val);
lock {
assertEquality(2, count1);
}
lock {
assertEquality(3, count2);
}

Derived d4 = {val1: {val: 40}, val2: {val: 50.5}};
assertEquality(40, d4.val1.val);
assertEquality(50.5, d4.val2.val);
lock {
assertEquality(2, count1);
}
lock {
assertEquality(3, count2);
}
}

const ASSERTION_ERROR_REASON = "AssertionError";

function assertEquality(any|error expected, any|error actual) {
isolated function assertEquality(any|error expected, any|error actual) {
if expected is anydata && actual is anydata && expected == actual {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,158 @@ function testDefaultValuesOfRecordFieldsWithTypeInclusion() {
assertEquality(30, info1.age);
}

type Inner record {|
int foo;
|};

type Outer record {|
Inner inner?;
|};

isolated int count = 0;

isolated function getDefaultInner() returns Inner {
lock {
count += 1;
}
return {foo: 10};
}

type OuterXBase record {
Inner inner = getDefaultInner();
};

type OuterX record {|
*OuterXBase;
|};

type OuterXOpenRecord record {|
*OuterXBase;
Inner...;
|};

type InnerOpenRec record {|
Inner...;
|};

type OuterXAlsoClosed record {|
*OuterXBase;
never...;
|};

type EffectivelyCloseRecord record {|
*OuterXBase;
record {|
never bar;
|}...;
|};

isolated function testDefaultValueFromInclusion() {
Outer o = {};
OuterX ox = {...o};
assertEquality(ox.inner.foo, 10);
lock {
assertEquality(1, count);
}
Outer o1 = {inner: {foo: 20}};
OuterX ox1 = {...o1};
assertEquality(20, ox1.inner.foo);
lock {
assertEquality(1, count);
}
map<Inner> innerMap = {};
OuterX ox2 = {...innerMap};
assertEquality(10, ox2.inner.foo);
lock {
assertEquality(2, count);
}

map<Inner> innerMap1 = {inner: {foo: 20}};
OuterX ox3 = {...innerMap1};
assertEquality(20, ox3.inner.foo);
lock {
assertEquality(2, count);
}

InnerOpenRec innerMap2 = {"inner": {foo: 20}};

OuterX ox4 = {...innerMap2};
assertEquality(20, ox4.inner.foo);
lock {
assertEquality(2, count);
}

OuterXOpenRecord ox5 = {...innerMap2};
assertEquality(20, ox5.inner.foo);
lock {
assertEquality(2, count);
}

OuterXOpenRecord ox6 = {...o};
assertEquality(10, ox6.inner.foo);
lock {
assertEquality(3, count);
}

OuterXAlsoClosed oxx = {...o};
assertEquality(oxx.inner.foo, 10);
lock {
assertEquality(4, count);
}
OuterXAlsoClosed oxx1 = {...o1};
assertEquality(20, oxx1.inner.foo);
lock {
assertEquality(4, count);
}

OuterX ox7 = {};
assertEquality(ox7.inner.foo, 10);
lock {
assertEquality(5, count);
}

OuterX ox8 = {inner: {foo: 20}};
assertEquality(ox8.inner.foo, 20);
lock {
assertEquality(5, count);
}

EffectivelyCloseRecord ecr = {...o};
assertEquality(ecr.inner.foo, 10);
lock {
assertEquality(6, count);
}

EffectivelyCloseRecord ecr1 = {...o1};
assertEquality(ecr1.inner.foo, 20);
lock {
assertEquality(6, count);
}
}

type Data record {
string id = fn();
string name;
};

type OpenData record {|
string name;
string...;
|};

isolated function fn() returns string {
panic error("shouldn't be called");
}

public function testSpreadOverrideDefault() {
OpenData or = {name: "May", "id": "A1234"};
Data emp = {...or};
assertEquality("A1234", emp.id);
}

const ASSERTION_ERROR_REASON = "AssertionError";

function assertEquality(any|error expected, any|error actual) {
isolated function assertEquality(any|error expected, any|error actual) {
if expected is anydata && actual is anydata && expected == actual {
return;
}
Expand Down

0 comments on commit da4eb4e

Please sign in to comment.