Skip to content

Commit

Permalink
Add constant mapping using semantic model
Browse files Browse the repository at this point in the history
  • Loading branch information
lnash94 committed Nov 13, 2024
1 parent 9f9d3b6 commit 6d672f4
Show file tree
Hide file tree
Showing 19 changed files with 617 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import io.ballerina.compiler.api.SemanticModel;
import io.ballerina.compiler.syntax.tree.ClassDefinitionNode;
import io.ballerina.compiler.syntax.tree.ConstantDeclarationNode;
import io.ballerina.compiler.syntax.tree.ListenerDeclarationNode;
import io.ballerina.compiler.syntax.tree.Node;
import io.ballerina.compiler.syntax.tree.NodeVisitor;
Expand All @@ -45,7 +44,6 @@ public class ModuleMemberVisitor extends NodeVisitor {
Set<ListenerDeclarationNode> listenerDeclarationNodes = new LinkedHashSet<>();
Set<ClassDefinitionNode> interceptorServiceClassNodes = new LinkedHashSet<>();
Set<ServiceContractType> serviceContractTypeNodes = new LinkedHashSet<>();
Set<ConstantDeclarationNode> constantDeclarationNodes = new LinkedHashSet<>();
SemanticModel semanticModel;

public ModuleMemberVisitor(SemanticModel semanticModel) {
Expand All @@ -72,11 +70,6 @@ public void visit(ClassDefinitionNode classDefinitionNode) {
interceptorServiceClassNodes.add(classDefinitionNode);
}

@Override
public void visit(ConstantDeclarationNode constantDeclarationNode) {
constantDeclarationNodes.add(constantDeclarationNode);
}

public Set<ListenerDeclarationNode> getListenerDeclarationNodes() {
return listenerDeclarationNodes;
}
Expand All @@ -99,16 +92,6 @@ public Optional<ClassDefinitionNode> getInterceptorServiceClassNode(String typeN
return Optional.empty();
}

public Optional<ConstantDeclarationNode> getConstantDeclarationNode(String constantName) {
for (ConstantDeclarationNode constantDeclarationNode : constantDeclarationNodes) {
if (MapperCommonUtils.unescapeIdentifier(constantDeclarationNode.variableName()
.text()).equals(constantName)) {
return Optional.of(constantDeclarationNode);
}
}
return Optional.empty();
}

public Set<ServiceContractType> getServiceContractTypeNodes() {
return serviceContractTypeNodes;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@
package io.ballerina.openapi.service.mapper.parameter;

import io.ballerina.compiler.api.SemanticModel;
import io.ballerina.compiler.api.symbols.ConstantSymbol;
import io.ballerina.compiler.api.symbols.Symbol;
import io.ballerina.compiler.api.symbols.TypeSymbol;
import io.ballerina.compiler.api.values.ConstantValue;
import io.ballerina.compiler.syntax.tree.DefaultableParameterNode;
import io.ballerina.compiler.syntax.tree.ExpressionNode;
import io.ballerina.compiler.syntax.tree.SimpleNameReferenceNode;
import io.ballerina.openapi.service.mapper.model.ModuleMemberVisitor;
import io.ballerina.openapi.service.mapper.model.OperationInventory;
import io.ballerina.openapi.service.mapper.utils.MapperCommonUtils;
import io.swagger.v3.oas.models.parameters.Parameter;

import java.util.List;
import java.util.Objects;

import static io.ballerina.openapi.service.mapper.utils.MapperCommonUtils.getExpressionNodeForConstantDeclaration;
import java.util.Optional;

/**
* This {@link AbstractParameterMapper} class represents the abstract parameter mapper.
Expand Down Expand Up @@ -59,12 +59,14 @@ public void setParameter() throws ParameterMapperException {
parameterList.forEach(operationInventory::setParameter);
}

static Object getDefaultValue(DefaultableParameterNode parameterNode, ModuleMemberVisitor moduleMemberVisitor) {
static Object getDefaultValue(DefaultableParameterNode parameterNode, SemanticModel semanticModel) {
ExpressionNode defaultValueExpression = (ExpressionNode) parameterNode.expression();
if (defaultValueExpression instanceof SimpleNameReferenceNode reference) {
defaultValueExpression = getExpressionNodeForConstantDeclaration(
moduleMemberVisitor,
defaultValueExpression, reference);
Optional<Symbol> symbol = semanticModel.symbol(defaultValueExpression);
if (symbol.isPresent() && symbol.get() instanceof ConstantSymbol constantSymbol) {
Object constValue = constantSymbol.constValue();
if (constValue instanceof ConstantValue value) {
return value.value();
}
}
if (MapperCommonUtils.isNotSimpleValueLiteralKind(defaultValueExpression.kind())) {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public HeaderParameterMapper(ParameterNode parameterNode, Map<String, String> ap
this.treatNilableAsOptional = treatNilableAsOptional;
if (parameterNode instanceof DefaultableParameterNode defaultableHeaderParam) {
this.defaultValue = AbstractParameterMapper.getDefaultValue(defaultableHeaderParam,
additionalData.moduleMemberVisitor());
additionalData.semanticModel());
}
this.typeMapper = typeMapper;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ public QueryParameterMapper(ParameterNode parameterNode, Map<String, String> api
this.typeMapper = typeMapper;
if (parameterNode instanceof DefaultableParameterNode defaultableQueryParam) {
this.defaultValue = AbstractParameterMapper.getDefaultValue(defaultableQueryParam,
additionalData.moduleMemberVisitor());
additionalData.semanticModel());
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,22 @@
*/
package io.ballerina.openapi.service.mapper.type;

import io.ballerina.compiler.api.SemanticModel;
import io.ballerina.compiler.api.symbols.ConstantSymbol;
import io.ballerina.compiler.api.symbols.IntersectionTypeSymbol;
import io.ballerina.compiler.api.symbols.RecordFieldSymbol;
import io.ballerina.compiler.api.symbols.RecordTypeSymbol;
import io.ballerina.compiler.api.symbols.Symbol;
import io.ballerina.compiler.api.symbols.TypeDescKind;
import io.ballerina.compiler.api.symbols.TypeReferenceTypeSymbol;
import io.ballerina.compiler.api.symbols.TypeSymbol;
import io.ballerina.compiler.api.symbols.UnionTypeSymbol;
import io.ballerina.compiler.api.values.ConstantValue;
import io.ballerina.compiler.syntax.tree.ExpressionNode;
import io.ballerina.compiler.syntax.tree.Node;
import io.ballerina.compiler.syntax.tree.NodeList;
import io.ballerina.compiler.syntax.tree.RecordFieldWithDefaultValueNode;
import io.ballerina.compiler.syntax.tree.RecordTypeDescriptorNode;
import io.ballerina.compiler.syntax.tree.SimpleNameReferenceNode;
import io.ballerina.compiler.syntax.tree.TypeDefinitionNode;
import io.ballerina.openapi.service.mapper.diagnostic.DiagnosticMessages;
import io.ballerina.openapi.service.mapper.diagnostic.ExceptionDiagnostic;
Expand All @@ -49,7 +52,6 @@
import java.util.Optional;
import java.util.Set;

import static io.ballerina.openapi.service.mapper.utils.MapperCommonUtils.getExpressionNodeForConstantDeclaration;
import static io.ballerina.openapi.service.mapper.utils.MapperCommonUtils.getRecordFieldTypeDescription;
import static io.ballerina.openapi.service.mapper.utils.MapperCommonUtils.getTypeName;

Expand Down Expand Up @@ -122,17 +124,12 @@ static List<Schema> mapIncludedRecords(RecordTypeSymbol typeSymbol, Components c
.typeDescriptor();
Map<String, RecordFieldSymbol> includedRecordFieldMap = includedRecordTypeSymbol.fieldDescriptors();
for (Map.Entry<String, RecordFieldSymbol> includedRecordField : includedRecordFieldMap.entrySet()) {
//logic comes here i guess
RecordFieldSymbol recordFieldSymbol = recordFieldMap.get(includedRecordField.getKey());
RecordFieldSymbol includedRecordFieldValue = includedRecordField.getValue();
if (recordFieldSymbol != null) {
//check for the types
if (includedRecordFieldValue.typeDescriptor().equals(recordFieldSymbol.typeDescriptor())) {
// check for the default values availability
if (!recordFieldSymbol.hasDefaultValue()) {
recordFieldMap.remove(includedRecordField.getKey());
}
}
boolean isRemovableField = recordFieldSymbol != null && includedRecordFieldValue.typeDescriptor()
.equals(recordFieldSymbol.typeDescriptor()) && !recordFieldSymbol.hasDefaultValue();
if (isRemovableField) {
recordFieldMap.remove(includedRecordField.getKey());
}
}
}
Expand Down Expand Up @@ -160,7 +157,7 @@ public static Map<String, Schema> mapRecordFields(Map<String, RecordFieldSymbol>
}
if (recordFieldSymbol.hasDefaultValue()) {
Optional<Object> recordFieldDefaultValueOpt = getRecordFieldDefaultValue(recordName, recordFieldName,
additionalData.moduleMemberVisitor());
additionalData.moduleMemberVisitor(), additionalData.semanticModel());
if (recordFieldDefaultValueOpt.isPresent()) {
TypeMapper.setDefaultValue(recordFieldSchema, recordFieldDefaultValueOpt.get());
} else {
Expand All @@ -175,18 +172,19 @@ public static Map<String, Schema> mapRecordFields(Map<String, RecordFieldSymbol>
}

public static Optional<Object> getRecordFieldDefaultValue(String recordName, String fieldName,
ModuleMemberVisitor moduleMemberVisitor) {
ModuleMemberVisitor moduleMemberVisitor,
SemanticModel semanticModel) {
Optional<TypeDefinitionNode> recordDefNodeOpt = moduleMemberVisitor.getTypeDefinitionNode(recordName);
if (recordDefNodeOpt.isPresent() &&
recordDefNodeOpt.get().typeDescriptor() instanceof RecordTypeDescriptorNode recordDefNode) {
return getRecordFieldDefaultValue(fieldName, recordDefNode, moduleMemberVisitor);
return getRecordFieldDefaultValue(fieldName, recordDefNode, semanticModel);
}
return Optional.empty();
}

private static Optional<Object> getRecordFieldDefaultValue(String fieldName,
RecordTypeDescriptorNode recordDefNode,
ModuleMemberVisitor moduleMemberVisitor) {
SemanticModel semanticModel) {
NodeList<Node> recordFields = recordDefNode.fields();
RecordFieldWithDefaultValueNode defaultValueNode = recordFields.stream()
.filter(field -> field instanceof RecordFieldWithDefaultValueNode)
Expand All @@ -197,10 +195,12 @@ private static Optional<Object> getRecordFieldDefaultValue(String fieldName,
return Optional.empty();
}
ExpressionNode defaultValueExpression = defaultValueNode.expression();
// ConstantDeclaration
if (defaultValueExpression instanceof SimpleNameReferenceNode reference) {
defaultValueExpression = getExpressionNodeForConstantDeclaration(moduleMemberVisitor,
defaultValueExpression, reference);
Optional<Symbol> symbol = semanticModel.symbol(defaultValueExpression);
if (symbol.isPresent() && symbol.get() instanceof ConstantSymbol constantSymbol) {
Object constValue = constantSymbol.constValue();
if (constValue instanceof ConstantValue value) {
return Optional.of(value.value());
}
}
if (MapperCommonUtils.isNotSimpleValueLiteralKind(defaultValueExpression.kind())) {
return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import io.ballerina.compiler.api.symbols.UnionTypeSymbol;
import io.ballerina.compiler.syntax.tree.AnnotationNode;
import io.ballerina.compiler.syntax.tree.BasicLiteralNode;
import io.ballerina.compiler.syntax.tree.ConstantDeclarationNode;
import io.ballerina.compiler.syntax.tree.DefaultableParameterNode;
import io.ballerina.compiler.syntax.tree.DistinctTypeDescriptorNode;
import io.ballerina.compiler.syntax.tree.ExpressionNode;
Expand All @@ -55,15 +54,13 @@
import io.ballerina.compiler.syntax.tree.ResourcePathParameterNode;
import io.ballerina.compiler.syntax.tree.SeparatedNodeList;
import io.ballerina.compiler.syntax.tree.ServiceDeclarationNode;
import io.ballerina.compiler.syntax.tree.SimpleNameReferenceNode;
import io.ballerina.compiler.syntax.tree.SpecificFieldNode;
import io.ballerina.compiler.syntax.tree.SyntaxKind;
import io.ballerina.compiler.syntax.tree.TypeDefinitionNode;
import io.ballerina.openapi.service.mapper.Constants;
import io.ballerina.openapi.service.mapper.diagnostic.DiagnosticMessages;
import io.ballerina.openapi.service.mapper.diagnostic.ExceptionDiagnostic;
import io.ballerina.openapi.service.mapper.diagnostic.OpenAPIMapperDiagnostic;
import io.ballerina.openapi.service.mapper.model.ModuleMemberVisitor;
import io.ballerina.openapi.service.mapper.model.OASResult;
import io.ballerina.openapi.service.mapper.model.ResourceFunction;
import io.ballerina.openapi.service.mapper.model.ResourceFunctionDeclaration;
Expand Down Expand Up @@ -565,16 +562,4 @@ public static Node getTypeDescriptor(TypeDefinitionNode typeDefinitionNode) {
}
return node;
}

public static ExpressionNode getExpressionNodeForConstantDeclaration(ModuleMemberVisitor moduleMemberVisitor,
ExpressionNode defaultValueExpression,
SimpleNameReferenceNode reference) {
Optional<ConstantDeclarationNode> constantDeclarationNode = moduleMemberVisitor
.getConstantDeclarationNode(reference.name().text());
if (constantDeclarationNode.isPresent()) {
ConstantDeclarationNode constantNode = constantDeclarationNode.get();
defaultValueExpression = (ExpressionNode) constantNode.initializer();
}
return defaultValueExpression;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,12 @@ public void testRestFieldInRecord() throws IOException {
TestUtils.compareWithGeneratedFile(ballerinaFilePath, "record/record_rest_param.yaml");
}

@Test(description = "Test for record has included record with same fields")
public void testIncludedRecordWithSameFields() throws IOException {
Path ballerinaFilePath = RES_DIR.resolve("record/included_record.bal");
TestUtils.compareWithGeneratedFile(ballerinaFilePath, "record/included_record.yaml");
}

@AfterMethod
public void cleanUp() {
TestUtils.deleteDirectory(this.tempDir);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ paths:
default: Pod
- name: header31
in: header
required: true
schema:
type: string
default: Pod
- name: header32
in: header
required: true
Expand All @@ -152,6 +152,10 @@ paths:
items:
type: integer
format: int64
default:
- 1
- 2
- 3
responses:
"200":
description: Ok
Expand All @@ -164,7 +168,7 @@ paths:
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorPayload'
$ref: "#/components/schemas/ErrorPayload"
components:
schemas:
ErrorPayload:
Expand Down
Loading

0 comments on commit 6d672f4

Please sign in to comment.