diff --git a/compiler/plc_ast/src/ast.rs b/compiler/plc_ast/src/ast.rs index 353fc510a7..3f80a0e053 100644 --- a/compiler/plc_ast/src/ast.rs +++ b/compiler/plc_ast/src/ast.rs @@ -11,7 +11,7 @@ use crate::{ control_statements::{ AstControlStatement, CaseStatement, ConditionalBlock, ForLoopStatement, IfStatement, LoopStatement, }, - literals::{Array, AstLiteral, StringValue}, + literals::{AstLiteral, StringValue}, pre_processor, provider::IdProvider, }; @@ -200,7 +200,7 @@ pub struct Implementation { pub type_name: String, pub linkage: LinkageType, pub pou_type: PouType, - pub statements: Vec, + pub statements: Vec, pub location: SourceLocation, pub name_location: SourceLocation, pub overriding: bool, @@ -349,8 +349,8 @@ impl Debug for VariableBlock { pub struct Variable { pub name: String, pub data_type_declaration: DataTypeDeclaration, - pub initializer: Option, - pub address: Option, + pub initializer: Option, + pub address: Option, pub location: SourceLocation, } @@ -383,7 +383,7 @@ pub trait DiagnosticInfo { fn get_location(&self) -> SourceLocation; } -impl DiagnosticInfo for AstStatement { +impl DiagnosticInfo for AstNode { fn get_description(&self) -> String { format!("{self:?}") } @@ -436,7 +436,7 @@ impl DataTypeDeclaration { #[derive(PartialEq)] pub struct UserTypeDeclaration { pub data_type: DataType, - pub initializer: Option, + pub initializer: Option, pub location: SourceLocation, /// stores the original scope for compiler-generated types pub scope: Option, @@ -461,16 +461,16 @@ pub enum DataType { EnumType { name: Option, //maybe empty for inline enums numeric_type: String, - elements: AstStatement, //a single Ref, or an ExpressionList with Refs + elements: AstNode, //a single Ref, or an ExpressionList with Refs }, SubRangeType { name: Option, referenced_type: String, - bounds: Option, + bounds: Option, }, ArrayType { name: Option, - bounds: AstStatement, + bounds: AstNode, referenced_type: Box, is_variable_length: bool, }, @@ -481,7 +481,7 @@ pub enum DataType { StringType { name: Option, is_wide: bool, //WSTRING - size: Option, + size: Option, }, VarArgs { referenced_type: Option>, @@ -557,15 +557,15 @@ pub enum ReferenceAccess { /** * a, a.b */ - Member(Box), + Member(Box), /** * a[3] */ - Index(Box), + Index(Box), /** * Color#Red */ - Cast(Box), + Cast(Box), /** * a^ */ @@ -576,176 +576,91 @@ pub enum ReferenceAccess { Address, } +#[derive(Clone, PartialEq)] +pub struct AstNode { + pub stmt: AstStatement, + pub id: AstId, + pub location: SourceLocation, +} + #[derive(Clone, PartialEq)] pub enum AstStatement { - EmptyStatement { - location: SourceLocation, - id: AstId, - }, + EmptyStatement(EmptyStatement), // a placeholder that indicates a default value of a datatype - DefaultValue { - location: SourceLocation, - id: AstId, - }, + DefaultValue(DefaultValue), // Literals - Literal { - kind: AstLiteral, - location: SourceLocation, - id: AstId, - }, - - CastStatement { - target: Box, - type_name: String, - location: SourceLocation, - id: AstId, - }, - MultipliedStatement { - multiplier: u32, - element: Box, - location: SourceLocation, - id: AstId, - }, + Literal(AstLiteral), + CastStatement(CastStatement), + MultipliedStatement(MultipliedStatement), // Expressions - ReferenceExpr { - access: ReferenceAccess, - base: Option>, - id: AstId, - location: SourceLocation, - }, - Identifier { - name: String, - location: SourceLocation, - id: AstId, - }, - DirectAccess { - access: DirectAccessType, - index: Box, - location: SourceLocation, - id: AstId, - }, - HardwareAccess { - direction: HardwareAccessType, - access: DirectAccessType, - address: Vec, - location: SourceLocation, - id: AstId, - }, - BinaryExpression { - operator: Operator, - left: Box, - right: Box, - id: AstId, - }, - UnaryExpression { - operator: Operator, - value: Box, - location: SourceLocation, - id: AstId, - }, - ExpressionList { - expressions: Vec, - id: AstId, - }, - RangeStatement { - id: AstId, - start: Box, - end: Box, - }, - VlaRangeStatement { - id: AstId, - }, + ReferenceExpr(ReferenceExpr), + Identifier(String), + DirectAccess(DirectAccess), + HardwareAccess(HardwareAccess), + BinaryExpression(BinaryExpression), + UnaryExpression(UnaryExpression), + ExpressionList(Vec), + RangeStatement(RangeStatement), + VlaRangeStatement, // Assignment - Assignment { - left: Box, - right: Box, - id: AstId, - }, + Assignment(Assignment), // OutputAssignment - OutputAssignment { - left: Box, - right: Box, - id: AstId, - }, + OutputAssignment(Assignment), //Call Statement - CallStatement { - operator: Box, - parameters: Box>, - location: SourceLocation, - id: AstId, - }, + CallStatement(CallStatement), // Control Statements - ControlStatement { - kind: AstControlStatement, - location: SourceLocation, - id: AstId, - }, + ControlStatement(AstControlStatement), - CaseCondition { - condition: Box, - id: AstId, - }, - ExitStatement { - location: SourceLocation, - id: AstId, - }, - ContinueStatement { - location: SourceLocation, - id: AstId, - }, - ReturnStatement { - location: SourceLocation, - id: AstId, - }, + CaseCondition(Box), + ExitStatement(()), + ContinueStatement(()), + ReturnStatement(()), } -impl Debug for AstStatement { +impl Debug for AstNode { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - AstStatement::EmptyStatement { .. } => f.debug_struct("EmptyStatement").finish(), - AstStatement::DefaultValue { .. } => f.debug_struct("DefaultValue").finish(), - AstStatement::Literal { kind, .. } => kind.fmt(f), - AstStatement::Identifier { name, .. } => { - f.debug_struct("Identifier").field("name", name).finish() - } - AstStatement::BinaryExpression { operator, left, right, .. } => f + match &self.stmt { + AstStatement::EmptyStatement(..) => f.debug_struct("EmptyStatement").finish(), + AstStatement::DefaultValue(..) => f.debug_struct("DefaultValue").finish(), + AstStatement::Literal(literal) => literal.fmt(f), + AstStatement::Identifier(name) => f.debug_struct("Identifier").field("name", name).finish(), + AstStatement::BinaryExpression(BinaryExpression { operator, left, right }) => f .debug_struct("BinaryExpression") .field("operator", operator) .field("left", left) .field("right", right) .finish(), - AstStatement::UnaryExpression { operator, value, .. } => { + AstStatement::UnaryExpression(UnaryExpression { operator, value }) => { f.debug_struct("UnaryExpression").field("operator", operator).field("value", value).finish() } - AstStatement::ExpressionList { expressions, .. } => { + AstStatement::ExpressionList(expressions) => { f.debug_struct("ExpressionList").field("expressions", expressions).finish() } - AstStatement::RangeStatement { start, end, .. } => { + AstStatement::RangeStatement(RangeStatement { start, end }) => { f.debug_struct("RangeStatement").field("start", start).field("end", end).finish() } - AstStatement::VlaRangeStatement { .. } => f.debug_struct("VlaRangeStatement").finish(), - AstStatement::Assignment { left, right, .. } => { + AstStatement::VlaRangeStatement => f.debug_struct("VlaRangeStatement").finish(), + AstStatement::Assignment(Assignment { left, right }) => { f.debug_struct("Assignment").field("left", left).field("right", right).finish() } - AstStatement::OutputAssignment { left, right, .. } => { + AstStatement::OutputAssignment(Assignment { left, right }) => { f.debug_struct("OutputAssignment").field("left", left).field("right", right).finish() } - AstStatement::CallStatement { operator, parameters, .. } => f + AstStatement::CallStatement(CallStatement { operator, parameters }) => f .debug_struct("CallStatement") .field("operator", operator) .field("parameters", parameters) .finish(), - AstStatement::ControlStatement { - kind: AstControlStatement::If(IfStatement { blocks, else_block, .. }), - .. - } => { + AstStatement::ControlStatement( + AstControlStatement::If(IfStatement { blocks, else_block, .. }), + .., + ) => { f.debug_struct("IfStatement").field("blocks", blocks).field("else_block", else_block).finish() } - AstStatement::ControlStatement { - kind: - AstControlStatement::ForLoop(ForLoopStatement { counter, start, end, by_step, body, .. }), - .. - } => f + AstStatement::ControlStatement( + AstControlStatement::ForLoop(ForLoopStatement { counter, start, end, by_step, body, .. }), + .., + ) => f .debug_struct("ForLoopStatement") .field("counter", counter) .field("start", start) @@ -753,155 +668,106 @@ impl Debug for AstStatement { .field("by_step", by_step) .field("body", body) .finish(), - AstStatement::ControlStatement { - kind: AstControlStatement::WhileLoop(LoopStatement { condition, body, .. }), - .. - } => f + AstStatement::ControlStatement( + AstControlStatement::WhileLoop(LoopStatement { condition, body, .. }), + .., + ) => f .debug_struct("WhileLoopStatement") .field("condition", condition) .field("body", body) .finish(), - AstStatement::ControlStatement { - kind: AstControlStatement::RepeatLoop(LoopStatement { condition, body, .. }), + AstStatement::ControlStatement(AstControlStatement::RepeatLoop(LoopStatement { + condition, + body, .. - } => f + })) => f .debug_struct("RepeatLoopStatement") .field("condition", condition) .field("body", body) .finish(), - AstStatement::ControlStatement { - kind: AstControlStatement::Case(CaseStatement { selector, case_blocks, else_block, .. }), + AstStatement::ControlStatement(AstControlStatement::Case(CaseStatement { + selector, + case_blocks, + else_block, .. - } => f + })) => f .debug_struct("CaseStatement") .field("selector", selector) .field("case_blocks", case_blocks) .field("else_block", else_block) .finish(), - AstStatement::DirectAccess { access, index, .. } => { + AstStatement::DirectAccess(DirectAccess { access, index }) => { f.debug_struct("DirectAccess").field("access", access).field("index", index).finish() } - AstStatement::HardwareAccess { direction, access, address, location, .. } => f + AstStatement::HardwareAccess(HardwareAccess { direction, access, address }) => f .debug_struct("HardwareAccess") .field("direction", direction) .field("access", access) .field("address", address) - .field("location", location) + .field("location", &self.location) .finish(), - AstStatement::MultipliedStatement { multiplier, element, .. } => f + AstStatement::MultipliedStatement(MultipliedStatement { multiplier, element }, ..) => f .debug_struct("MultipliedStatement") .field("multiplier", multiplier) .field("element", element) .finish(), - AstStatement::CaseCondition { condition, .. } => { + AstStatement::CaseCondition(condition) => { f.debug_struct("CaseCondition").field("condition", condition).finish() } - AstStatement::ReturnStatement { .. } => f.debug_struct("ReturnStatement").finish(), - AstStatement::ContinueStatement { .. } => f.debug_struct("ContinueStatement").finish(), - AstStatement::ExitStatement { .. } => f.debug_struct("ExitStatement").finish(), - AstStatement::CastStatement { target, type_name, .. } => { + AstStatement::ReturnStatement(..) => f.debug_struct("ReturnStatement").finish(), + AstStatement::ContinueStatement(..) => f.debug_struct("ContinueStatement").finish(), + AstStatement::ExitStatement(..) => f.debug_struct("ExitStatement").finish(), + AstStatement::CastStatement(CastStatement { target, type_name }) => { f.debug_struct("CastStatement").field("type_name", type_name).field("target", target).finish() } - AstStatement::ReferenceExpr { access, base, .. } => { + AstStatement::ReferenceExpr(ReferenceExpr { access, base }) => { f.debug_struct("ReferenceExpr").field("kind", access).field("base", base).finish() } } } } -impl AstStatement { +impl AstNode { ///Returns the statement in a singleton list, or the contained statements if the statement is already a list - pub fn get_as_list(&self) -> Vec<&AstStatement> { - if let AstStatement::ExpressionList { expressions, .. } = self { - expressions.iter().collect::>() + pub fn get_as_list(&self) -> Vec<&AstNode> { + if let AstStatement::ExpressionList(expressions) = &self.stmt { + expressions.iter().collect::>() } else { vec![self] } } + pub fn get_location(&self) -> SourceLocation { - match self { - AstStatement::EmptyStatement { location, .. } => location.clone(), - AstStatement::DefaultValue { location, .. } => location.clone(), - AstStatement::Literal { location, .. } => location.clone(), - AstStatement::Identifier { location, .. } => location.clone(), - AstStatement::BinaryExpression { left, right, .. } => { - let left_loc = left.get_location(); - let right_loc = right.get_location(); - left_loc.span(&right_loc) - } - AstStatement::UnaryExpression { location, .. } => location.clone(), - AstStatement::ExpressionList { expressions, .. } => { - let first = - expressions.first().map_or_else(SourceLocation::undefined, |it| it.get_location()); - let last = expressions.last().map_or_else(SourceLocation::undefined, |it| it.get_location()); - first.span(&last) - } - AstStatement::RangeStatement { start, end, .. } => { - let start_loc = start.get_location(); - let end_loc = end.get_location(); - start_loc.span(&end_loc) - } - AstStatement::VlaRangeStatement { .. } => SourceLocation::undefined(), // internal type only - AstStatement::Assignment { left, right, .. } => { - let left_loc = left.get_location(); - let right_loc = right.get_location(); - left_loc.span(&right_loc) - } - AstStatement::OutputAssignment { left, right, .. } => { - let left_loc = left.get_location(); - let right_loc = right.get_location(); - left_loc.span(&right_loc) - } - AstStatement::CallStatement { location, .. } => location.clone(), - AstStatement::ControlStatement { location, .. } => location.clone(), - AstStatement::DirectAccess { location, .. } => location.clone(), - AstStatement::HardwareAccess { location, .. } => location.clone(), - AstStatement::MultipliedStatement { location, .. } => location.clone(), - AstStatement::CaseCondition { condition, .. } => condition.get_location(), - AstStatement::ReturnStatement { location, .. } => location.clone(), - AstStatement::ContinueStatement { location, .. } => location.clone(), - AstStatement::ExitStatement { location, .. } => location.clone(), - AstStatement::CastStatement { location, .. } => location.clone(), - AstStatement::ReferenceExpr { location, .. } => location.clone(), - } + self.location.clone() + } + + pub fn set_location(&mut self, location: SourceLocation) { + self.location = location; } pub fn get_id(&self) -> AstId { - match self { - AstStatement::EmptyStatement { id, .. } => *id, - AstStatement::DefaultValue { id, .. } => *id, - AstStatement::Literal { id, .. } => *id, - AstStatement::MultipliedStatement { id, .. } => *id, - AstStatement::Identifier { id, .. } => *id, - AstStatement::DirectAccess { id, .. } => *id, - AstStatement::HardwareAccess { id, .. } => *id, - AstStatement::BinaryExpression { id, .. } => *id, - AstStatement::UnaryExpression { id, .. } => *id, - AstStatement::ExpressionList { id, .. } => *id, - AstStatement::RangeStatement { id, .. } => *id, - AstStatement::VlaRangeStatement { id, .. } => *id, - AstStatement::Assignment { id, .. } => *id, - AstStatement::OutputAssignment { id, .. } => *id, - AstStatement::CallStatement { id, .. } => *id, - AstStatement::ControlStatement { id, .. } => *id, - AstStatement::CaseCondition { id, .. } => *id, - AstStatement::ReturnStatement { id, .. } => *id, - AstStatement::ContinueStatement { id, .. } => *id, - AstStatement::ExitStatement { id, .. } => *id, - AstStatement::CastStatement { id, .. } => *id, - AstStatement::ReferenceExpr { id, .. } => *id, - } + self.id + } + + pub fn get_stmt(&self) -> &AstStatement { + &self.stmt } /// Returns true if the current statement has a direct access. pub fn has_direct_access(&self) -> bool { - match self { - AstStatement::ReferenceExpr { access: ReferenceAccess::Member(reference), base, .. } - | AstStatement::ReferenceExpr { access: ReferenceAccess::Cast(reference), base, .. } => { + match &self.stmt { + AstStatement::ReferenceExpr( + ReferenceExpr { access: ReferenceAccess::Member(reference), base }, + .., + ) + | AstStatement::ReferenceExpr( + ReferenceExpr { access: ReferenceAccess::Cast(reference), base }, + .., + ) => { reference.has_direct_access() || base.as_ref().map(|it| it.has_direct_access()).unwrap_or(false) } - AstStatement::DirectAccess { .. } => true, + AstStatement::DirectAccess(..) => true, _ => false, } } @@ -910,23 +776,22 @@ impl AstStatement { /// prefixed with a type-cast (e.g. INT#23) pub fn is_cast_prefix_eligible(&self) -> bool { // TODO: figure out a better name for this... - match self { - AstStatement::Literal { kind, .. } => kind.is_cast_prefix_eligible(), - AstStatement::Identifier { .. } => true, + match &self.stmt { + AstStatement::Literal(kind, ..) => kind.is_cast_prefix_eligible(), + AstStatement::Identifier(..) => true, _ => false, } } /// Returns true if the current statement is a flat reference (e.g. `a`) pub fn is_flat_reference(&self) -> bool { - matches!(self, AstStatement::Identifier { .. }) || { - if let AstStatement::ReferenceExpr { - access: ReferenceAccess::Member(reference), - base: None, - .. - } = self + matches!(self.stmt, AstStatement::Identifier(..)) || { + if let AstStatement::ReferenceExpr( + ReferenceExpr { access: ReferenceAccess::Member(reference), base: None }, + .., + ) = &self.stmt { - matches!(reference.as_ref(), AstStatement::Identifier { .. }) + matches!(reference.as_ref().stmt, AstStatement::Identifier(..)) } else { false } @@ -935,37 +800,50 @@ impl AstStatement { /// Returns the reference-name if this is a flat reference like `a`, or None if this is no flat reference pub fn get_flat_reference_name(&self) -> Option<&str> { - match self { - AstStatement::ReferenceExpr { access: ReferenceAccess::Member(reference), .. } => { - if let AstStatement::Identifier { name, .. } = reference.as_ref() { + match &self.stmt { + AstStatement::ReferenceExpr( + ReferenceExpr { access: ReferenceAccess::Member(reference), .. }, + .., + ) => { + if let AstStatement::Identifier(name, ..) = &reference.as_ref().stmt { Some(name) } else { None } } - AstStatement::Identifier { name, .. } => Some(name), + AstStatement::Identifier(name, ..) => Some(name), _ => None, } } + pub fn is_empty_statement(&self) -> bool { + matches!(self.stmt, AstStatement::EmptyStatement(..)) + } + pub fn is_reference(&self) -> bool { - matches!(self, AstStatement::ReferenceExpr { .. }) + matches!(self.stmt, AstStatement::ReferenceExpr(..)) } pub fn is_hardware_access(&self) -> bool { - matches!(self, AstStatement::HardwareAccess { .. }) + matches!(self.stmt, AstStatement::HardwareAccess(..)) } pub fn is_array_access(&self) -> bool { - matches!(self, AstStatement::ReferenceExpr { access: ReferenceAccess::Index(_), .. }) + matches!( + self.stmt, + AstStatement::ReferenceExpr(ReferenceExpr { access: ReferenceAccess::Index(_), .. }, ..) + ) } pub fn is_pointer_access(&self) -> bool { - matches!(self, AstStatement::ReferenceExpr { access: ReferenceAccess::Deref, .. }) + matches!( + self.stmt, + AstStatement::ReferenceExpr(ReferenceExpr { access: ReferenceAccess::Deref, .. }, ..) + ) } pub fn is_expression_list(&self) -> bool { - matches!(self, AstStatement::ExpressionList { .. }) + matches!(self.stmt, AstStatement::ExpressionList { .. }) } pub fn can_be_assigned_to(&self) -> bool { @@ -977,30 +855,39 @@ impl AstStatement { || self.is_hardware_access() } - pub fn new_literal(kind: AstLiteral, id: AstId, location: SourceLocation) -> Self { - AstStatement::Literal { kind, id, location } + pub fn new(stmt: AstStatement, id: AstId, location: SourceLocation) -> AstNode { + AstNode { stmt, id, location } } - pub fn new_integer(value: i128, id: AstId, location: SourceLocation) -> Self { - AstStatement::Literal { kind: AstLiteral::Integer(value), location, id } + pub fn new_literal(kind: AstLiteral, id: AstId, location: SourceLocation) -> AstNode { + AstNode::new(AstStatement::Literal(kind), id, location) } - pub fn new_real(value: String, id: AstId, location: SourceLocation) -> Self { - AstStatement::Literal { kind: AstLiteral::Real(value), location, id } + pub fn new_integer(value: i128, id: AstId, location: SourceLocation) -> AstNode { + AstNode::new(AstStatement::Literal(AstLiteral::Integer(value)), id, location) } - pub fn new_string(value: impl Into, is_wide: bool, id: AstId, location: SourceLocation) -> Self { - AstStatement::Literal { - kind: AstLiteral::String(StringValue { value: value.into(), is_wide }), - location, + pub fn new_real(value: String, id: AstId, location: SourceLocation) -> AstNode { + AstNode::new(AstStatement::Literal(AstLiteral::Real(value)), id, location) + } + + pub fn new_string( + value: impl Into, + is_wide: bool, + id: AstId, + location: SourceLocation, + ) -> AstNode { + AstNode::new( + AstStatement::Literal(AstLiteral::String(StringValue { value: value.into(), is_wide })), id, - } + location, + ) } /// Returns true if the given token is an integer or float and zero. pub fn is_zero(&self) -> bool { - match self { - AstStatement::Literal { kind, .. } => match kind { + match &self.stmt { + AstStatement::Literal(kind, ..) => match kind { AstLiteral::Integer(0) => true, AstLiteral::Real(val) => val == "0" || val == "0.0", _ => false, @@ -1011,74 +898,23 @@ impl AstStatement { } pub fn is_binary_expression(&self) -> bool { - matches!(self, AstStatement::BinaryExpression { .. }) + matches!(self.stmt, AstStatement::BinaryExpression(..)) } pub fn is_literal_array(&self) -> bool { - matches!(self, AstStatement::Literal { kind: AstLiteral::Array(..), .. }) + matches!(self.stmt, AstStatement::Literal(AstLiteral::Array(..), ..)) } pub fn is_literal(&self) -> bool { - matches!(self, AstStatement::Literal { .. }) + matches!(self.stmt, AstStatement::Literal(..)) } - pub fn set_location(self, new_location: SourceLocation) -> Self { - match self { - AstStatement::EmptyStatement { location: _, id } => { - AstStatement::EmptyStatement { location: new_location, id } - } - AstStatement::DefaultValue { location: _, id } => { - AstStatement::DefaultValue { location: new_location, id } - } - AstStatement::Literal { kind, location: _, id } => { - AstStatement::Literal { kind, location: new_location, id } - } - AstStatement::CastStatement { target, type_name, location: _, id } => { - AstStatement::CastStatement { target, type_name, location: new_location, id } - } - AstStatement::MultipliedStatement { multiplier, element, location: _, id } => { - AstStatement::MultipliedStatement { multiplier, element, location: new_location, id } - } - AstStatement::DirectAccess { access, index, location: _, id } => { - AstStatement::DirectAccess { access, index, location: new_location, id } - } - AstStatement::HardwareAccess { direction, access, address, location: _, id } => { - AstStatement::HardwareAccess { direction, access, address, location: new_location, id } - } - AstStatement::UnaryExpression { operator, value, location: _, id } => { - AstStatement::UnaryExpression { operator, value, location: new_location, id } - } - AstStatement::CallStatement { operator, parameters, location: _, id } => { - AstStatement::CallStatement { operator, parameters, location: new_location, id } - } - AstStatement::ControlStatement { kind, location: _, id } => { - AstStatement::ControlStatement { kind, location: new_location, id } - } - AstStatement::ExitStatement { location: _, id } => { - AstStatement::ExitStatement { location: new_location, id } - } - AstStatement::ContinueStatement { location: _, id } => { - AstStatement::ContinueStatement { location: new_location, id } - } - AstStatement::ReturnStatement { location: _, id } => { - AstStatement::ReturnStatement { location: new_location, id } - } - AstStatement::ReferenceExpr { access, base, id, location: _ } => { - Self::ReferenceExpr { access, base, id, location: new_location } - } - AstStatement::Identifier { name, location: _, id } => { - AstStatement::Identifier { name, location: new_location, id } - } - _ => self, - } + pub fn is_identifier(&self) -> bool { + matches!(self.stmt, AstStatement::Identifier(..)) } - pub fn get_literal_array(&self) -> Option<&Array> { - if let AstStatement::Literal { kind: AstLiteral::Array(array), .. } = self { - return Some(array); - } - - None + pub fn is_default_value(&self) -> bool { + matches!(self.stmt, AstStatement::DefaultValue { .. }) } } @@ -1120,19 +956,19 @@ impl Display for Operator { /// enum_elements should be the statement between then enum's brackets ( ) /// e.g. x : ( this, that, etc) -pub fn get_enum_element_names(enum_elements: &AstStatement) -> Vec { +pub fn get_enum_element_names(enum_elements: &AstNode) -> Vec { flatten_expression_list(enum_elements) .into_iter() - .filter(|it| matches!(it, AstStatement::Identifier { .. } | AstStatement::Assignment { .. })) + .filter(|it| matches!(it.stmt, AstStatement::Identifier(..) | AstStatement::Assignment(..))) .map(get_enum_element_name) .collect() } /// expects a Reference or an Assignment -pub fn get_enum_element_name(enum_element: &AstStatement) -> String { - match enum_element { - AstStatement::Identifier { name, .. } => name.to_string(), - AstStatement::Assignment { left, .. } => left +pub fn get_enum_element_name(enum_element: &AstNode) -> String { + match &enum_element.stmt { + AstStatement::Identifier(name, ..) => name.to_string(), + AstStatement::Assignment(Assignment { left, .. }, ..) => left .get_flat_reference_name() .map(|it| it.to_string()) .expect("left of assignment not a reference"), @@ -1144,12 +980,12 @@ pub fn get_enum_element_name(enum_element: &AstStatement) -> String { /// flattens expression-lists and MultipliedStatements into a vec of statements. /// It can also handle nested structures like 2(3(4,5)) -pub fn flatten_expression_list(list: &AstStatement) -> Vec<&AstStatement> { - match list { - AstStatement::ExpressionList { expressions, .. } => { +pub fn flatten_expression_list(list: &AstNode) -> Vec<&AstNode> { + match &list.stmt { + AstStatement::ExpressionList(expressions, ..) => { expressions.iter().by_ref().flat_map(flatten_expression_list).collect() } - AstStatement::MultipliedStatement { multiplier, element, .. } => { + AstStatement::MultipliedStatement(MultipliedStatement { multiplier, element }, ..) => { std::iter::repeat(flatten_expression_list(element)).take(*multiplier as usize).flatten().collect() } _ => vec![list], @@ -1217,19 +1053,66 @@ mod tests { pub struct AstFactory {} impl AstFactory { - pub fn empty_statement(location: SourceLocation, id: AstId) -> AstStatement { - AstStatement::EmptyStatement { location, id } + pub fn create_empty_statement(location: SourceLocation, id: AstId) -> AstNode { + AstNode { stmt: AstStatement::EmptyStatement(EmptyStatement {}), location, id } + // AstStatement::EmptyStatement ( EmptyStatement {}, location, id } + } + + pub fn create_return_statement(location: SourceLocation, id: AstId) -> AstNode { + AstNode { stmt: AstStatement::ReturnStatement(()), location, id } + } + + pub fn create_exit_statement(location: SourceLocation, id: AstId) -> AstNode { + AstNode { stmt: AstStatement::ExitStatement(()), location, id } + } + + pub fn create_continue_statement(location: SourceLocation, id: AstId) -> AstNode { + AstNode { stmt: AstStatement::ContinueStatement(()), location, id } + } + + pub fn create_case_condition(result: AstNode, location: SourceLocation, id: AstId) -> AstNode { + AstNode { stmt: AstStatement::CaseCondition(Box::new(result)), id, location } + } + + pub fn create_vla_range_statement(location: SourceLocation, id: AstId) -> AstNode { + AstNode { stmt: AstStatement::VlaRangeStatement, id, location } + } + + pub fn create_literal(kind: AstLiteral, location: SourceLocation, id: AstId) -> AstNode { + AstNode { stmt: AstStatement::Literal(kind), id, location } + } + + pub fn create_hardware_access( + access: DirectAccessType, + direction: HardwareAccessType, + address: Vec, + location: SourceLocation, + id: usize, + ) -> AstNode { + AstNode { + stmt: AstStatement::HardwareAccess(HardwareAccess { access, direction, address }), + location, + id, + } + } + + pub fn create_default_value(location: SourceLocation, id: AstId) -> AstNode { + AstNode { stmt: AstStatement::DefaultValue(DefaultValue {}), location, id } + } + + pub fn create_expression_list(expressions: Vec, location: SourceLocation, id: AstId) -> AstNode { + AstNode { stmt: AstStatement::ExpressionList(expressions), location, id } } /// creates a new if-statement pub fn create_if_statement( blocks: Vec, - else_block: Vec, + else_block: Vec, location: SourceLocation, id: AstId, - ) -> AstStatement { - AstStatement::ControlStatement { - kind: AstControlStatement::If(IfStatement { blocks, else_block }), + ) -> AstNode { + AstNode { + stmt: AstStatement::ControlStatement(AstControlStatement::If(IfStatement { blocks, else_block })), location, id, } @@ -1237,22 +1120,22 @@ impl AstFactory { /// creates a new for loop statement pub fn create_for_loop( - counter: AstStatement, - start: AstStatement, - end: AstStatement, - by_step: Option, - body: Vec, + counter: AstNode, + start: AstNode, + end: AstNode, + by_step: Option, + body: Vec, location: SourceLocation, id: AstId, - ) -> AstStatement { - AstStatement::ControlStatement { - kind: AstControlStatement::ForLoop(ForLoopStatement { + ) -> AstNode { + AstNode { + stmt: AstStatement::ControlStatement(AstControlStatement::ForLoop(ForLoopStatement { counter: Box::new(counter), start: Box::new(start), end: Box::new(end), by_step: by_step.map(Box::new), body, - }), + })), location, id, } @@ -1260,13 +1143,16 @@ impl AstFactory { /// creates a new while statement pub fn create_while_statement( - condition: AstStatement, - body: Vec, + condition: AstNode, + body: Vec, location: SourceLocation, id: AstId, - ) -> AstStatement { - AstStatement::ControlStatement { - kind: AstControlStatement::WhileLoop(LoopStatement { condition: Box::new(condition), body }), + ) -> AstNode { + AstNode { + stmt: AstStatement::ControlStatement(AstControlStatement::WhileLoop(LoopStatement { + condition: Box::new(condition), + body, + })), id, location, } @@ -1274,13 +1160,16 @@ impl AstFactory { /// creates a new repeat-statement pub fn create_repeat_statement( - condition: AstStatement, - body: Vec, + condition: AstNode, + body: Vec, location: SourceLocation, id: AstId, - ) -> AstStatement { - AstStatement::ControlStatement { - kind: AstControlStatement::RepeatLoop(LoopStatement { condition: Box::new(condition), body }), + ) -> AstNode { + AstNode { + stmt: AstStatement::ControlStatement(AstControlStatement::RepeatLoop(LoopStatement { + condition: Box::new(condition), + body, + })), id, location, } @@ -1288,96 +1177,135 @@ impl AstFactory { /// creates a new case-statement pub fn create_case_statement( - selector: AstStatement, + selector: AstNode, case_blocks: Vec, - else_block: Vec, + else_block: Vec, location: SourceLocation, id: AstId, - ) -> AstStatement { - AstStatement::ControlStatement { - kind: AstControlStatement::Case(CaseStatement { + ) -> AstNode { + AstNode { + stmt: AstStatement::ControlStatement(AstControlStatement::Case(CaseStatement { selector: Box::new(selector), case_blocks, else_block, - }), + })), id, location, } } /// creates an or-expression - pub fn create_or_expression(left: AstStatement, right: AstStatement) -> AstStatement { - AstStatement::BinaryExpression { - id: left.get_id(), - left: Box::new(left), - right: Box::new(right), - operator: Operator::Or, + pub fn create_or_expression(left: AstNode, right: AstNode) -> AstNode { + let id = left.get_id(); + let location = left.get_location().span(&right.get_location()); + AstNode { + stmt: AstStatement::BinaryExpression(BinaryExpression { + left: Box::new(left), + right: Box::new(right), + operator: Operator::Or, + }), + id, + location, } } /// creates a not-expression - pub fn create_not_expression(operator: AstStatement, location: SourceLocation) -> AstStatement { - AstStatement::UnaryExpression { - id: operator.get_id(), - value: Box::new(operator), + pub fn create_not_expression(operator: AstNode, location: SourceLocation) -> AstNode { + let id = operator.get_id(); + AstNode { + stmt: AstStatement::UnaryExpression(UnaryExpression { + value: Box::new(operator), + operator: Operator::Not, + }), + id, location, - operator: Operator::Not, } } /// creates a new Identifier - pub fn create_identifier(name: &str, location: &SourceLocation, id: AstId) -> AstStatement { - AstStatement::Identifier { id, location: location.clone(), name: name.to_string() } + pub fn create_identifier(name: &str, location: &SourceLocation, id: AstId) -> AstNode { + AstNode::new(AstStatement::Identifier(name.to_string()), id, location.clone()) } - pub fn create_member_reference( - member: AstStatement, - base: Option, + pub fn create_unary_expression( + operator: Operator, + value: AstNode, + location: SourceLocation, id: AstId, - ) -> AstStatement { + ) -> AstNode { + AstNode { + stmt: AstStatement::UnaryExpression(UnaryExpression { operator, value: Box::new(value) }), + location, + id, + } + } + + pub fn create_assignment(left: AstNode, right: AstNode, id: AstId) -> AstNode { + let location = left.location.span(&right.location); + AstNode { + stmt: AstStatement::Assignment(Assignment { left: Box::new(left), right: Box::new(right) }), + id, + location, + } + } + + pub fn create_output_assignment(left: AstNode, right: AstNode, id: AstId) -> AstNode { + let location = left.location.span(&right.location); + AstNode::new( + AstStatement::OutputAssignment(Assignment { left: Box::new(left), right: Box::new(right) }), + id, + location, + ) + } + + pub fn create_member_reference(member: AstNode, base: Option, id: AstId) -> AstNode { let location = base .as_ref() .map(|it| it.get_location().span(&member.get_location())) .unwrap_or_else(|| member.get_location()); - AstStatement::ReferenceExpr { - access: ReferenceAccess::Member(Box::new(member)), - base: base.map(Box::new), + AstNode { + stmt: AstStatement::ReferenceExpr(ReferenceExpr { + access: ReferenceAccess::Member(Box::new(member)), + base: base.map(Box::new), + }), id, location, } } pub fn create_index_reference( - index: AstStatement, - base: Option, + index: AstNode, + base: Option, id: AstId, location: SourceLocation, - ) -> AstStatement { - AstStatement::ReferenceExpr { - access: ReferenceAccess::Index(Box::new(index)), - base: base.map(Box::new), + ) -> AstNode { + AstNode { + stmt: AstStatement::ReferenceExpr(ReferenceExpr { + access: ReferenceAccess::Index(Box::new(index)), + base: base.map(Box::new), + }), id, location, } } - pub fn create_address_of_reference( - base: AstStatement, - id: AstId, - location: SourceLocation, - ) -> AstStatement { - AstStatement::ReferenceExpr { - access: ReferenceAccess::Address, - base: Some(Box::new(base)), + pub fn create_address_of_reference(base: AstNode, id: AstId, location: SourceLocation) -> AstNode { + AstNode { + stmt: AstStatement::ReferenceExpr(ReferenceExpr { + access: ReferenceAccess::Address, + base: Some(Box::new(base)), + }), id, location, } } - pub fn create_deref_reference(base: AstStatement, id: AstId, location: SourceLocation) -> AstStatement { - AstStatement::ReferenceExpr { - access: ReferenceAccess::Deref, - base: Some(Box::new(base)), + pub fn create_deref_reference(base: AstNode, id: AstId, location: SourceLocation) -> AstNode { + AstNode { + stmt: AstStatement::ReferenceExpr(ReferenceExpr { + access: ReferenceAccess::Deref, + base: Some(Box::new(base)), + }), id, location, } @@ -1385,78 +1313,132 @@ impl AstFactory { pub fn create_direct_access( access: DirectAccessType, - index: AstStatement, + index: AstNode, id: AstId, location: SourceLocation, - ) -> AstStatement { - AstStatement::DirectAccess { access, index: Box::new(index), location, id } + ) -> AstNode { + AstNode { + stmt: AstStatement::DirectAccess(DirectAccess { access, index: Box::new(index) }), + location, + id, + } } /// creates a new binary statement - pub fn create_binary_expression( - left: AstStatement, - operator: Operator, - right: AstStatement, - id: AstId, - ) -> AstStatement { - AstStatement::BinaryExpression { id, left: Box::new(left), operator, right: Box::new(right) } + pub fn create_binary_expression(left: AstNode, operator: Operator, right: AstNode, id: AstId) -> AstNode { + let location = left.location.span(&right.location); + AstNode { + stmt: AstStatement::BinaryExpression(BinaryExpression { + left: Box::new(left), + operator, + right: Box::new(right), + }), + id, + location, + } } /// creates a new cast statement pub fn create_cast_statement( - type_name: AstStatement, - stmt: AstStatement, + type_name: AstNode, + stmt: AstNode, location: &SourceLocation, id: AstId, - ) -> AstStatement { + ) -> AstNode { let new_location = location.span(&stmt.get_location()); - AstStatement::ReferenceExpr { - access: ReferenceAccess::Cast(Box::new(stmt)), - base: Some(Box::new(type_name)), + AstNode { + stmt: AstStatement::ReferenceExpr(ReferenceExpr { + access: ReferenceAccess::Cast(Box::new(stmt)), + base: Some(Box::new(type_name)), + }), id, location: new_location, } } + pub fn create_call_statement( + operator: AstNode, + parameters: Option, + id: usize, + location: SourceLocation, + ) -> AstNode { + AstNode { + stmt: AstStatement::CallStatement(CallStatement { + operator: Box::new(operator), + parameters: parameters.map(Box::new), + }), + location, + id, + } + } + /// creates a new call statement to the given function and parameters pub fn create_call_to( function_name: String, - parameters: Vec, + parameters: Vec, id: usize, parameter_list_id: usize, location: &SourceLocation, - ) -> AstStatement { - AstStatement::CallStatement { - operator: Box::new(AstFactory::create_member_reference( - AstFactory::create_identifier(&function_name, location, id), - None, - id, - )), - parameters: Box::new(Some(AstStatement::ExpressionList { - expressions: parameters, - id: parameter_list_id, - })), + ) -> AstNode { + AstNode { + stmt: AstStatement::CallStatement(CallStatement { + operator: Box::new(AstFactory::create_member_reference( + AstFactory::create_identifier(&function_name, location, id), + None, + id, + )), + parameters: Some(Box::new(AstNode::new( + AstStatement::ExpressionList(parameters), + parameter_list_id, + SourceLocation::undefined(), //TODO: get real location + ))), + }), location: location.clone(), id, } } + pub fn create_multiplied_statement( + multiplier: u32, + element: AstNode, + location: SourceLocation, + id: AstId, + ) -> AstNode { + AstNode { + stmt: AstStatement::MultipliedStatement(MultipliedStatement { + multiplier, + element: Box::new(element), + }), + location, + id, + } + } + + pub fn create_range_statement(start: AstNode, end: AstNode, id: AstId) -> AstNode { + let location = start.location.span(&end.location); + let data = RangeStatement { start: Box::new(start), end: Box::new(end) }; + AstNode { stmt: AstStatement::RangeStatement(data), id, location } + } + pub fn create_call_to_with_ids( function_name: &str, - parameters: Vec, + parameters: Vec, location: &SourceLocation, mut id_provider: IdProvider, - ) -> AstStatement { - AstStatement::CallStatement { - operator: Box::new(AstFactory::create_member_reference( - AstFactory::create_identifier(function_name, location, id_provider.next_id()), - None, - id_provider.next_id(), - )), - parameters: Box::new(Some(AstStatement::ExpressionList { - expressions: parameters, - id: id_provider.next_id(), - })), + ) -> AstNode { + AstNode { + stmt: AstStatement::CallStatement(CallStatement { + operator: Box::new(AstFactory::create_member_reference( + AstFactory::create_identifier(function_name, location, id_provider.next_id()), + None, + id_provider.next_id(), + )), + parameters: Some(Box::new(AstFactory::create_expression_list( + parameters, + SourceLocation::undefined(), + id_provider.next_id(), + ))), + }), location: location.clone(), id: id_provider.next_id(), } @@ -1464,11 +1446,11 @@ impl AstFactory { pub fn create_call_to_check_function_ast( check_function_name: &str, - parameter: AstStatement, - sub_range: Range, + parameter: AstNode, + sub_range: Range, location: &SourceLocation, id_provider: IdProvider, - ) -> AstStatement { + ) -> AstNode { AstFactory::create_call_to_with_ids( check_function_name, vec![parameter, sub_range.start, sub_range.end], @@ -1477,3 +1459,69 @@ impl AstFactory { ) } } +#[derive(Clone, PartialEq)] +pub struct EmptyStatement {} + +#[derive(Clone, PartialEq)] +pub struct DefaultValue {} + +#[derive(Clone, PartialEq)] +pub struct CastStatement { + pub target: Box, + pub type_name: String, +} + +#[derive(Clone, PartialEq)] +pub struct MultipliedStatement { + pub multiplier: u32, + pub element: Box, +} +#[derive(Clone, PartialEq)] +pub struct ReferenceExpr { + pub access: ReferenceAccess, + pub base: Option>, +} + +#[derive(Clone, PartialEq)] +pub struct DirectAccess { + pub access: DirectAccessType, + pub index: Box, +} + +#[derive(Clone, PartialEq)] +pub struct HardwareAccess { + pub direction: HardwareAccessType, + pub access: DirectAccessType, + pub address: Vec, +} + +#[derive(Clone, PartialEq)] +pub struct BinaryExpression { + pub operator: Operator, + pub left: Box, + pub right: Box, +} + +#[derive(Clone, PartialEq)] +pub struct UnaryExpression { + pub operator: Operator, + pub value: Box, +} + +#[derive(Clone, PartialEq)] +pub struct RangeStatement { + pub start: Box, + pub end: Box, +} + +#[derive(Clone, PartialEq)] +pub struct Assignment { + pub left: Box, + pub right: Box, +} + +#[derive(Clone, PartialEq)] +pub struct CallStatement { + pub operator: Box, + pub parameters: Option>, +} diff --git a/compiler/plc_ast/src/control_statements.rs b/compiler/plc_ast/src/control_statements.rs index c13b0b5602..357b8a72d0 100644 --- a/compiler/plc_ast/src/control_statements.rs +++ b/compiler/plc_ast/src/control_statements.rs @@ -1,34 +1,34 @@ use std::fmt::{Debug, Formatter}; -use crate::ast::AstStatement; +use crate::ast::AstNode; #[derive(Clone, PartialEq)] pub struct IfStatement { pub blocks: Vec, - pub else_block: Vec, + pub else_block: Vec, } #[derive(Clone, PartialEq)] pub struct ForLoopStatement { - pub counter: Box, - pub start: Box, - pub end: Box, - pub by_step: Option>, - pub body: Vec, + pub counter: Box, + pub start: Box, + pub end: Box, + pub by_step: Option>, + pub body: Vec, } #[derive(Clone, PartialEq)] /// used for While and Repeat loops pub struct LoopStatement { - pub condition: Box, - pub body: Vec, + pub condition: Box, + pub body: Vec, } #[derive(Clone, PartialEq)] pub struct CaseStatement { - pub selector: Box, + pub selector: Box, pub case_blocks: Vec, - pub else_block: Vec, + pub else_block: Vec, } #[derive(Clone, PartialEq)] @@ -42,8 +42,8 @@ pub enum AstControlStatement { #[derive(Clone, PartialEq)] pub struct ConditionalBlock { - pub condition: Box, - pub body: Vec, + pub condition: Box, + pub body: Vec, } impl Debug for ConditionalBlock { diff --git a/compiler/plc_ast/src/literals.rs b/compiler/plc_ast/src/literals.rs index 857316c5fb..3cc8942a5f 100644 --- a/compiler/plc_ast/src/literals.rs +++ b/compiler/plc_ast/src/literals.rs @@ -2,7 +2,7 @@ use std::fmt::{Debug, Formatter}; use chrono::NaiveDate; -use crate::ast::AstStatement; +use crate::ast::AstNode; macro_rules! impl_getters { ($type:ty, [$($name:ident),+], [$($out:ty),+]) => { @@ -84,7 +84,7 @@ pub struct StringValue { #[derive(Clone, PartialEq)] pub struct Array { - pub elements: Option>, // expression-list + pub elements: Option>, // expression-list } /// calculates the nanoseconds since 1970-01-01-00:00:00 for the given @@ -170,14 +170,14 @@ impl Time { } impl Array { - pub fn elements(&self) -> Option<&AstStatement> { + pub fn elements(&self) -> Option<&AstNode> { self.elements.as_ref().map(|it| it.as_ref()) } } impl AstLiteral { /// Creates a new literal array - pub fn new_array(elements: Option>) -> Self { + pub fn new_array(elements: Option>) -> Self { AstLiteral::Array(Array { elements }) } /// Creates a new literal integer diff --git a/compiler/plc_ast/src/pre_processor.rs b/compiler/plc_ast/src/pre_processor.rs index dcd9560edf..a0a7e6d20b 100644 --- a/compiler/plc_ast/src/pre_processor.rs +++ b/compiler/plc_ast/src/pre_processor.rs @@ -6,8 +6,8 @@ use plc_util::convention::internal_type_name; use crate::{ ast::{ - flatten_expression_list, AstFactory, AstStatement, CompilationUnit, DataType, DataTypeDeclaration, - Operator, Pou, UserTypeDeclaration, Variable, + flatten_expression_list, Assignment, AstFactory, AstNode, AstStatement, CompilationUnit, DataType, + DataTypeDeclaration, Operator, Pou, UserTypeDeclaration, Variable, }, literals::AstLiteral, provider::IdProvider, @@ -80,27 +80,24 @@ pub fn pre_process(unit: &mut CompilationUnit, mut id_provider: IdProvider) { } } DataType::EnumType { elements, .. } - if matches!(elements, AstStatement::EmptyStatement { .. }) => + if matches!(elements.stmt, AstStatement::EmptyStatement { .. }) => { //avoid empty statements, just use an empty expression list to make it easier to work with - let _ = std::mem::replace( - elements, - AstStatement::ExpressionList { expressions: vec![], id: id_provider.next_id() }, - ); + let _ = std::mem::replace(&mut elements.stmt, AstStatement::ExpressionList(vec![])); } DataType::EnumType { elements: original_elements, name: Some(enum_name), .. } - if !matches!(original_elements, AstStatement::EmptyStatement { .. }) => + if !matches!(original_elements.stmt, AstStatement::EmptyStatement { .. }) => { let mut last_name: Option = None; - fn extract_flat_ref_name(statement: &AstStatement) -> &str { + fn extract_flat_ref_name(statement: &AstNode) -> &str { statement.get_flat_reference_name().expect("expected assignment") } let initialized_enum_elements = flatten_expression_list(original_elements) .iter() - .map(|it| match it { - AstStatement::Assignment { left, right, .. } => { + .map(|it| match &it.stmt { + AstStatement::Assignment(Assignment { left, right }) => { // ( extract_flat_ref_name(left.as_ref()), @@ -115,9 +112,8 @@ pub fn pre_process(unit: &mut CompilationUnit, mut id_provider: IdProvider) { build_enum_initializer(&last_name, &location, &mut id_provider, enum_name) }); last_name = Some(element_name.to_string()); - AstStatement::Assignment { - id: id_provider.next_id(), - left: Box::new(AstFactory::create_member_reference( + AstFactory::create_assignment( + AstFactory::create_member_reference( AstFactory::create_identifier( element_name, &location, @@ -125,18 +121,25 @@ pub fn pre_process(unit: &mut CompilationUnit, mut id_provider: IdProvider) { ), None, id_provider.next_id(), - )), - right: Box::new(enum_literal), - } + ), + enum_literal, + id_provider.next_id(), + ) }) - .collect::>(); + .collect::>(); // if the enum is empty, we dont change anything if !initialized_enum_elements.is_empty() { + // we can safely unwrap because we checked the vec + let start_loc = + initialized_enum_elements.first().expect("non empty vec").get_location(); + let end_loc = + initialized_enum_elements.iter().last().expect("non empty vec").get_location(); //swap the expression list with our new Assignments - let expression = AstStatement::ExpressionList { - expressions: initialized_enum_elements, - id: id_provider.next_id(), - }; + let expression = AstFactory::create_expression_list( + initialized_enum_elements, + start_loc.span(&end_loc), + id_provider.next_id(), + ); let _ = std::mem::replace(original_elements, expression); } } @@ -152,7 +155,7 @@ fn build_enum_initializer( location: &SourceLocation, id_provider: &mut IdProvider, enum_name: &mut str, -) -> AstStatement { +) -> AstNode { if let Some(last_element) = last_name.as_ref() { // generate a `enum#last + 1` statement let enum_ref = AstFactory::create_identifier(last_element, location, id_provider.next_id()); @@ -164,11 +167,11 @@ fn build_enum_initializer( AstFactory::create_binary_expression( AstFactory::create_cast_statement(type_element, enum_ref, location, id_provider.next_id()), Operator::Plus, - AstStatement::new_literal(AstLiteral::new_integer(1), id_provider.next_id(), location.clone()), + AstNode::new_literal(AstLiteral::new_integer(1), id_provider.next_id(), location.clone()), id_provider.next_id(), ) } else { - AstStatement::new_literal(AstLiteral::new_integer(0), id_provider.next_id(), location.clone()) + AstNode::new_literal(AstLiteral::new_integer(0), id_provider.next_id(), location.clone()) } } diff --git a/compiler/plc_diagnostics/src/diagnostics.rs b/compiler/plc_diagnostics/src/diagnostics.rs index 01094b7d10..c87b31e0b5 100644 --- a/compiler/plc_diagnostics/src/diagnostics.rs +++ b/compiler/plc_diagnostics/src/diagnostics.rs @@ -1,6 +1,6 @@ use std::{error::Error, ops::Range}; -use plc_ast::ast::{AstStatement, DataTypeDeclaration, DiagnosticInfo, PouType}; +use plc_ast::ast::{AstNode, DataTypeDeclaration, DiagnosticInfo, PouType}; use plc_source::source_location::SourceLocation; use crate::errno::ErrNo; @@ -669,7 +669,7 @@ impl Diagnostic { } } - pub fn invalid_range_statement(entity: &AstStatement, range: SourceLocation) -> Diagnostic { + pub fn invalid_range_statement(entity: &AstNode, range: SourceLocation) -> Diagnostic { Diagnostic::SyntaxError { message: format!("Expected a range statement, got {entity:?} instead"), range: vec![range], diff --git a/compiler/plc_driver/src/pipelines.rs b/compiler/plc_driver/src/pipelines.rs index 8a627a40db..db987b8b79 100644 --- a/compiler/plc_driver/src/pipelines.rs +++ b/compiler/plc_driver/src/pipelines.rs @@ -182,7 +182,7 @@ impl IndexedProject { /// A project that has been annotated with information about different types and used units pub struct AnnotatedProject { - units: Vec<(CompilationUnit, IndexSet, StringLiterals)>, + pub units: Vec<(CompilationUnit, IndexSet, StringLiterals)>, index: Index, annotations: AstAnnotations, } diff --git a/compiler/plc_driver/src/runner.rs b/compiler/plc_driver/src/runner.rs index 799a399643..9850013cc9 100644 --- a/compiler/plc_driver/src/runner.rs +++ b/compiler/plc_driver/src/runner.rs @@ -38,6 +38,7 @@ pub fn compile(context: &CodegenContext, source: T) -> GeneratedM ..Default::default() }; + dbg!(&annotated_project.units[0].0); annotated_project.generate_single_module(context, &compile_options).unwrap().unwrap() } diff --git a/compiler/plc_xml/src/xml_parser.rs b/compiler/plc_xml/src/xml_parser.rs index b50064bd93..bfd3a9a932 100644 --- a/compiler/plc_xml/src/xml_parser.rs +++ b/compiler/plc_xml/src/xml_parser.rs @@ -1,5 +1,5 @@ use ast::{ - ast::{AstId, AstStatement, CompilationUnit, Implementation, LinkageType, PouType as AstPouType}, + ast::{AstId, AstNode, CompilationUnit, Implementation, LinkageType, PouType as AstPouType}, provider::IdProvider, }; use plc::{lexer, parser::expressions_parser::parse_expression}; @@ -119,14 +119,15 @@ impl<'parse> ParseSession<'parse> { )) } - fn parse_expression(&self, expr: &str, local_id: usize, execution_order: Option) -> AstStatement { - let exp = parse_expression(&mut lexer::lex_with_ids( + fn parse_expression(&self, expr: &str, local_id: usize, execution_order: Option) -> AstNode { + let mut exp = parse_expression(&mut lexer::lex_with_ids( html_escape::decode_html_entities_to_string(expr, &mut String::new()), self.id_provider.clone(), self.range_factory.clone(), )); let loc = exp.get_location(); - exp.set_location(self.range_factory.create_block_location(local_id, execution_order).span(&loc)) + exp.set_location(self.range_factory.create_block_location(local_id, execution_order).span(&loc)); + exp } fn parse_model(&self) -> Vec { diff --git a/compiler/plc_xml/src/xml_parser/action.rs b/compiler/plc_xml/src/xml_parser/action.rs index a10d7a6d2e..15ec334e73 100644 --- a/compiler/plc_xml/src/xml_parser/action.rs +++ b/compiler/plc_xml/src/xml_parser/action.rs @@ -1,11 +1,11 @@ -use ast::ast::{AstStatement, Implementation, PouType as AstPouType}; +use ast::ast::{AstNode, Implementation, PouType as AstPouType}; use crate::model::action::Action; use super::ParseSession; impl Action { - pub(crate) fn transform(&self, _session: &ParseSession) -> Vec { + pub(crate) fn transform(&self, _session: &ParseSession) -> Vec { todo!() } diff --git a/compiler/plc_xml/src/xml_parser/block.rs b/compiler/plc_xml/src/xml_parser/block.rs index d22eb4389a..02d212956a 100644 --- a/compiler/plc_xml/src/xml_parser/block.rs +++ b/compiler/plc_xml/src/xml_parser/block.rs @@ -1,11 +1,11 @@ -use ast::ast::{AstFactory, AstStatement}; +use ast::ast::{AstFactory, AstNode}; use crate::model::{block::Block, fbd::NodeIndex}; use super::ParseSession; impl Block { - pub(crate) fn transform(&self, session: &ParseSession, index: &NodeIndex) -> AstStatement { + pub(crate) fn transform(&self, session: &ParseSession, index: &NodeIndex) -> AstNode { let parameters = self .variables .iter() diff --git a/compiler/plc_xml/src/xml_parser/fbd.rs b/compiler/plc_xml/src/xml_parser/fbd.rs index 98745af836..a2e9cb9d43 100644 --- a/compiler/plc_xml/src/xml_parser/fbd.rs +++ b/compiler/plc_xml/src/xml_parser/fbd.rs @@ -1,4 +1,4 @@ -use ast::ast::AstStatement; +use ast::ast::{AstFactory, AstNode, AstStatement}; use indexmap::IndexMap; use crate::model::fbd::{FunctionBlockDiagram, Node, NodeId}; @@ -8,7 +8,7 @@ use super::ParseSession; impl FunctionBlockDiagram { /// Transforms the body of a function block diagram to their AST-equivalent, in order of execution. /// Only statements that are necessary for execution logic will be selected. - pub(crate) fn transform(&self, session: &ParseSession) -> Vec { + pub(crate) fn transform(&self, session: &ParseSession) -> Vec { let mut ast_association = IndexMap::new(); // transform each node to an ast-statement. since we might see and transform a node multiple times, we use an // ast-association map to keep track of the latest statement for each id @@ -34,8 +34,8 @@ impl FunctionBlockDiagram { &self, id: NodeId, session: &ParseSession, - ast_association: &IndexMap, - ) -> (AstStatement, Option) { + ast_association: &IndexMap, + ) -> (AstNode, Option) { let Some(current_node) = self.nodes.get(&id) else { unreachable!() }; match current_node { @@ -51,7 +51,7 @@ impl FunctionBlockDiagram { let (rhs, remove_id) = ast_association .get(&ref_id) .map(|stmt| { - if matches!(stmt, AstStatement::CallStatement { .. }) { + if matches!(stmt.get_stmt(), AstStatement::CallStatement(..)) { (stmt.clone(), Some(ref_id)) } else { self.transform_node(ref_id, session, ast_association) @@ -59,14 +59,7 @@ impl FunctionBlockDiagram { }) .expect("Expected AST statement, found None"); - ( - AstStatement::Assignment { - left: Box::new(lhs), - right: Box::new(rhs), - id: session.next_id(), - }, - remove_id, - ) + (AstFactory::create_assignment(lhs, rhs, session.next_id()), remove_id) } Node::Control(_) => todo!(), Node::Connector(_) => todo!(), diff --git a/compiler/plc_xml/src/xml_parser/pou.rs b/compiler/plc_xml/src/xml_parser/pou.rs index ed94e7f74f..aae016ba90 100644 --- a/compiler/plc_xml/src/xml_parser/pou.rs +++ b/compiler/plc_xml/src/xml_parser/pou.rs @@ -1,11 +1,11 @@ -use ast::ast::{AstStatement, Implementation}; +use ast::ast::{AstNode, Implementation}; use crate::model::pou::Pou; use super::ParseSession; impl Pou { - fn transform(&self, session: &ParseSession) -> Vec { + fn transform(&self, session: &ParseSession) -> Vec { let Some(fbd) = &self.body.function_block_diagram else { // empty body return vec![]; diff --git a/compiler/plc_xml/src/xml_parser/tests.rs b/compiler/plc_xml/src/xml_parser/tests.rs index dae613787c..151d90259e 100644 --- a/compiler/plc_xml/src/xml_parser/tests.rs +++ b/compiler/plc_xml/src/xml_parser/tests.rs @@ -1,5 +1,8 @@ use ast::{ - ast::{flatten_expression_list, AstStatement, CompilationUnit, LinkageType}, + ast::{ + flatten_expression_list, Assignment, AstNode, AstStatement, CallStatement, CompilationUnit, + LinkageType, + }, provider::IdProvider, }; use insta::assert_debug_snapshot; @@ -129,18 +132,18 @@ fn ast_generates_locations() { let (units, diagnostics) = xml_parser::parse(&source_code, LinkageType::Internal, IdProvider::default()); let impl1 = &units.implementations[0]; //Deconstruct assignment and get locations - let AstStatement::Assignment { left, right, .. } = &impl1.statements[0] else { - panic!("Not an assignment"); - }; + let AstStatement::Assignment (Assignment{ left, right, .. })= &impl1.statements[0].get_stmt() else { + panic!("Not an assignment"); + }; assert_debug_snapshot!(left.get_location()); assert_debug_snapshot!(right.get_location()); //Deconstruct call statement and get locations - let AstStatement::CallStatement { operator, parameters, location, .. } = &impl1.statements[1] else { - panic!("Not a call statement"); - }; + let AstNode { stmt: AstStatement::CallStatement (CallStatement{ operator, parameters, .. }), location, ..} = &impl1.statements[1] else { + panic!("Not a call statement"); + }; assert_debug_snapshot!(location); assert_debug_snapshot!(operator.get_location()); - let parameters = parameters.as_ref().as_ref().unwrap(); + let parameters = parameters.as_deref().unwrap(); let parameters = flatten_expression_list(parameters); for param in parameters { assert_debug_snapshot!(param.get_location()); diff --git a/compiler/plc_xml/src/xml_parser/variables.rs b/compiler/plc_xml/src/xml_parser/variables.rs index 5e3bf28944..7840f5cf1d 100644 --- a/compiler/plc_xml/src/xml_parser/variables.rs +++ b/compiler/plc_xml/src/xml_parser/variables.rs @@ -1,4 +1,4 @@ -use ast::ast::{AstStatement, Operator}; +use ast::ast::{AstFactory, AstNode, Operator}; use crate::model::{ fbd::{Node, NodeIndex}, @@ -8,7 +8,7 @@ use crate::model::{ use super::ParseSession; impl BlockVariable { - pub(crate) fn transform(&self, session: &ParseSession, index: &NodeIndex) -> Option { + pub(crate) fn transform(&self, session: &ParseSession, index: &NodeIndex) -> Option { let Some(ref_id) = &self.ref_local_id else { // param not provided/passed return None; @@ -27,16 +27,16 @@ impl BlockVariable { // variables, parameters -> more readable names? impl FunctionBlockVariable { - pub(crate) fn transform(&self, session: &ParseSession) -> AstStatement { + pub(crate) fn transform(&self, session: &ParseSession) -> AstNode { if self.negated { let ident = session.parse_expression(&self.expression, self.local_id, self.execution_order_id); - AstStatement::UnaryExpression { - operator: Operator::Not, - value: Box::new(ident), - location: session.create_block_location(self.local_id, self.execution_order_id), - id: session.next_id(), - } + AstFactory::create_unary_expression( + Operator::Not, + ident, + session.create_block_location(self.local_id, self.execution_order_id), + session.next_id(), + ) } else { session.parse_expression(&self.expression, self.local_id, self.execution_order_id) } diff --git a/libs/stdlib/tests/arithmetic_functions_tests.rs b/libs/stdlib/tests/arithmetic_functions_tests.rs index 8ee55d0b92..87411b1914 100644 --- a/libs/stdlib/tests/arithmetic_functions_tests.rs +++ b/libs/stdlib/tests/arithmetic_functions_tests.rs @@ -486,6 +486,7 @@ fn expt_called_with_operator() { a := 2**7; END_PROGRAM "#; + let sources = add_std!(src, "arithmetic_functions.st"); let mut maintype = MainType::::default(); let _: i32 = compile_and_run(sources, &mut maintype); diff --git a/src/builtins.rs b/src/builtins.rs index 1adde40253..5b9cce3854 100644 --- a/src/builtins.rs +++ b/src/builtins.rs @@ -8,7 +8,7 @@ use inkwell::{ use lazy_static::lazy_static; use plc_ast::{ ast::{ - self, flatten_expression_list, pre_process, AstStatement, CompilationUnit, GenericBinding, + self, flatten_expression_list, pre_process, AstNode, AstStatement, CompilationUnit, GenericBinding, LinkageType, TypeNature, }, literals::AstLiteral, @@ -328,7 +328,7 @@ lazy_static! { fn annotate_variable_length_array_bound_function( annotator: &mut TypeAnnotator, - parameters: Option<&AstStatement>, + parameters: Option<&AstNode>, ) { let Some(parameters) = parameters else { // caught during validation @@ -354,8 +354,8 @@ fn annotate_variable_length_array_bound_function( fn validate_variable_length_array_bound_function( validator: &mut Validator, - operator: &AstStatement, - parameters: &Option, + operator: &AstNode, + parameters: Option<&AstNode>, annotations: &dyn AnnotationMap, index: &Index, ) { @@ -388,7 +388,7 @@ fn validate_variable_length_array_bound_function( } // TODO: consider adding validation for consts and enums once https://github.com/PLC-lang/rusty/issues/847 has been implemented - if let AstStatement::Literal { kind: AstLiteral::Integer(dimension_idx), .. } = idx { + if let AstStatement::Literal(AstLiteral::Integer(dimension_idx)) = idx.get_stmt() { let dimension_idx = *dimension_idx as usize; let Some(n_dimensions) = @@ -414,7 +414,7 @@ fn validate_variable_length_array_bound_function( /// arguments are incorrect. fn generate_variable_length_array_bound_function<'ink>( generator: &ExpressionCodeGenerator<'ink, '_>, - params: &[&AstStatement], + params: &[&AstNode], is_lower: bool, location: SourceLocation, ) -> Result, Diagnostic> { @@ -435,9 +435,9 @@ fn generate_variable_length_array_bound_function<'ink>( let vla = generator.generate_lvalue(params[0]).unwrap(); let dim = builder.build_struct_gep(vla, 1, "dim").unwrap(); - let accessor = match params[1] { + let accessor = match params[1].get_stmt() { // e.g. LOWER_BOUND(arr, 1) - AstStatement::Literal { kind, .. } => { + AstStatement::Literal(kind) => { let AstLiteral::Integer(value) = kind else { let Some(type_name) = get_literal_actual_signed_type_name(kind, false) else { unreachable!("type cannot be VOID") @@ -452,8 +452,8 @@ fn generate_variable_length_array_bound_function<'ink>( let offset = if is_lower { (value - 1) as u64 * 2 } else { (value - 1) as u64 * 2 + 1 }; llvm.i32_type().const_int(offset, false) } - AstStatement::CastStatement { target, .. } => { - let ExpressionValue::RValue(value) = generator.generate_expression_value(target)? else { + AstStatement::CastStatement(data) => { + let ExpressionValue::RValue(value) = generator.generate_expression_value(&data.target)? else { unreachable!() }; @@ -497,15 +497,14 @@ fn generate_variable_length_array_bound_function<'ink>( Ok(ExpressionValue::RValue(bound)) } -type AnnotationFunction = fn(&mut TypeAnnotator, &AstStatement, Option<&AstStatement>, VisitorContext); +type AnnotationFunction = fn(&mut TypeAnnotator, &AstNode, Option<&AstNode>, VisitorContext); type GenericNameResolver = fn(&str, &[GenericBinding], &HashMap) -> String; type CodegenFunction = for<'ink, 'b> fn( &'b ExpressionCodeGenerator<'ink, 'b>, - &[&AstStatement], + &[&AstNode], SourceLocation, ) -> Result, Diagnostic>; -type ValidationFunction = - fn(&mut Validator, &AstStatement, &Option, &dyn AnnotationMap, &Index); +type ValidationFunction = fn(&mut Validator, &AstNode, Option<&AstNode>, &dyn AnnotationMap, &Index); pub struct BuiltIn { decl: &'static str, @@ -519,7 +518,7 @@ impl BuiltIn { pub fn codegen<'ink, 'b>( &self, generator: &'b ExpressionCodeGenerator<'ink, 'b>, - params: &[&AstStatement], + params: &[&AstNode], location: SourceLocation, ) -> Result, Diagnostic> { (self.code)(generator, params, location) diff --git a/src/codegen/generators/data_type_generator.rs b/src/codegen/generators/data_type_generator.rs index 63d18edc29..09a3786565 100644 --- a/src/codegen/generators/data_type_generator.rs +++ b/src/codegen/generators/data_type_generator.rs @@ -17,7 +17,7 @@ use inkwell::{ values::{BasicValue, BasicValueEnum}, AddressSpace, }; -use plc_ast::ast::AstStatement; +use plc_ast::ast::{AstNode, AstStatement}; use plc_ast::literals::AstLiteral; use plc_diagnostics::diagnostics::Diagnostic; use plc_diagnostics::errno::ErrNo; @@ -310,12 +310,12 @@ impl<'ink, 'b> DataTypeGenerator<'ink, 'b> { } DataTypeInformation::Array { .. } => self.generate_array_initializer( data_type, - |stmt| matches!(stmt, AstStatement::Literal { kind: AstLiteral::Array { .. }, .. }), + |stmt| matches!(stmt.stmt, AstStatement::Literal(AstLiteral::Array { .. })), "LiteralArray", ), DataTypeInformation::String { .. } => self.generate_array_initializer( data_type, - |stmt| matches!(stmt, AstStatement::Literal { kind: AstLiteral::String { .. }, .. }), + |stmt| matches!(stmt.stmt, AstStatement::Literal(AstLiteral::String { .. })), "LiteralString", ), DataTypeInformation::SubRange { referenced_type, .. } => { @@ -364,7 +364,7 @@ impl<'ink, 'b> DataTypeGenerator<'ink, 'b> { fn generate_initializer( &mut self, qualified_name: &str, - initializer: Option<&AstStatement>, + initializer: Option<&AstNode>, data_type_name: &str, ) -> Result>, Diagnostic> { if let Some(initializer) = initializer { @@ -389,7 +389,7 @@ impl<'ink, 'b> DataTypeGenerator<'ink, 'b> { fn generate_array_initializer( &self, data_type: &DataType, - predicate: fn(&AstStatement) -> bool, + predicate: fn(&AstNode) -> bool, expected_ast: &str, ) -> Result>, Diagnostic> { if let Some(initializer) = diff --git a/src/codegen/generators/expression_generator.rs b/src/codegen/generators/expression_generator.rs index 6f54ed8f65..00943b7aad 100644 --- a/src/codegen/generators/expression_generator.rs +++ b/src/codegen/generators/expression_generator.rs @@ -24,10 +24,13 @@ use inkwell::{ }, AddressSpace, FloatPredicate, IntPredicate, }; -use plc_ast::ast::{ - flatten_expression_list, AstFactory, AstStatement, DirectAccessType, Operator, ReferenceAccess, +use plc_ast::{ + ast::{ + flatten_expression_list, AstFactory, AstNode, AstStatement, DirectAccessType, Operator, + ReferenceAccess, ReferenceExpr, + }, + literals::AstLiteral, }; -use plc_ast::literals::AstLiteral; use plc_diagnostics::diagnostics::{Diagnostic, INTERNAL_LLVM_ERROR}; use plc_source::source_location::SourceLocation; use plc_util::convention::qualified_name; @@ -58,7 +61,7 @@ pub struct ExpressionCodeGenerator<'a, 'b> { #[derive(Debug)] struct CallParameterAssignment<'a, 'b> { /// the assignmentstatement in the call-argument list (a:=3) - assignment_statement: &'b AstStatement, + assignment_statement: &'b AstNode, /// the name of the function we're calling function_name: &'b str, /// the position of the argument in the POU's argument's list @@ -152,14 +155,14 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { /// returns the function context or returns a Compile-Error pub fn get_function_context( &self, - statement: &AstStatement, + statement: &AstNode, ) -> Result<&'b FunctionContext<'ink, 'b>, Diagnostic> { self.function_context.ok_or_else(|| Diagnostic::missing_function(statement.get_location())) } /// entry point into the expression generator. /// generates the given expression and returns the resulting BasicValueEnum - pub fn generate_expression(&self, expression: &AstStatement) -> Result, Diagnostic> { + pub fn generate_expression(&self, expression: &AstNode) -> Result, Diagnostic> { // If the expression was replaced by the resolver, generate the replacement if let Some(StatementAnnotation::ReplacementAst { statement }) = self.annotations.get(expression) { // we trust that the validator only passed us valid parameters (so left & right should be same type) @@ -178,7 +181,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { Ok(cast_if_needed!(self, target_type, actual_type, v, self.annotations.get(expression))) } - fn register_debug_location(&self, statement: &AstStatement) { + fn register_debug_location(&self, statement: &AstNode) { let function_context = self.function_context.expect("Cannot generate debug info without function context"); let line = statement.get_location().get_line(); @@ -188,7 +191,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { pub fn generate_expression_value( &self, - expression: &AstStatement, + expression: &AstNode, ) -> Result, Diagnostic> { //see if this is a constant - maybe we can short curcuit this codegen if let Some(StatementAnnotation::Variable { @@ -201,9 +204,10 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { } } // generate the expression - match expression { - AstStatement::ReferenceExpr { access, base, .. } => { - let res = self.generate_reference_expression(access, base.as_deref(), expression)?; + match expression.get_stmt() { + AstStatement::ReferenceExpr(data) => { + let res = + self.generate_reference_expression(&data.access, data.base.as_deref(), expression)?; let val = match res { ExpressionValue::LValue(val) => { ExpressionValue::LValue(self.auto_deref_if_necessary(val, expression)) @@ -220,14 +224,14 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { }; Ok(val) } - AstStatement::BinaryExpression { left, right, operator, .. } => self - .generate_binary_expression(left, right, operator, expression) + AstStatement::BinaryExpression(data) => self + .generate_binary_expression(&data.left, &data.right, &data.operator, expression) .map(ExpressionValue::RValue), - AstStatement::CallStatement { operator, parameters, .. } => { - self.generate_call_statement(operator, parameters) + AstStatement::CallStatement(data) => { + self.generate_call_statement(&data.operator, data.parameters.as_deref()) } - AstStatement::UnaryExpression { operator, value, .. } => { - self.generate_unary_expression(operator, value).map(ExpressionValue::RValue) + AstStatement::UnaryExpression(data) => { + self.generate_unary_expression(&data.operator, &data.value).map(ExpressionValue::RValue) } // TODO: Hardware access needs to be evaluated, see #648 AstStatement::HardwareAccess { .. } => { @@ -244,7 +248,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { fn generate_constant_expression( &self, qualified_name: &str, - expression: &AstStatement, + expression: &AstNode, ) -> Result, Diagnostic> { let const_expression = self .index @@ -277,10 +281,10 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { /// - `expression` the whole expression for diagnostic reasons fn generate_binary_expression( &self, - left: &AstStatement, - right: &AstStatement, + left: &AstNode, + right: &AstNode, operator: &Operator, - expression: &AstStatement, + expression: &AstNode, ) -> Result, Diagnostic> { let l_type_hint = self.get_type_hint_for(left)?; let ltype = self.index.get_intrinsic_type_by_name(l_type_hint.get_name()).get_type_information(); @@ -314,7 +318,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { pub fn generate_direct_access_index( &self, access: &DirectAccessType, - index: &AstStatement, + index: &AstNode, access_type: &DataTypeInformation, target_type: &DataType, ) -> Result, Diagnostic> { @@ -347,7 +351,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { fn generate_unary_expression( &self, unary_operator: &Operator, - expression: &AstStatement, + expression: &AstNode, ) -> Result, Diagnostic> { let value = match unary_operator { Operator::Not => { @@ -397,8 +401,8 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { /// - `parameters` - an optional StatementList of parameters pub fn generate_call_statement( &self, - operator: &AstStatement, - parameters: &Option, + operator: &AstNode, + parameters: Option<&AstNode>, ) -> Result, Diagnostic> { // find the pou we're calling let pou = self.annotations.get_call_name(operator).zip(self.annotations.get_qualified_name(operator)) @@ -419,7 +423,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { .find_implementation(self.index) .ok_or_else(|| Diagnostic::cannot_generate_call_statement(operator))?; - let parameters_list = parameters.as_ref().map(flatten_expression_list).unwrap_or_default(); + let parameters_list = parameters.map(flatten_expression_list).unwrap_or_default(); let implementation_name = implementation.get_call_name(); // if the function is builtin, generate a basic value enum for it if let Some(builtin) = self.index.get_builtin_function(implementation_name) { @@ -507,7 +511,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { &self, parameter_struct: PointerValue<'ink>, function_name: &str, - parameters: Vec<&AstStatement>, + parameters: Vec<&AstNode>, ) -> Result<(), Diagnostic> { for (index, assignment_statement) in parameters.into_iter().enumerate() { self.assign_output_value(&CallParameterAssignment { @@ -521,14 +525,14 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { } fn assign_output_value(&self, param_context: &CallParameterAssignment) -> Result<(), Diagnostic> { - match param_context.assignment_statement { - AstStatement::OutputAssignment { left, right, .. } - | AstStatement::Assignment { left, right, .. } => self.generate_explicit_output_assignment( - param_context.parameter_struct, - param_context.function_name, - left, - right, - ), + match param_context.assignment_statement.get_stmt() { + AstStatement::OutputAssignment(data) | AstStatement::Assignment(data) => self + .generate_explicit_output_assignment( + param_context.parameter_struct, + param_context.function_name, + &data.left, + &data.right, + ), _ => self.generate_output_assignment(param_context), } } @@ -541,7 +545,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { let index = param_context.index; if let Some(parameter) = self.index.get_declared_parameter(function_name, index) { if matches!(parameter.get_variable_type(), VariableType::Output) - && !matches!(expression, AstStatement::EmptyStatement { .. }) + && !matches!(expression.get_stmt(), AstStatement::EmptyStatement { .. }) { { let assigned_output = self.generate_lvalue(expression)?; @@ -586,8 +590,8 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { &self, parameter_struct: PointerValue<'ink>, function_name: &str, - left: &AstStatement, - right: &AstStatement, + left: &AstNode, + right: &AstNode, ) -> Result<(), Diagnostic> { if let Some(StatementAnnotation::Variable { qualified_name, .. }) = self.annotations.get(left) { let parameter = self @@ -611,9 +615,9 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { fn generate_pou_call_arguments_list( &self, pou: &PouIndexEntry, - passed_parameters: &[&AstStatement], + passed_parameters: &[&AstNode], implementation: &ImplementationIndexEntry, - operator: &AstStatement, + operator: &AstNode, function_context: &'b FunctionContext<'ink, 'b>, ) -> Result>, Diagnostic> { let arguments_list = if matches!(pou, PouIndexEntry::Function { .. }) { @@ -631,7 +635,10 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { } // TODO: find a more reliable way to make sure if this is a call into a local action!! PouIndexEntry::Action { .. } - if matches!(operator, AstStatement::ReferenceExpr { base: None, .. }) => + if matches!( + operator.get_stmt(), + AstStatement::ReferenceExpr(ReferenceExpr { base: None, .. }) + ) => { // special handling for local actions, get the parameter from the function context function_context @@ -660,7 +667,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { fn generate_function_arguments( &self, pou: &PouIndexEntry, - passed_parameters: &[&AstStatement], + passed_parameters: &[&AstNode], declared_parameters: Vec<&VariableIndexEntry>, ) -> Result>, Diagnostic> { let mut result = Vec::new(); @@ -704,7 +711,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { self.generate_argument_by_ref(parameter, type_name, declared_parameter.copied())? } else { // by val - if !matches!(parameter, AstStatement::EmptyStatement { .. }) { + if !parameter.is_empty_statement() { self.generate_argument_by_val(type_name, parameter)? } else if let Some(param) = declared_parameters.get(i) { self.generate_empty_expression(param)? @@ -745,7 +752,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { fn generate_argument_by_val( &self, type_name: &str, - param_statement: &AstStatement, + param_statement: &AstNode, ) -> Result, Diagnostic> { Ok(match self.index.find_effective_type_by_name(type_name) { Some(type_info) if type_info.information.is_string() => { @@ -760,7 +767,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { fn generate_string_argument( &self, type_info: &DataType, - argument: &AstStatement, + argument: &AstNode, ) -> Result, Diagnostic> { // allocate a temporary string of correct size and pass it let llvm_type = self @@ -788,11 +795,11 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { /// pointing to the given `argument` fn generate_argument_by_ref( &self, - argument: &AstStatement, + argument: &AstNode, type_name: &str, declared_parameter: Option<&VariableIndexEntry>, ) -> Result, Diagnostic> { - if matches!(argument, AstStatement::EmptyStatement { .. }) { + if argument.is_empty_statement() { // Uninitialized var_output / var_in_out let v_type = self .llvm_index @@ -880,7 +887,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { pub fn generate_variadic_arguments_list( &self, pou: &PouIndexEntry, - variadic_params: &[&AstStatement], + variadic_params: &[&AstNode], ) -> Result>, Diagnostic> { // get the real varargs from the index if let Some((var_args, argument_type)) = self @@ -963,7 +970,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { fn allocate_function_struct_instance( &self, function_name: &str, - context: &AstStatement, + context: &AstNode, ) -> Result, Diagnostic> { let instance_name = format!("{function_name}_instance"); // TODO: Naming convention (see plc_util/src/convention.rs) let function_type = self @@ -990,7 +997,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { pou_name: &str, class_struct: Option>, parameter_struct: PointerValue<'ink>, - passed_parameters: &[&AstStatement], + passed_parameters: &[&AstNode], ) -> Result>, Diagnostic> { let mut result = class_struct .map(|class_struct| { @@ -1061,7 +1068,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { /// first try to find an initial value for the given id /// /// if there is none try to find an initial value for the given type - fn get_initial_value(&self, id: &Option, type_name: &str) -> Option<&AstStatement> { + fn get_initial_value(&self, id: &Option, type_name: &str) -> Option<&AstNode> { self.index.get_initial_value(id).or_else(|| self.index.get_initial_value_for_type(type_name)) } @@ -1076,11 +1083,10 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { &self, param_context: &CallParameterAssignment, ) -> Result>, Diagnostic> { - let parameter_value = match param_context.assignment_statement { + let parameter_value = match param_context.assignment_statement.get_stmt() { // explicit call parameter: foo(param := value) - AstStatement::OutputAssignment { left, right, .. } - | AstStatement::Assignment { left, right, .. } => { - self.generate_formal_parameter(param_context, left, right)?; + AstStatement::OutputAssignment(data) | AstStatement::Assignment(data) => { + self.generate_formal_parameter(param_context, &data.left, &data.right)?; None } // foo(x) @@ -1127,7 +1133,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { //this is VAR_IN_OUT assignemt, so don't load the value, assign the pointer //expression may be empty -> generate a local variable for it - let generated_exp = if matches!(expression, AstStatement::EmptyStatement { .. }) { + let generated_exp = if expression.is_empty_statement() { let temp_type = self.llvm_index.find_associated_type(inner_type_name).ok_or_else(|| { Diagnostic::unknown_type(parameter.get_name(), expression.get_location()) @@ -1153,8 +1159,8 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { fn generate_formal_parameter( &self, param_context: &CallParameterAssignment, - left: &AstStatement, - right: &AstStatement, + left: &AstNode, + right: &AstNode, ) -> Result<(), Diagnostic> { let function_name = param_context.function_name; let parameter_struct = param_context.parameter_struct; @@ -1175,7 +1181,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { .unwrap_or_else(|| self.index.get_void_type().get_type_information()), DataTypeInformation::Pointer { auto_deref: true, .. } ); - if !matches!(right, AstStatement::EmptyStatement { .. }) || is_auto_deref { + if !right.is_empty_statement() || is_auto_deref { self.generate_call_struct_argument_assignment(&CallParameterAssignment { assignment_statement: right, function_name, @@ -1190,10 +1196,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { /// generates an gep-statement and returns the resulting pointer /// /// - `reference_statement` - the statement to get an lvalue from - pub fn generate_lvalue( - &self, - reference_statement: &AstStatement, - ) -> Result, Diagnostic> { + pub fn generate_lvalue(&self, reference_statement: &AstNode) -> Result, Diagnostic> { self.generate_expression_value(reference_statement).and_then(|it| { let v: Result = it.get_basic_value_enum().try_into(); v.map_err(|err| { @@ -1211,7 +1214,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { &self, qualifier: Option<&PointerValue<'ink>>, name: &str, - context: &AstStatement, + context: &AstNode, ) -> Result, Diagnostic> { let offset = &context.get_location(); if let Some(qualifier) = qualifier { @@ -1291,7 +1294,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { fn auto_deref_if_necessary( &self, accessor_ptr: PointerValue<'ink>, - statement: &AstStatement, + statement: &AstNode, ) -> PointerValue<'ink> { if let Some(StatementAnnotation::Variable { is_auto_deref: true, .. }) = self.annotations.get(statement) @@ -1310,7 +1313,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { fn generate_access_for_dimension( &self, dimension: &Dimension, - access_expression: &AstStatement, + access_expression: &AstNode, ) -> Result, Diagnostic> { let start_offset = dimension .start_offset @@ -1347,8 +1350,8 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { /// - `access` the accessor expression (the expression between the brackets: reference[access]) fn generate_element_pointer_for_array( &self, - reference: &AstStatement, - access: &AstStatement, + reference: &AstNode, + access: &AstNode, ) -> Result, Diagnostic> { //Load the reference self.generate_expression_value(reference) @@ -1468,11 +1471,11 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { pub fn create_llvm_binary_expression_for_pointer( &self, operator: &Operator, - left: &AstStatement, + left: &AstNode, left_type: &DataTypeInformation, - right: &AstStatement, + right: &AstNode, right_type: &DataTypeInformation, - expression: &AstStatement, + expression: &AstNode, ) -> Result, Diagnostic> { let left_expr = self.generate_expression(left)?; let right_expr = self.generate_expression(right)?; @@ -1705,7 +1708,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { fn generate_numeric_literal( &self, - stmt: &AstStatement, + stmt: &AstNode, number: &str, ) -> Result, Diagnostic> { let type_hint = self.get_type_hint_for(stmt)?; @@ -1726,10 +1729,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { /// generates the literal statement and returns the resulting value /// /// - `literal_statement` one of LiteralBool, LiteralInteger, LiteralReal, LiteralString - pub fn generate_literal( - &self, - literal_statement: &AstStatement, - ) -> Result, Diagnostic> { + pub fn generate_literal(&self, literal_statement: &AstNode) -> Result, Diagnostic> { let cannot_generate_literal = || { Diagnostic::codegen_error( &format!("Cannot generate Literal for {literal_statement:?}"), @@ -1737,8 +1737,9 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { ) }; - match literal_statement { - AstStatement::Literal { kind, location, .. } => match kind { + let location = &literal_statement.get_location(); + match literal_statement.get_stmt() { + AstStatement::Literal(kind) => match kind { AstLiteral::Bool(b) => self.llvm.create_const_bool(*b).map(ExpressionValue::RValue), AstLiteral::Integer(i, ..) => self .generate_numeric_literal(literal_statement, i.to_string().as_str()) @@ -1784,7 +1785,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { } // if there is just one assignment, this may be an struct-initialization (TODO this is not very elegant :-/ ) AstStatement::Assignment { .. } => self.generate_literal_struct(literal_statement), - AstStatement::CastStatement { target, .. } => self.generate_expression_value(target), + AstStatement::CastStatement(data) => self.generate_expression_value(&data.target), _ => Err(cannot_generate_literal()), } } @@ -1792,7 +1793,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { /// generates the string-literal `value` represented by `literal_statement` fn generate_string_literal( &self, - literal_statement: &AstStatement, + literal_statement: &AstNode, value: &str, location: &SourceLocation, ) -> Result, Diagnostic> { @@ -1870,10 +1871,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { /// - 1st try: fetch the type associated via the `self.annotations` /// - 2nd try: fetch the type associated with the given `default_type_name` /// - else return an `Err` - pub fn get_type_hint_info_for( - &self, - statement: &AstStatement, - ) -> Result<&DataTypeInformation, Diagnostic> { + pub fn get_type_hint_info_for(&self, statement: &AstNode) -> Result<&DataTypeInformation, Diagnostic> { self.get_type_hint_for(statement).map(DataType::get_type_information) } @@ -1881,7 +1879,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { /// - 1st try: fetch the type associated via the `self.annotations` /// - 2nd try: fetch the type associated with the given `default_type_name` /// - else return an `Err` - pub fn get_type_hint_for(&self, statement: &AstStatement) -> Result<&DataType, Diagnostic> { + pub fn get_type_hint_for(&self, statement: &AstNode) -> Result<&DataType, Diagnostic> { self.annotations .get_type_hint(statement, self.index) .or_else(|| self.annotations.get_type(statement, self.index)) @@ -1894,34 +1892,31 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { } /// generates a struct literal value with the given value assignments (ExpressionList) - fn generate_literal_struct( - &self, - assignments: &AstStatement, - ) -> Result, Diagnostic> { + fn generate_literal_struct(&self, assignments: &AstNode) -> Result, Diagnostic> { if let DataTypeInformation::Struct { name: struct_name, members, .. } = self.get_type_hint_info_for(assignments)? { let mut uninitialized_members: HashSet<&VariableIndexEntry> = HashSet::from_iter(members); let mut member_values: Vec<(u32, BasicValueEnum<'ink>)> = Vec::new(); for assignment in flatten_expression_list(assignments) { - if let AstStatement::Assignment { left, right, .. } = assignment { + if let AstStatement::Assignment(data) = assignment.get_stmt() { if let Some(StatementAnnotation::Variable { qualified_name, .. }) = - self.annotations.get(left.as_ref()) + self.annotations.get(data.left.as_ref()) { let member: &VariableIndexEntry = self.index.find_fully_qualified_variable(qualified_name).ok_or_else(|| { - Diagnostic::unresolved_reference(qualified_name, left.get_location()) + Diagnostic::unresolved_reference(qualified_name, data.left.get_location()) })?; let index_in_parent = member.get_location_in_parent(); - let value = self.generate_expression(right)?; + let value = self.generate_expression(data.right.as_ref())?; uninitialized_members.remove(member); member_values.push((index_in_parent, value)); } else { return Err(Diagnostic::codegen_error( "struct member lvalue required as left operand of assignment", - left.get_location(), + data.left.get_location(), )); } } else { @@ -1977,10 +1972,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { } /// generates an array literal with the given optional elements (represented as an ExpressionList) - pub fn generate_literal_array( - &self, - initializer: &AstStatement, - ) -> Result, Diagnostic> { + pub fn generate_literal_array(&self, initializer: &AstNode) -> Result, Diagnostic> { let array_value = self.generate_literal_array_value( initializer, self.get_type_hint_info_for(initializer)?, @@ -1996,7 +1988,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { /// i16-array-value fn generate_literal_array_value( &self, - elements: &AstStatement, + elements: &AstNode, data_type: &DataTypeInformation, location: &SourceLocation, ) -> Result, Diagnostic> { @@ -2023,8 +2015,8 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { // flatten_expression_list will will return a vec of only assignments let elements = if self.index.get_effective_type_or_void_by_name(inner_type.get_name()).information.is_struct() { - match elements { - AstStatement::ExpressionList { expressions, .. } => expressions.iter().collect(), + match elements.get_stmt() { + AstStatement::ExpressionList(expressions) => expressions.iter().collect(), _ => unreachable!("This should always be an expression list"), } } else { @@ -2082,8 +2074,8 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { pub fn generate_bool_binary_expression( &self, operator: &Operator, - left: &AstStatement, - right: &AstStatement, + left: &AstNode, + right: &AstNode, ) -> Result, Diagnostic> { match operator { Operator::And | Operator::Or => { @@ -2133,8 +2125,8 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { pub fn generate_bool_short_circuit_expression( &self, operator: &Operator, - left: &AstStatement, - right: &AstStatement, + left: &AstNode, + right: &AstNode, ) -> Result, Diagnostic> { let builder = &self.llvm.builder; let lhs = to_i1(self.generate_expression(left)?.into_int_value(), builder); @@ -2187,9 +2179,9 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { /// expressions fn create_llvm_generic_binary_expression( &self, - left: &AstStatement, - right: &AstStatement, - binary_statement: &AstStatement, + left: &AstNode, + right: &AstNode, + binary_statement: &AstNode, ) -> Result, Diagnostic> { if let Some(StatementAnnotation::ReplacementAst { statement }) = self.annotations.get(binary_statement) @@ -2213,7 +2205,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { &self, left: inkwell::values::PointerValue, left_type: &DataTypeInformation, - right_statement: &AstStatement, + right_statement: &AstNode, ) -> Result<(), Diagnostic> { let right_type = self.annotations.get_type_or_void(right_statement, self.index).get_type_information(); @@ -2305,17 +2297,17 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { } /// returns an optional name used for a temporary variable when loading a pointer represented by `expression` - fn get_load_name(&self, expression: &AstStatement) -> Option { - match expression { - AstStatement::ReferenceExpr { access: ReferenceAccess::Deref, .. } - | AstStatement::ReferenceExpr { access: ReferenceAccess::Index(_), .. } => { + fn get_load_name(&self, expression: &AstNode) -> Option { + match expression.get_stmt() { + AstStatement::ReferenceExpr(ReferenceExpr { access: ReferenceAccess::Deref, .. }) + | AstStatement::ReferenceExpr(ReferenceExpr { access: ReferenceAccess::Index(_), .. }) => { Some("load_tmpVar".to_string()) } AstStatement::ReferenceExpr { .. } => expression .get_flat_reference_name() .map(|name| format!("{}{}{}", self.temp_variable_prefix, name, self.temp_variable_suffix)) .or_else(|| Some(self.temp_variable_prefix.clone())), - AstStatement::Identifier { name, .. } => Some(format!("{}{}", name, self.temp_variable_suffix)), + AstStatement::Identifier(name, ..) => Some(format!("{}{}", name, self.temp_variable_suffix)), _ => None, } } @@ -2325,7 +2317,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { &self, reference: ExpressionValue<'ink>, reference_annotation: &StatementAnnotation, - access: &AstStatement, + access: &AstNode, ) -> Result, ()> { let builder = &self.llvm.builder; @@ -2410,8 +2402,8 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { fn generate_reference_expression( &self, access: &ReferenceAccess, - base: Option<&AstStatement>, - original_expression: &AstStatement, + base: Option<&AstNode>, + original_expression: &AstNode, ) -> Result, Diagnostic> { match (access, base) { @@ -2419,11 +2411,11 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { (ReferenceAccess::Member(member), base) => { let base_value = base.map(|it| self.generate_expression_value(it)).transpose()?; - if let AstStatement::DirectAccess { access, index, .. } = member.as_ref() { + if let AstStatement::DirectAccess (data) = member.as_ref().get_stmt() { let (Some(base), Some(base_value)) = (base, base_value) else { return Err(Diagnostic::codegen_error("Cannot generate DirectAccess without base value.", original_expression.get_location())); }; - self.generate_direct_access_expression(base, &base_value, member, access, index) + self.generate_direct_access_expression(base, &base_value, member, &data.access, &data.index) } else { let member_name = member.get_flat_reference_name().unwrap_or("unknown"); self.create_llvm_pointer_value_for_reference( @@ -2454,7 +2446,7 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { // INT#target (INT = base) (ReferenceAccess::Cast(target), Some(_base)) => { - if matches!(target.as_ref(), AstStatement::Identifier { .. }) { + if target.as_ref().is_identifier() { let mr = AstFactory::create_member_reference(target.as_ref().clone(), None, target.get_id()); self.generate_expression_value(&mr) @@ -2497,11 +2489,11 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { /// - `access` the type of access (see `B` above) fn generate_direct_access_expression( &self, - qualifier: &AstStatement, + qualifier: &AstNode, qualifier_value: &ExpressionValue<'ink>, - member: &AstStatement, + member: &AstNode, access: &DirectAccessType, - index: &AstStatement, + index: &AstNode, ) -> Result, Diagnostic> { let loaded_base_value = qualifier_value.as_r_value(self.llvm, self.get_load_name(qualifier)); let datatype = self.get_type_hint_info_for(member)?; @@ -2532,21 +2524,21 @@ impl<'ink, 'b> ExpressionCodeGenerator<'ink, 'b> { /// as well as the parameter value (right side) ´param := value´ => ´value´ /// and `true` for implicit / `false` for explicit parameters pub fn get_implicit_call_parameter<'a>( - param_statement: &'a AstStatement, + param_statement: &'a AstNode, declared_parameters: &[&VariableIndexEntry], idx: usize, -) -> Result<(usize, &'a AstStatement, bool), Diagnostic> { - let (location, param_statement, is_implicit) = match param_statement { - AstStatement::Assignment { left, right, .. } | AstStatement::OutputAssignment { left, right, .. } => { +) -> Result<(usize, &'a AstNode, bool), Diagnostic> { + let (location, param_statement, is_implicit) = match param_statement.get_stmt() { + AstStatement::Assignment(data) | AstStatement::OutputAssignment(data) => { //explicit - let Some(left_name) = left.as_ref().get_flat_reference_name() else { + let Some(left_name) = data.left.as_ref().get_flat_reference_name() else { return Err(Diagnostic::reference_expected(param_statement.get_location())); }; let loc = declared_parameters .iter() .position(|p| p.get_name() == left_name) - .ok_or_else(|| Diagnostic::unresolved_reference(left_name, left.get_location()))?; - (loc, right.as_ref(), false) + .ok_or_else(|| Diagnostic::unresolved_reference(left_name, data.left.get_location()))?; + (loc, data.right.as_ref(), false) } _ => { //implicit diff --git a/src/codegen/generators/pou_generator.rs b/src/codegen/generators/pou_generator.rs index a8cb2f01fa..a59d86bb14 100644 --- a/src/codegen/generators/pou_generator.rs +++ b/src/codegen/generators/pou_generator.rs @@ -35,7 +35,7 @@ use inkwell::{ types::{BasicType, StructType}, values::PointerValue, }; -use plc_ast::ast::{AstStatement, Implementation, PouType}; +use plc_ast::ast::{AstNode, Implementation, PouType}; use plc_diagnostics::diagnostics::{Diagnostic, INTERNAL_LLVM_ERROR}; use plc_source::source_location::SourceLocation; @@ -616,7 +616,7 @@ impl<'ink, 'cg> PouGenerator<'ink, 'cg> { &self, variable: &&VariableIndexEntry, variable_to_initialize: PointerValue, - initializer_statement: Option<&AstStatement>, + initializer_statement: Option<&AstNode>, exp_gen: &ExpressionCodeGenerator, ) -> Result<(), Diagnostic> { let variable_llvm_type = self diff --git a/src/codegen/generators/statement_generator.rs b/src/codegen/generators/statement_generator.rs index b820cf344b..311697fb12 100644 --- a/src/codegen/generators/statement_generator.rs +++ b/src/codegen/generators/statement_generator.rs @@ -18,7 +18,9 @@ use inkwell::{ values::{BasicValueEnum, FunctionValue, PointerValue}, }; use plc_ast::{ - ast::{flatten_expression_list, AstFactory, AstStatement, Operator, ReferenceAccess}, + ast::{ + flatten_expression_list, AstFactory, AstNode, AstStatement, Operator, ReferenceAccess, ReferenceExpr, + }, control_statements::{AstControlStatement, ConditionalBlock}, }; use plc_diagnostics::diagnostics::{Diagnostic, INTERNAL_LLVM_ERROR}; @@ -91,7 +93,7 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { } /// generates a list of statements - pub fn generate_body(&self, statements: &[AstStatement]) -> Result<(), Diagnostic> { + pub fn generate_body(&self, statements: &[AstNode]) -> Result<(), Diagnostic> { for s in statements { self.generate_statement(s)?; } @@ -112,16 +114,16 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { /// genertes a single statement /// /// - `statement` the statement to be generated - pub fn generate_statement(&self, statement: &AstStatement) -> Result<(), Diagnostic> { - match statement { - AstStatement::EmptyStatement { .. } => { + pub fn generate_statement(&self, statement: &AstNode) -> Result<(), Diagnostic> { + match statement.get_stmt() { + AstStatement::EmptyStatement(..) => { //nothing to generate } - AstStatement::Assignment { left, right, .. } => { - self.generate_assignment_statement(left, right)?; + AstStatement::Assignment(data, ..) => { + self.generate_assignment_statement(&data.left, &data.right)?; } - AstStatement::ControlStatement { kind: ctl_statement, .. } => { + AstStatement::ControlStatement(ctl_statement, ..) => { self.generate_control_statement(ctl_statement)? } AstStatement::ReturnStatement { .. } => { @@ -129,7 +131,7 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { self.pou_generator.generate_return_statement(self.function_context, self.llvm_index)?; self.generate_buffer_block(); } - AstStatement::ExitStatement { location, .. } => { + AstStatement::ExitStatement(_) => { if let Some(exit_block) = &self.current_loop_exit { self.register_debug_location(statement); self.llvm.builder.build_unconditional_branch(*exit_block); @@ -137,18 +139,18 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { } else { return Err(Diagnostic::codegen_error( "Cannot break out of loop when not inside a loop", - location.clone(), + statement.get_location(), )); } } - AstStatement::ContinueStatement { location, .. } => { + AstStatement::ContinueStatement(_) => { if let Some(cont_block) = &self.current_loop_continue { self.llvm.builder.build_unconditional_branch(*cont_block); self.generate_buffer_block(); } else { return Err(Diagnostic::codegen_error( "Cannot continue loop when not inside a loop", - location.clone(), + statement.get_location(), )); } } @@ -190,8 +192,8 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { /// `right_statement` the right side of the assignment pub fn generate_assignment_statement( &self, - left_statement: &AstStatement, - right_statement: &AstStatement, + left_statement: &AstNode, + right_statement: &AstNode, ) -> Result<(), Diagnostic> { //Register any debug info for the store self.register_debug_location(left_statement); @@ -200,7 +202,7 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { return self.generate_direct_access_assignment(left_statement, right_statement); } //TODO: Also hacky but for now we cannot generate assignments for hardware access - if matches!(left_statement, AstStatement::HardwareAccess { .. }) { + if matches!(left_statement.get_stmt(), AstStatement::HardwareAccess { .. }) { return Ok(()); } let exp_gen = self.create_expr_generator(); @@ -226,7 +228,7 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { Ok(()) } - fn register_debug_location(&self, statement: &AstStatement) { + fn register_debug_location(&self, statement: &AstNode) { let line = statement.get_location().get_line(); let column = statement.get_location().get_column(); self.debug.set_debug_location(self.llvm, &self.function_context.function, line, column); @@ -234,8 +236,8 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { fn generate_direct_access_assignment( &self, - left_statement: &AstStatement, - right_statement: &AstStatement, + left_statement: &AstNode, + right_statement: &AstNode, ) -> Result<(), Diagnostic> { //TODO : Validation let exp_gen = self.create_expr_generator(); @@ -266,10 +268,10 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { let left = left_expression_value.get_basic_value_enum().into_pointer_value(); //Build index if let Some((element, direct_access)) = access_sequence.split_first() { - let mut rhs = if let AstStatement::DirectAccess { access, index, .. } = element { + let mut rhs = if let AstStatement::DirectAccess(data, ..) = element.get_stmt() { exp_gen.generate_direct_access_index( - access, - index, + &data.access, + &data.index, right_type.get_type_information(), left_type, ) @@ -280,10 +282,10 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { )) }?; for element in direct_access { - let rhs_next = if let AstStatement::DirectAccess { access, index, .. } = element { + let rhs_next = if let AstStatement::DirectAccess(data, ..) = element.get_stmt() { exp_gen.generate_direct_access_index( - access, - index, + &data.access, + &data.index, right_type.get_type_information(), left_type, ) @@ -335,11 +337,11 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { /// - `body` the statements inside the for-loop fn generate_for_statement( &self, - counter: &AstStatement, - start: &AstStatement, - end: &AstStatement, - by_step: &Option>, - body: &[AstStatement], + counter: &AstNode, + start: &AstNode, + end: &AstNode, + by_step: &Option>, + body: &[AstNode], ) -> Result<(), Diagnostic> { let (builder, current_function, context) = self.get_llvm_deps(); self.generate_assignment_statement(counter, start)?; @@ -412,53 +414,42 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { fn generate_compare_expression( &'a self, - counter: &AstStatement, - end: &AstStatement, - start: &AstStatement, + counter: &AstNode, + end: &AstNode, + start: &AstNode, exp_gen: &'a ExpressionCodeGenerator, ) -> Result, Diagnostic> { - let counter_end_ge = AstStatement::BinaryExpression { - id: self.annotations.get_bool_id(), - operator: Operator::GreaterOrEqual, - left: Box::new(counter.to_owned()), - right: Box::new(end.to_owned()), - }; - let counter_start_ge = AstStatement::BinaryExpression { - id: self.annotations.get_bool_id(), - operator: Operator::GreaterOrEqual, - left: Box::new(counter.to_owned()), - right: Box::new(start.to_owned()), - }; - let counter_end_le = AstStatement::BinaryExpression { - id: self.annotations.get_bool_id(), - operator: Operator::LessOrEqual, - left: Box::new(counter.to_owned()), - right: Box::new(end.to_owned()), - }; - let counter_start_le = AstStatement::BinaryExpression { - id: self.annotations.get_bool_id(), - operator: Operator::LessOrEqual, - left: Box::new(counter.to_owned()), - right: Box::new(start.to_owned()), - }; - let and_1 = AstStatement::BinaryExpression { - id: self.annotations.get_bool_id(), - operator: Operator::And, - left: Box::new(counter_end_le), - right: Box::new(counter_start_ge), - }; - let and_2 = AstStatement::BinaryExpression { - id: self.annotations.get_bool_id(), - operator: Operator::And, - left: Box::new(counter_end_ge), - right: Box::new(counter_start_le), - }; - let or = AstStatement::BinaryExpression { - id: self.annotations.get_bool_id(), - operator: Operator::Or, - left: Box::new(and_1), - right: Box::new(and_2), - }; + let bool_id = self.annotations.get_bool_id(); + let counter_end_ge = AstFactory::create_binary_expression( + counter.clone(), + Operator::GreaterOrEqual, + end.clone(), + bool_id, + ); + let counter_start_ge = AstFactory::create_binary_expression( + counter.clone(), + Operator::GreaterOrEqual, + start.clone(), + bool_id, + ); + let counter_end_le = AstFactory::create_binary_expression( + counter.clone(), + Operator::LessOrEqual, + end.clone(), + bool_id, + ); + let counter_start_le = AstFactory::create_binary_expression( + counter.clone(), + Operator::LessOrEqual, + start.clone(), + bool_id, + ); + let and_1 = + AstFactory::create_binary_expression(counter_end_le, Operator::And, counter_start_ge, bool_id); + let and_2 = + AstFactory::create_binary_expression(counter_end_ge, Operator::And, counter_start_le, bool_id); + let or = AstFactory::create_binary_expression(and_1, Operator::Or, and_2, bool_id); + self.register_debug_location(&or); let or_eval = exp_gen.generate_expression(&or)?; Ok(or_eval) @@ -476,9 +467,9 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { /// - `else_body` the statements in the else-block fn generate_case_statement( &self, - selector: &AstStatement, + selector: &AstNode, conditional_blocks: &[ConditionalBlock], - else_body: &[AstStatement], + else_body: &[AstNode], ) -> Result<(), Diagnostic> { let (builder, current_function, context) = self.get_llvm_deps(); //Continue @@ -500,14 +491,14 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { //flatten the expression list into a vector of expressions let expressions = flatten_expression_list(&conditional_block.condition); for s in expressions { - if let AstStatement::RangeStatement { start, end, .. } = s { + if let AstStatement::RangeStatement(data, ..) = s.get_stmt() { //if this is a range statement, we generate an if (x >= start && x <= end) then the else-section builder.position_at_end(current_else_block); // since the if's generate additional blocks, we use the last one as the else-section current_else_block = self.generate_case_range_condition( selector, - start.as_ref(), - end.as_ref(), + data.start.as_ref(), + data.end.as_ref(), case_block, )?; } else { @@ -543,9 +534,9 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { /// fn generate_case_range_condition( &self, - selector: &AstStatement, - start: &AstStatement, - end: &AstStatement, + selector: &AstNode, + start: &AstNode, + end: &AstNode, match_block: BasicBlock, ) -> Result { let (builder, _, context) = self.get_llvm_deps(); @@ -592,11 +583,7 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { /// /// - `condition` the while's condition /// - `body` the while's body statements - fn generate_while_statement( - &self, - condition: &AstStatement, - body: &[AstStatement], - ) -> Result<(), Diagnostic> { + fn generate_while_statement(&self, condition: &AstNode, body: &[AstNode]) -> Result<(), Diagnostic> { let builder = &self.llvm.builder; let basic_block = builder.get_insert_block().expect(INTERNAL_LLVM_ERROR); let (condition_block, _) = self.generate_base_while_statement(condition, body)?; @@ -619,11 +606,7 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { /// /// - `condition` the repeat's condition /// - `body` the repeat's body statements - fn generate_repeat_statement( - &self, - condition: &AstStatement, - body: &[AstStatement], - ) -> Result<(), Diagnostic> { + fn generate_repeat_statement(&self, condition: &AstNode, body: &[AstNode]) -> Result<(), Diagnostic> { let builder = &self.llvm.builder; let basic_block = builder.get_insert_block().expect(INTERNAL_LLVM_ERROR); @@ -643,8 +626,8 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { /// utility method for while and repeat loops fn generate_base_while_statement( &self, - condition: &AstStatement, - body: &[AstStatement], + condition: &AstNode, + body: &[AstNode], ) -> Result<(BasicBlock, BasicBlock), Diagnostic> { let (builder, current_function, context) = self.get_llvm_deps(); let condition_check = context.append_basic_block(current_function, "condition_check"); @@ -686,7 +669,7 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { fn generate_if_statement( &self, conditional_blocks: &[ConditionalBlock], - else_body: &[AstStatement], + else_body: &[AstNode], ) -> Result<(), Diagnostic> { let (builder, current_function, context) = self.get_llvm_deps(); let mut blocks = vec![builder.get_insert_block().expect(INTERNAL_LLVM_ERROR)]; @@ -749,12 +732,14 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { /// we want to deconstruct the sequence into the base-statement (a.b.c) and the sequence /// of direct-access commands (vec![%W3, %X2]) fn collect_base_and_direct_access_for_assignment( - left_statement: &AstStatement, -) -> Option<(&AstStatement, Vec<&AstStatement>)> { + left_statement: &AstNode, +) -> Option<(&AstNode, Vec<&AstNode>)> { let mut current = Some(left_statement); let mut access_sequence = Vec::new(); - while let Some(AstStatement::ReferenceExpr { access: ReferenceAccess::Member(m), base, .. }) = current { - if matches!(m.as_ref(), AstStatement::DirectAccess { .. }) { + while let Some(AstStatement::ReferenceExpr(ReferenceExpr { access: ReferenceAccess::Member(m), base })) = + current.map(|it| it.get_stmt()) + { + if matches!(m.get_stmt(), AstStatement::DirectAccess { .. }) { access_sequence.insert(0, m.as_ref()); current = base.as_deref(); } else { diff --git a/src/index.rs b/src/index.rs index 8af273423b..df9b17b1ae 100644 --- a/src/index.rs +++ b/src/index.rs @@ -7,7 +7,8 @@ use crate::{ use indexmap::IndexMap; use itertools::Itertools; use plc_ast::ast::{ - AstStatement, DirectAccessType, GenericBinding, HardwareAccessType, LinkageType, PouType, TypeNature, + AstNode, AstStatement, DirectAccessType, GenericBinding, HardwareAccessType, LinkageType, PouType, + TypeNature, }; use plc_diagnostics::diagnostics::Diagnostic; use plc_source::source_location::SourceLocation; @@ -65,12 +66,13 @@ pub struct HardwareBinding { } impl HardwareBinding { - fn from_statement(index: &mut Index, it: &AstStatement, scope: Option) -> Option { - if let AstStatement::HardwareAccess { access, address, direction, location, .. } = it { + fn from_statement(index: &mut Index, it: &AstNode, scope: Option) -> Option { + if let AstStatement::HardwareAccess(data) = it.get_stmt() { Some(HardwareBinding { - access: *access, - direction: *direction, - entries: address + access: data.access, + direction: data.direction, + entries: data + .address .iter() .map(|expr| { index.constant_expressions.add_constant_expression( @@ -80,7 +82,7 @@ impl HardwareBinding { ) }) .collect(), - location: location.clone(), + location: it.get_location(), }) } else { None @@ -1217,7 +1219,7 @@ impl Index { self.get_types().get(&type_name.to_lowercase()).unwrap_or_else(|| panic!("{type_name} not found")) } - pub fn get_initial_value(&self, id: &Option) -> Option<&AstStatement> { + pub fn get_initial_value(&self, id: &Option) -> Option<&AstNode> { self.get_const_expressions().maybe_get_constant_statement(id) } @@ -1235,7 +1237,7 @@ impl Index { /// Returns the initioal value registered for the given data_type. /// If the given dataType has no initial value AND it is an Alias or SubRange (referencing another type) /// this method tries to obtain the default value from the referenced type. - pub fn get_initial_value_for_type(&self, type_name: &str) -> Option<&AstStatement> { + pub fn get_initial_value_for_type(&self, type_name: &str) -> Option<&AstNode> { let mut dt = self.type_index.find_type(type_name); let mut initial_value = dt.and_then(|it| it.initial_value); diff --git a/src/index/const_expressions.rs b/src/index/const_expressions.rs index 5008206119..b1f72f4e76 100644 --- a/src/index/const_expressions.rs +++ b/src/index/const_expressions.rs @@ -1,7 +1,10 @@ // Copyright (c) 2020 Ghaith Hachem and Mathias Rieder use generational_arena::{Arena, Iter}; -use plc_ast::{ast::AstStatement, literals::AstLiteral}; +use plc_ast::{ + ast::{AstNode, AstStatement}, + literals::AstLiteral, +}; use plc_source::source_location::SourceLocation; pub type ConstId = generational_arena::Index; @@ -18,7 +21,7 @@ struct ConstWrapper { } impl ConstWrapper { - pub fn get_statement(&self) -> &AstStatement { + pub fn get_statement(&self) -> &AstNode { self.expr.get_statement() } } @@ -29,22 +32,22 @@ impl ConstWrapper { #[derive(Debug)] pub enum ConstExpression { Unresolved { - statement: AstStatement, + statement: AstNode, /// optional qualifier used when evaluating this expression /// e.g. a const-expression inside a POU would use this POU's name as a /// qualifier. scope: Option, }, - Resolved(AstStatement), + Resolved(AstNode), Unresolvable { - statement: AstStatement, + statement: AstNode, reason: UnresolvableKind, }, } impl ConstExpression { /// returns the const-expression represented as an AST-element - pub fn get_statement(&self) -> &AstStatement { + pub fn get_statement(&self) -> &AstNode { match &self { ConstExpression::Unresolved { statement, .. } | ConstExpression::Resolved(statement) @@ -66,7 +69,7 @@ impl ConstExpression { } pub(crate) fn is_default(&self) -> bool { - matches!(self.get_statement(), AstStatement::DefaultValue { .. }) + self.get_statement().is_default_value() } } @@ -111,7 +114,7 @@ impl ConstExpressions { /// - `scope`: the scope this expression needs to be resolved in (e.g. a POU's name) pub fn add_expression( &mut self, - statement: AstStatement, + statement: AstNode, target_type_name: String, scope: Option, ) -> ConstId { @@ -121,7 +124,7 @@ impl ConstExpressions { /// returns the expression associated with the given `id` together with an optional /// `qualifier` that represents the expressions scope (e.g. the host's POU-name) - pub fn find_expression(&self, id: &ConstId) -> (Option<&AstStatement>, Option<&str>) { + pub fn find_expression(&self, id: &ConstId) -> (Option<&AstNode>, Option<&str>) { self.expressions .get(*id) .filter(|it| !it.expr.is_default()) @@ -141,7 +144,7 @@ impl ConstExpressions { } /// clones the expression in the ConstExpressions and returns all of its elements - pub fn clone(&self, id: &ConstId) -> Option<(AstStatement, String, Option)> { + pub fn clone(&self, id: &ConstId) -> Option<(AstNode, String, Option)> { self.expressions.get(*id).map(|it| match &it.expr { ConstExpression::Unresolved { statement, scope } => { (statement.clone(), it.target_type_name.clone(), scope.clone()) @@ -154,7 +157,7 @@ impl ConstExpressions { /// marks the const-expression represented by the given `id` as resolvend and stores the the /// given `new_statement` as it's resolved value. - pub fn mark_resolved(&mut self, id: &ConstId, new_statement: AstStatement) -> Result<(), String> { + pub fn mark_resolved(&mut self, id: &ConstId, new_statement: AstNode) -> Result<(), String> { let wrapper = self .expressions .get_mut(*id) @@ -183,7 +186,7 @@ impl ConstExpressions { /// - `scope`: the scope this expression needs to be resolved in (e.g. a POU's name) pub fn add_constant_expression( &mut self, - expr: AstStatement, + expr: AstNode, target_type: String, scope: Option, ) -> ConstId { @@ -195,7 +198,7 @@ impl ConstExpressions { /// otherwhise use `add_constant_expression` pub fn maybe_add_constant_expression( &mut self, - expr: Option, + expr: Option, target_type_name: &str, scope: Option, ) -> Option { @@ -206,19 +209,19 @@ impl ConstExpressions { /// if the given `id` is `None`, this method returns `None` /// use this only as a shortcut if you have an Option - e.g. an optional initializer. /// otherwhise use `get_constant_expression` - pub fn maybe_get_constant_statement(&self, id: &Option) -> Option<&AstStatement> { + pub fn maybe_get_constant_statement(&self, id: &Option) -> Option<&AstNode> { id.as_ref().and_then(|it| self.get_constant_statement(it)) } /// query the constants arena for an expression associated with the given `id` - pub fn get_constant_statement(&self, id: &ConstId) -> Option<&AstStatement> { + pub fn get_constant_statement(&self, id: &ConstId) -> Option<&AstNode> { self.find_expression(id).0 } /// query the constants arena for a resolved expression associated with the given `id`. /// this operation returns None, if an unresolved/unresolvable expression was registered /// for the given id (for different behavior see `get_constant_statement`) - pub fn get_resolved_constant_statement(&self, id: &ConstId) -> Option<&AstStatement> { + pub fn get_resolved_constant_statement(&self, id: &ConstId) -> Option<&AstNode> { self.find_const_expression(id).filter(|it| it.is_resolved()).map(ConstExpression::get_statement) } @@ -227,8 +230,8 @@ impl ConstExpressions { /// complex one (not a LiteralInteger) pub fn get_constant_int_statement_value(&self, id: &ConstId) -> Result { self.get_constant_statement(id).ok_or_else(|| "Cannot find constant expression".into()).and_then( - |it| match it { - AstStatement::Literal { kind: AstLiteral::Integer(i), .. } => Ok(*i), + |it| match it.get_stmt() { + AstStatement::Literal(AstLiteral::Integer(i)) => Ok(*i), _ => Err(format!("Cannot extract int constant from {it:#?}")), }, ) @@ -240,7 +243,7 @@ impl ConstExpressions { } impl<'a> IntoIterator for &'a ConstExpressions { - type Item = (ConstId, &'a AstStatement); + type Item = (ConstId, &'a AstNode); type IntoIter = IntoStatementIter<'a>; fn into_iter(self) -> Self::IntoIter { @@ -253,7 +256,7 @@ pub struct IntoStatementIter<'a> { } impl<'a> Iterator for IntoStatementIter<'a> { - type Item = (ConstId, &'a AstStatement); + type Item = (ConstId, &'a AstNode); fn next(&mut self) -> Option { self.inner.next().map(|(idx, expr)| (idx, expr.get_statement())) diff --git a/src/index/tests/index_tests.rs b/src/index/tests/index_tests.rs index 372317d663..f1cfbdbd96 100644 --- a/src/index/tests/index_tests.rs +++ b/src/index/tests/index_tests.rs @@ -1,8 +1,7 @@ // Copyright (c) 2020 Ghaith Hachem and Mathias Rieder use insta::assert_debug_snapshot; use plc_ast::ast::{ - pre_process, AstStatement, DataType, GenericBinding, LinkageType, Operator, TypeNature, - UserTypeDeclaration, + pre_process, AstFactory, DataType, GenericBinding, LinkageType, Operator, TypeNature, UserTypeDeclaration, }; use plc_ast::provider::IdProvider; use plc_source::source_location::{SourceLocation, SourceLocationFactory}; @@ -1089,12 +1088,12 @@ fn array_dimensions_are_stored_in_the_const_expression_arena() { assert_eq!( format!( "{:#?}", - AstStatement::BinaryExpression { - id: 0, - operator: Operator::Minus, - left: Box::new(crate::parser::tests::ref_to("LEN")), - right: Box::new(crate::parser::tests::literal_int(1)) - } + AstFactory::create_binary_expression( + crate::parser::tests::ref_to("LEN"), + Operator::Minus, + crate::parser::tests::literal_int(1), + 0 + ) ), format!("{end_0:#?}") ); @@ -1151,12 +1150,12 @@ fn string_dimensions_are_stored_in_the_const_expression_arena() { assert_eq!( format!( "{:#?}", - &AstStatement::BinaryExpression { - id: actual_len_expression.get_id(), - left: Box::new(actual_len_expression.clone()), - operator: Operator::Plus, - right: Box::new(crate::parser::tests::literal_int(1)) - } + AstFactory::create_binary_expression( + actual_len_expression.clone(), + Operator::Plus, + literal_int(1), + actual_len_expression.get_id() + ) ), format!("{:#?}", index.get_const_expressions().get_constant_statement(expr).unwrap()) ); diff --git a/src/index/visitor.rs b/src/index/visitor.rs index 6db231ef86..bc34fc3ca9 100644 --- a/src/index/visitor.rs +++ b/src/index/visitor.rs @@ -3,8 +3,9 @@ use super::{HardwareBinding, PouIndexEntry, VariableIndexEntry, VariableType}; use crate::index::{ArgumentType, Index, MemberInfo}; use crate::typesystem::{self, *}; use plc_ast::ast::{ - self, ArgumentProperty, AstStatement, CompilationUnit, DataType, DataTypeDeclaration, Implementation, - Pou, PouType, TypeNature, UserTypeDeclaration, Variable, VariableBlock, VariableBlockType, + self, ArgumentProperty, Assignment, AstFactory, AstNode, AstStatement, CompilationUnit, DataType, + DataTypeDeclaration, Implementation, Pou, PouType, RangeStatement, TypeNature, UserTypeDeclaration, + Variable, VariableBlock, VariableBlockType, }; use plc_ast::literals::AstLiteral; use plc_diagnostics::diagnostics::Diagnostic; @@ -337,7 +338,7 @@ fn visit_data_type(index: &mut Index, type_declaration: &UserTypeDeclaration) { for ele in ast::flatten_expression_list(elements) { let element_name = ast::get_enum_element_name(ele); - if let AstStatement::Assignment { right, .. } = ele { + if let AstStatement::Assignment(Assignment { right, .. }) = ele.get_stmt() { let init = index.get_mut_const_expressions().add_constant_expression( right.as_ref().clone(), numeric_type.clone(), @@ -364,7 +365,9 @@ fn visit_data_type(index: &mut Index, type_declaration: &UserTypeDeclaration) { } DataType::SubRangeType { name: Some(name), referenced_type, bounds } => { - let information = if let Some(AstStatement::RangeStatement { start, end, .. }) = bounds { + let information = if let Some(AstStatement::RangeStatement(RangeStatement { start, end })) = + bounds.as_ref().map(|it| it.get_stmt()) + { DataTypeInformation::SubRange { name: name.into(), referenced_type: referenced_type.into(), @@ -421,21 +424,21 @@ fn visit_data_type(index: &mut Index, type_declaration: &UserTypeDeclaration) { let encoding = if *is_wide { StringEncoding::Utf16 } else { StringEncoding::Utf8 }; let size = match size { - Some(AstStatement::Literal { kind: AstLiteral::Integer(value), .. }) => { + Some(AstNode { stmt: AstStatement::Literal(AstLiteral::Integer(value)), .. }) => { TypeSize::from_literal((value + 1) as i64) } Some(statement) => { // construct a "x + 1" expression because we need one additional character for \0 terminator - let len_plus_1 = AstStatement::BinaryExpression { - id: statement.get_id(), - left: Box::new(statement.clone()), - operator: ast::Operator::Plus, - right: Box::new(AstStatement::new_literal( + let len_plus_1 = AstFactory::create_binary_expression( + statement.clone(), + ast::Operator::Plus, + AstNode::new_literal( AstLiteral::new_integer(1), statement.get_id(), statement.get_location(), - )), - }; + ), + statement.get_id(), + ); TypeSize::from_expression(index.get_mut_const_expressions().add_constant_expression( len_plus_1, @@ -504,15 +507,15 @@ fn visit_data_type(index: &mut Index, type_declaration: &UserTypeDeclaration) { /// END_STRUCT /// ``` fn visit_variable_length_array( - bounds: &AstStatement, + bounds: &AstNode, referenced_type: &DataTypeDeclaration, name: &str, index: &mut Index, type_declaration: &UserTypeDeclaration, ) { - let ndims = match bounds { - AstStatement::VlaRangeStatement { .. } => 1, - AstStatement::ExpressionList { expressions, .. } => expressions.len(), + let ndims = match bounds.get_stmt() { + AstStatement::VlaRangeStatement => 1, + AstStatement::ExpressionList(expressions) => expressions.len(), _ => unreachable!("not a bounds statement"), }; @@ -574,24 +577,29 @@ fn visit_variable_length_array( DataTypeDeclaration::DataTypeDefinition { data_type: DataType::ArrayType { name: Some(member_dimensions_name), - bounds: AstStatement::ExpressionList { - expressions: (0..ndims) - .map(|_| AstStatement::RangeStatement { - start: Box::new(AstStatement::new_literal( - AstLiteral::new_integer(0), - 0, - SourceLocation::undefined(), - )), - end: Box::new(AstStatement::new_literal( - AstLiteral::new_integer(1), - 0, - SourceLocation::undefined(), - )), - id: 0, - }) - .collect::<_>(), - id: 0, - }, + bounds: AstNode::new( + AstStatement::ExpressionList( + (0..ndims) + .map(|_| { + AstFactory::create_range_statement( + AstNode::new_literal( + AstLiteral::new_integer(0), + 0, + SourceLocation::undefined(), + ), + AstNode::new_literal( + AstLiteral::new_integer(1), + 0, + SourceLocation::undefined(), + ), + 0, + ) + }) + .collect::<_>(), + ), + 0, + SourceLocation::undefined(), + ), referenced_type: Box::new(DataTypeDeclaration::DataTypeReference { referenced_type: DINT_TYPE.to_string(), location: SourceLocation::undefined(), @@ -644,7 +652,7 @@ fn visit_variable_length_array( } fn visit_array( - bounds: &AstStatement, + bounds: &AstNode, index: &mut Index, scope: &Option, referenced_type: &DataTypeDeclaration, @@ -654,8 +662,8 @@ fn visit_array( let dimensions: Result, Diagnostic> = bounds .get_as_list() .iter() - .map(|it| match it { - AstStatement::RangeStatement { start, end, .. } => { + .map(|it| match it.get_stmt() { + AstStatement::RangeStatement(RangeStatement { start, end }) => { let constants = index.get_mut_const_expressions(); Ok(Dimension { start_offset: TypeSize::from_expression(constants.add_constant_expression( diff --git a/src/lexer.rs b/src/lexer.rs index aa807cea5e..60e1e7d5eb 100644 --- a/src/lexer.rs +++ b/src/lexer.rs @@ -21,7 +21,7 @@ pub struct ParseSession<'a> { /// the range of the `last_token` pub last_range: Range, pub parse_progress: usize, - id_provider: IdProvider, + pub id_provider: IdProvider, pub source_range_factory: SourceLocationFactory, pub scope: Option, } diff --git a/src/parser.rs b/src/parser.rs index c893650021..cc93f01cda 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -4,10 +4,10 @@ use std::ops::Range; use plc_ast::{ ast::{ - AccessModifier, ArgumentProperty, AstStatement, CompilationUnit, DataType, DataTypeDeclaration, - DirectAccessType, GenericBinding, HardwareAccessType, Implementation, LinkageType, PolymorphismMode, - Pou, PouType, ReferenceAccess, TypeNature, UserTypeDeclaration, Variable, VariableBlock, - VariableBlockType, + AccessModifier, ArgumentProperty, AstFactory, AstNode, AstStatement, CompilationUnit, DataType, + DataTypeDeclaration, DirectAccessType, GenericBinding, HardwareAccessType, Implementation, + LinkageType, PolymorphismMode, Pou, PouType, ReferenceAccess, ReferenceExpr, TypeNature, + UserTypeDeclaration, Variable, VariableBlock, VariableBlockType, }, provider::IdProvider, }; @@ -568,7 +568,7 @@ fn parse_type(lexer: &mut ParseSession) -> Vec { }) } -type DataTypeWithInitializer = (DataTypeDeclaration, Option); +type DataTypeWithInitializer = (DataTypeDeclaration, Option); fn parse_full_data_type_definition( lexer: &mut ParseSession, @@ -661,7 +661,7 @@ fn parse_pointer_definition( lexer: &mut ParseSession, name: Option, start_pos: usize, -) -> Option<(DataTypeDeclaration, Option)> { +) -> Option<(DataTypeDeclaration, Option)> { parse_data_type_definition(lexer, None).map(|(decl, initializer)| { ( DataTypeDeclaration::DataTypeDefinition { @@ -677,7 +677,7 @@ fn parse_pointer_definition( fn parse_type_reference_type_definition( lexer: &mut ParseSession, name: Option, -) -> Option<(DataTypeDeclaration, Option)> { +) -> Option<(DataTypeDeclaration, Option)> { let start = lexer.range().start; //Subrange let referenced_type = lexer.slice_and_advance(); @@ -698,19 +698,22 @@ fn parse_type_reference_type_definition( let end = lexer.last_range.end; if name.is_some() || bounds.is_some() { let data_type = match bounds { - Some(AstStatement::ExpressionList { expressions, id }) => { + Some(AstNode { stmt: AstStatement::ExpressionList(expressions), id, location }) => { //this is an enum DataTypeDeclaration::DataTypeDefinition { data_type: DataType::EnumType { name, numeric_type: referenced_type, - elements: AstStatement::ExpressionList { expressions, id }, + elements: AstFactory::create_expression_list(expressions, location, id), }, location: lexer.source_range_factory.create_range(start..end), scope: lexer.scope.clone(), } } - Some(AstStatement::ReferenceExpr { access: ReferenceAccess::Member(_), .. }) => { + Some(AstNode { + stmt: AstStatement::ReferenceExpr(ReferenceExpr { access: ReferenceAccess::Member(_), .. }), + .. + }) => { // a enum with just one element DataTypeDeclaration::DataTypeDefinition { data_type: DataType::EnumType { @@ -741,7 +744,7 @@ fn parse_type_reference_type_definition( } } -fn parse_string_size_expression(lexer: &mut ParseSession) -> Option { +fn parse_string_size_expression(lexer: &mut ParseSession) -> Option { let opening_token = lexer.token.clone(); if lexer.try_consume(&KeywordSquareParensOpen) || lexer.try_consume(&KeywordParensOpen) { let opening_location = lexer.range().start; @@ -775,7 +778,7 @@ fn parse_string_size_expression(lexer: &mut ParseSession) -> Option, -) -> Option<(DataTypeDeclaration, Option)> { +) -> Option<(DataTypeDeclaration, Option)> { let text = lexer.slice().to_string(); let start = lexer.range().start; let is_wide = lexer.token == KeywordWideString; @@ -808,7 +811,7 @@ fn parse_string_type_definition( fn parse_enum_type_definition( lexer: &mut ParseSession, name: Option, -) -> Option<(DataTypeDeclaration, Option)> { +) -> Option<(DataTypeDeclaration, Option)> { let start = lexer.last_location(); let elements = parse_any_in_region(lexer, vec![KeywordParensClose], |lexer| { // Parse Enum - we expect at least one element @@ -829,7 +832,7 @@ fn parse_enum_type_definition( fn parse_array_type_definition( lexer: &mut ParseSession, name: Option, -) -> Option<(DataTypeDeclaration, Option)> { +) -> Option<(DataTypeDeclaration, Option)> { let start = lexer.last_range.start; let range = parse_any_in_region(lexer, vec![KeywordOf], |lexer| { // Parse Array range @@ -850,15 +853,15 @@ fn parse_array_type_definition( let reference_end = reference.get_location().to_range().map(|it| it.end).unwrap_or(0); let location = lexer.source_range_factory.create_range(start..reference_end); - let is_variable_length = match &range { + let is_variable_length = match &range.get_stmt() { // Single dimensions, i.e. ARRAY[0..5] or ARRAY[*] AstStatement::RangeStatement { .. } => Some(false), AstStatement::VlaRangeStatement { .. } => Some(true), // Multi dimensions, i.e. ARRAY [0..5, 5..10] or ARRAY [*, *] - AstStatement::ExpressionList { expressions, .. } => match expressions[0] { - AstStatement::RangeStatement { .. } => Some(false), - AstStatement::VlaRangeStatement { .. } => Some(true), + AstStatement::ExpressionList(expressions) => match expressions[0].get_stmt() { + AstStatement::RangeStatement(..) => Some(false), + AstStatement::VlaRangeStatement => Some(true), _ => None, }, @@ -890,11 +893,11 @@ fn parse_array_type_definition( } /// parse a body and recovers until the given `end_keywords` -fn parse_body_in_region(lexer: &mut ParseSession, end_keywords: Vec) -> Vec { +fn parse_body_in_region(lexer: &mut ParseSession, end_keywords: Vec) -> Vec { parse_any_in_region(lexer, end_keywords, parse_body_standalone) } -fn parse_body_standalone(lexer: &mut ParseSession) -> Vec { +fn parse_body_standalone(lexer: &mut ParseSession) -> Vec { let mut statements = Vec::new(); while !lexer.closes_open_region(&lexer.token) { statements.push(parse_control(lexer)); @@ -903,10 +906,11 @@ fn parse_body_standalone(lexer: &mut ParseSession) -> Vec { } /// parses a statement ending with a ';' -fn parse_statement(lexer: &mut ParseSession) -> AstStatement { +fn parse_statement(lexer: &mut ParseSession) -> AstNode { let result = parse_any_in_region(lexer, vec![KeywordSemicolon, KeywordColon], parse_expression); if lexer.last_token == KeywordColon { - AstStatement::CaseCondition { condition: Box::new(result), id: lexer.next_id() } + let location = result.location.span(&lexer.last_location()); + AstFactory::create_case_condition(result, location, lexer.next_id()) } else { result } @@ -939,19 +943,18 @@ pub fn parse_any_in_region T>( result } -fn parse_reference(lexer: &mut ParseSession) -> AstStatement { +fn parse_reference(lexer: &mut ParseSession) -> AstNode { match expressions_parser::parse_call_statement(lexer) { Ok(statement) => statement, Err(diagnostic) => { - let statement = - AstStatement::EmptyStatement { location: diagnostic.get_location(), id: lexer.next_id() }; + let statement = AstFactory::create_empty_statement(diagnostic.get_location(), lexer.next_id()); lexer.accept_diagnostic(diagnostic); statement } } } -fn parse_control(lexer: &mut ParseSession) -> AstStatement { +fn parse_control(lexer: &mut ParseSession) -> AstNode { parse_control_statement(lexer) } @@ -998,8 +1001,7 @@ fn parse_variable_block(lexer: &mut ParseSession, linkage: LinkageType) -> Varia if constant { // sneak in the DefaultValue-Statements if no initializers were defined variables.iter_mut().filter(|it| it.initializer.is_none()).for_each(|it| { - it.initializer = - Some(AstStatement::DefaultValue { location: it.location.clone(), id: lexer.next_id() }); + it.initializer = Some(AstFactory::create_default_value(it.location.clone(), lexer.next_id())); }); } @@ -1083,7 +1085,7 @@ fn parse_hardware_access( lexer: &mut ParseSession, hardware_access_type: HardwareAccessType, access_type: DirectAccessType, -) -> Result { +) -> Result { let start_location = lexer.last_location(); lexer.advance(); //Folowed by an integer @@ -1098,13 +1100,13 @@ fn parse_hardware_access( } } } - Ok(AstStatement::HardwareAccess { - access: access_type, - direction: hardware_access_type, + Ok(AstFactory::create_hardware_access( + access_type, + hardware_access_type, address, - location: start_location.span(&lexer.last_location()), - id: lexer.next_id(), - }) + start_location.span(&lexer.last_location()), + lexer.next_id(), + )) } else { Err(Diagnostic::missing_token("LiteralInteger", lexer.location())) } diff --git a/src/parser/control_parser.rs b/src/parser/control_parser.rs index b709b1ca05..928d13279a 100644 --- a/src/parser/control_parser.rs +++ b/src/parser/control_parser.rs @@ -1,5 +1,5 @@ use plc_ast::{ - ast::{AstFactory, AstStatement}, + ast::{AstFactory, AstNode, AstStatement}, control_statements::ConditionalBlock, }; use plc_diagnostics::diagnostics::Diagnostic; @@ -14,7 +14,7 @@ use crate::{ use super::ParseSession; use super::{parse_expression, parse_reference, parse_statement}; -pub fn parse_control_statement(lexer: &mut ParseSession) -> AstStatement { +pub fn parse_control_statement(lexer: &mut ParseSession) -> AstNode { match lexer.token { KeywordIf => parse_if_statement(lexer), KeywordFor => parse_for_statement(lexer), @@ -28,25 +28,25 @@ pub fn parse_control_statement(lexer: &mut ParseSession) -> AstStatement { } } -fn parse_return_statement(lexer: &mut ParseSession) -> AstStatement { +fn parse_return_statement(lexer: &mut ParseSession) -> AstNode { let location = lexer.location(); lexer.advance(); - AstStatement::ReturnStatement { location, id: lexer.next_id() } + AstFactory::create_return_statement(location, lexer.next_id()) } -fn parse_exit_statement(lexer: &mut ParseSession) -> AstStatement { +fn parse_exit_statement(lexer: &mut ParseSession) -> AstNode { let location = lexer.location(); lexer.advance(); - AstStatement::ExitStatement { location, id: lexer.next_id() } + AstFactory::create_exit_statement(location, lexer.next_id()) } -fn parse_continue_statement(lexer: &mut ParseSession) -> AstStatement { +fn parse_continue_statement(lexer: &mut ParseSession) -> AstNode { let location = lexer.location(); lexer.advance(); - AstStatement::ContinueStatement { location, id: lexer.next_id() } + AstFactory::create_continue_statement(location, lexer.next_id()) } -fn parse_if_statement(lexer: &mut ParseSession) -> AstStatement { +fn parse_if_statement(lexer: &mut ParseSession) -> AstNode { let start = lexer.range().start; lexer.advance(); //If let mut conditional_blocks = vec![]; @@ -56,7 +56,7 @@ fn parse_if_statement(lexer: &mut ParseSession) -> AstStatement { expect_token!( lexer, KeywordThen, - AstStatement::EmptyStatement { location: lexer.location(), id: lexer.next_id() } + AstFactory::create_empty_statement(lexer.location(), lexer.next_id()) ); lexer.advance(); @@ -84,7 +84,7 @@ fn parse_if_statement(lexer: &mut ParseSession) -> AstStatement { ) } -fn parse_for_statement(lexer: &mut ParseSession) -> AstStatement { +fn parse_for_statement(lexer: &mut ParseSession) -> AstNode { let start = lexer.range().start; lexer.advance(); // FOR @@ -92,16 +92,12 @@ fn parse_for_statement(lexer: &mut ParseSession) -> AstStatement { expect_token!( lexer, KeywordAssignment, - AstStatement::EmptyStatement { location: lexer.location(), id: lexer.next_id() } + AstFactory::create_empty_statement(lexer.location(), lexer.next_id()) ); lexer.advance(); let start_expression = parse_expression(lexer); - expect_token!( - lexer, - KeywordTo, - AstStatement::EmptyStatement { location: lexer.location(), id: lexer.next_id() } - ); + expect_token!(lexer, KeywordTo, AstFactory::create_empty_statement(lexer.location(), lexer.next_id())); lexer.advance(); let end_expression = parse_expression(lexer); @@ -125,7 +121,7 @@ fn parse_for_statement(lexer: &mut ParseSession) -> AstStatement { ) } -fn parse_while_statement(lexer: &mut ParseSession) -> AstStatement { +fn parse_while_statement(lexer: &mut ParseSession) -> AstNode { let start = lexer.range().start; lexer.advance(); //WHILE @@ -140,7 +136,7 @@ fn parse_while_statement(lexer: &mut ParseSession) -> AstStatement { ) } -fn parse_repeat_statement(lexer: &mut ParseSession) -> AstStatement { +fn parse_repeat_statement(lexer: &mut ParseSession) -> AstNode { let start = lexer.range().start; lexer.advance(); //REPEAT @@ -148,7 +144,7 @@ fn parse_repeat_statement(lexer: &mut ParseSession) -> AstStatement { let condition = if lexer.last_token == KeywordUntil { parse_any_in_region(lexer, vec![KeywordEndRepeat], parse_expression) } else { - AstStatement::EmptyStatement { location: lexer.location(), id: lexer.next_id() } + AstFactory::create_empty_statement(lexer.location(), lexer.next_id()) }; AstFactory::create_repeat_statement( @@ -159,17 +155,13 @@ fn parse_repeat_statement(lexer: &mut ParseSession) -> AstStatement { ) } -fn parse_case_statement(lexer: &mut ParseSession) -> AstStatement { +fn parse_case_statement(lexer: &mut ParseSession) -> AstNode { let start = lexer.range().start; lexer.advance(); // CASE let selector = parse_expression(lexer); - expect_token!( - lexer, - KeywordOf, - AstStatement::EmptyStatement { location: lexer.location(), id: lexer.next_id() } - ); + expect_token!(lexer, KeywordOf, AstFactory::create_empty_statement(lexer.location(), lexer.next_id())); lexer.advance(); @@ -180,7 +172,7 @@ fn parse_case_statement(lexer: &mut ParseSession) -> AstStatement { let mut current_condition = None; let mut current_body = vec![]; for statement in body { - if let AstStatement::CaseCondition { condition, .. } = statement { + if let AstNode { stmt: AstStatement::CaseCondition(condition), .. } = statement { if let Some(condition) = current_condition { let block = ConditionalBlock { condition, body: current_body }; case_blocks.push(block); @@ -194,10 +186,8 @@ fn parse_case_statement(lexer: &mut ParseSession) -> AstStatement { "Missing Case-Condition", lexer.location(), )); - current_condition = Some(Box::new(AstStatement::EmptyStatement { - location: lexer.location(), - id: lexer.next_id(), - })); + current_condition = + Some(Box::new(AstFactory::create_empty_statement(lexer.location(), lexer.next_id()))); } current_body.push(statement); } diff --git a/src/parser/expressions_parser.rs b/src/parser/expressions_parser.rs index 38e7b4b292..2c6295af57 100644 --- a/src/parser/expressions_parser.rs +++ b/src/parser/expressions_parser.rs @@ -7,7 +7,7 @@ use crate::{ }; use core::str::Split; use plc_ast::{ - ast::{AstFactory, AstId, AstStatement, DirectAccessType, Operator}, + ast::{AstFactory, AstId, AstNode, AstStatement, DirectAccessType, Operator}, literals::{AstLiteral, Time}, }; use plc_diagnostics::diagnostics::Diagnostic; @@ -30,12 +30,7 @@ macro_rules! parse_left_associative_expression { }; $lexer.advance(); let right = $action($lexer); - left = AstStatement::BinaryExpression { - operator, - left: Box::new(left), - right: Box::new(right), - id: $lexer.next_id(), - }; + left = AstFactory::create_binary_expression(left, operator, right, $lexer.next_id()); } left } @@ -46,15 +41,16 @@ macro_rules! parse_left_associative_expression { /// is encountered, the erroneous part of the AST will consist of an /// EmptyStatement and a diagnostic will be logged. That case is different from /// only an EmptyStatement returned, which does not denote an error condition. -pub fn parse_expression(lexer: &mut ParseSession) -> AstStatement { +pub fn parse_expression(lexer: &mut ParseSession) -> AstNode { if lexer.token == KeywordSemicolon { - AstStatement::EmptyStatement { location: lexer.location(), id: lexer.next_id() } + AstFactory::create_empty_statement(lexer.location(), lexer.next_id()) } else { parse_expression_list(lexer) } } -pub fn parse_expression_list(lexer: &mut ParseSession) -> AstStatement { +pub fn parse_expression_list(lexer: &mut ParseSession) -> AstNode { + let start = lexer.location(); let left = parse_range_statement(lexer); if lexer.token == KeywordComma { let mut expressions = vec![]; @@ -68,49 +64,49 @@ pub fn parse_expression_list(lexer: &mut ParseSession) -> AstStatement { // we may have parsed no additional expression because of trailing comma if !expressions.is_empty() { expressions.insert(0, left); - return AstStatement::ExpressionList { expressions, id: lexer.next_id() }; + return AstFactory::create_expression_list( + expressions, + start.span(&lexer.last_location()), + lexer.next_id(), + ); } } left } -pub(crate) fn parse_range_statement(lexer: &mut ParseSession) -> AstStatement { +pub(crate) fn parse_range_statement(lexer: &mut ParseSession) -> AstNode { let start = parse_or_expression(lexer); if lexer.token == KeywordDotDot { lexer.advance(); let end = parse_or_expression(lexer); - return AstStatement::RangeStatement { - start: Box::new(start), - end: Box::new(end), - id: lexer.next_id(), - }; + return AstFactory::create_range_statement(start, end, lexer.next_id()); } start } // OR -fn parse_or_expression(lexer: &mut ParseSession) -> AstStatement { +fn parse_or_expression(lexer: &mut ParseSession) -> AstNode { parse_left_associative_expression!(lexer, parse_xor_expression, OperatorOr,) } // XOR -fn parse_xor_expression(lexer: &mut ParseSession) -> AstStatement { +fn parse_xor_expression(lexer: &mut ParseSession) -> AstNode { parse_left_associative_expression!(lexer, parse_and_expression, OperatorXor,) } // AND -fn parse_and_expression(lexer: &mut ParseSession) -> AstStatement { +fn parse_and_expression(lexer: &mut ParseSession) -> AstNode { parse_left_associative_expression!(lexer, parse_equality_expression, OperatorAmp | OperatorAnd,) } //EQUALITY =, <> -fn parse_equality_expression(lexer: &mut ParseSession) -> AstStatement { +fn parse_equality_expression(lexer: &mut ParseSession) -> AstNode { parse_left_associative_expression!(lexer, parse_compare_expression, OperatorEqual | OperatorNotEqual,) } //COMPARE <, >, <=, >= -fn parse_compare_expression(lexer: &mut ParseSession) -> AstStatement { +fn parse_compare_expression(lexer: &mut ParseSession) -> AstNode { parse_left_associative_expression!( lexer, parse_additive_expression, @@ -119,12 +115,12 @@ fn parse_compare_expression(lexer: &mut ParseSession) -> AstStatement { } // Addition +, - -fn parse_additive_expression(lexer: &mut ParseSession) -> AstStatement { +fn parse_additive_expression(lexer: &mut ParseSession) -> AstNode { parse_left_associative_expression!(lexer, parse_multiplication_expression, OperatorPlus | OperatorMinus,) } // Multiplication *, /, MOD -fn parse_multiplication_expression(lexer: &mut ParseSession) -> AstStatement { +fn parse_multiplication_expression(lexer: &mut ParseSession) -> AstNode { parse_left_associative_expression!( lexer, parse_exponent_expression, @@ -133,34 +129,22 @@ fn parse_multiplication_expression(lexer: &mut ParseSession) -> AstStatement { } // Expoent ** -fn parse_exponent_expression(lexer: &mut ParseSession) -> AstStatement { +fn parse_exponent_expression(lexer: &mut ParseSession) -> AstNode { //This is always parsed as a function call to the EXPT function //Parse left let mut left = parse_unary_expression(lexer); while matches!(lexer.token, OperatorExponent) { - let start_location = lexer.last_location(); - let op_location = lexer.location(); lexer.advance(); let right = parse_unary_expression(lexer); - left = AstStatement::CallStatement { - operator: Box::new(AstFactory::create_member_reference( - AstFactory::create_identifier("EXPT", &op_location, lexer.next_id()), - None, - lexer.next_id(), - )), - parameters: Box::new(Some(AstStatement::ExpressionList { - expressions: vec![left, right], - id: lexer.next_id(), - })), - location: start_location.span(&lexer.last_location()), - id: lexer.next_id(), - } + let span = left.get_location().span(&right.get_location()); + left = + AstFactory::create_call_to_with_ids("EXPT", vec![left, right], &span, lexer.id_provider.clone()); } left } // UNARY -x, NOT x -fn parse_unary_expression(lexer: &mut ParseSession) -> AstStatement { +fn parse_unary_expression(lexer: &mut ParseSession) -> AstNode { // collect all consecutive operators let start_location = lexer.location(); let mut operators = vec![]; @@ -179,26 +163,21 @@ fn parse_unary_expression(lexer: &mut ParseSession) -> AstStatement { let expression_location = expression.get_location(); let location = start_location.span(&expression_location); - match (&operator, &expression) { - (Operator::Minus, AstStatement::Literal { kind: AstLiteral::Integer(value), .. }) => { - AstStatement::new_literal(AstLiteral::new_integer(-value), lexer.next_id(), location) + match (&operator, &expression.get_stmt()) { + (Operator::Minus, AstStatement::Literal(AstLiteral::Integer(value))) => { + AstNode::new_literal(AstLiteral::new_integer(-value), lexer.next_id(), location) } - (Operator::Plus, AstStatement::Literal { kind: AstLiteral::Integer(value), .. }) => { - AstStatement::new_literal(AstLiteral::new_integer(*value), lexer.next_id(), location) + (Operator::Plus, AstStatement::Literal(AstLiteral::Integer(value))) => { + AstNode::new_literal(AstLiteral::new_integer(*value), lexer.next_id(), location) } // Return the reference itself instead of wrapping it inside a `AstStatement::UnaryExpression` - (Operator::Plus, AstStatement::Identifier { name, .. }) => { + (Operator::Plus, AstStatement::Identifier(name)) => { AstFactory::create_identifier(name, &location, lexer.next_id()) } - _ => AstStatement::UnaryExpression { - operator: *operator, - value: Box::new(expression), - location, - id: lexer.next_id(), - }, + _ => AstFactory::create_unary_expression(*operator, expression, location, lexer.next_id()), } }) } @@ -226,7 +205,7 @@ fn to_operator(token: &Token) -> Option { } // Literals, Identifiers, etc. -fn parse_leaf_expression(lexer: &mut ParseSession) -> AstStatement { +fn parse_leaf_expression(lexer: &mut ParseSession) -> AstNode { let literal_parse_result = match lexer.token { OperatorMultiplication => parse_vla_range(lexer), _ => parse_call_statement(lexer), @@ -236,25 +215,16 @@ fn parse_leaf_expression(lexer: &mut ParseSession) -> AstStatement { Ok(statement) => { if lexer.token == KeywordAssignment { lexer.advance(); - AstStatement::Assignment { - left: Box::new(statement), - right: Box::new(parse_range_statement(lexer)), - id: lexer.next_id(), - } + AstFactory::create_assignment(statement, parse_range_statement(lexer), lexer.next_id()) } else if lexer.token == KeywordOutputAssignment { lexer.advance(); - AstStatement::OutputAssignment { - left: Box::new(statement), - right: Box::new(parse_range_statement(lexer)), - id: lexer.next_id(), - } + AstFactory::create_output_assignment(statement, parse_range_statement(lexer), lexer.next_id()) } else { statement } } Err(diagnostic) => { - let statement = - AstStatement::EmptyStatement { location: diagnostic.get_location(), id: lexer.next_id() }; + let statement = AstFactory::create_empty_statement(diagnostic.get_location(), lexer.next_id()); lexer.accept_diagnostic(diagnostic); statement } @@ -264,7 +234,7 @@ fn parse_leaf_expression(lexer: &mut ParseSession) -> AstStatement { /// parse an expression at the bottom of the parse-tree. /// leaf-expressions are literals, identifier, direct-access and parenthesized expressions /// (since the parentheses change the parse-priority) -fn parse_atomic_leaf_expression(lexer: &mut ParseSession<'_>) -> Result { +fn parse_atomic_leaf_expression(lexer: &mut ParseSession<'_>) -> Result { // Check if we're dealing with a number that has an explicit '+' or '-' sign... match lexer.token { @@ -314,7 +284,7 @@ fn parse_atomic_leaf_expression(lexer: &mut ParseSession<'_>) -> Result' / ':=' // we are probably in a call statement missing a parameter assignment 'foo(param := ); // optional parameter assignments are allowed, validation should handle any unwanted cases - Ok(AstStatement::EmptyStatement { location: lexer.location(), id: lexer.next_id() }) + Ok(AstFactory::create_empty_statement(lexer.location(), lexer.next_id())) } else { Err(Diagnostic::unexpected_token_found("Literal", lexer.slice(), lexer.location())) } @@ -322,16 +292,16 @@ fn parse_atomic_leaf_expression(lexer: &mut ParseSession<'_>) -> Result) -> AstStatement { +fn parse_identifier(lexer: &mut ParseSession<'_>) -> AstNode { AstFactory::create_identifier(&lexer.slice_and_advance(), &lexer.last_location(), lexer.next_id()) } -fn parse_vla_range(lexer: &mut ParseSession) -> Result { +fn parse_vla_range(lexer: &mut ParseSession) -> Result { lexer.advance(); - Ok(AstStatement::VlaRangeStatement { id: lexer.next_id() }) + Ok(AstFactory::create_vla_range_statement(lexer.last_location(), lexer.next_id())) } -fn parse_array_literal(lexer: &mut ParseSession) -> Result { +fn parse_array_literal(lexer: &mut ParseSession) -> Result { let start = lexer.range().start; lexer.expect(KeywordSquareParensOpen)?; lexer.advance(); @@ -340,7 +310,7 @@ fn parse_array_literal(lexer: &mut ParseSession) -> Result Result Result { +fn parse_bool_literal(lexer: &mut ParseSession, value: bool) -> Result { let location = lexer.location(); lexer.advance(); - Ok(AstStatement::new_literal(AstLiteral::new_bool(value), lexer.next_id(), location)) + Ok(AstNode::new_literal(AstLiteral::new_bool(value), lexer.next_id(), location)) } #[allow(clippy::unnecessary_wraps)] //Allowing the unnecessary wrap here because this method is used along other methods that need to return Results -fn parse_null_literal(lexer: &mut ParseSession) -> Result { +fn parse_null_literal(lexer: &mut ParseSession) -> Result { let location = lexer.location(); lexer.advance(); - Ok(AstStatement::new_literal(AstLiteral::new_null(), lexer.next_id(), location)) + Ok(AstNode::new_literal(AstLiteral::new_null(), lexer.next_id(), location)) } -pub fn parse_call_statement(lexer: &mut ParseSession) -> Result { +pub fn parse_call_statement(lexer: &mut ParseSession) -> Result { let reference = parse_qualified_reference(lexer)?; // is this a callstatement? if lexer.try_consume(&KeywordParensOpen) { - let start_location = reference.get_location(); + let start = reference.get_location(); // Call Statement let call_statement = if lexer.try_consume(&KeywordParensClose) { - AstStatement::CallStatement { - operator: Box::new(reference), - parameters: Box::new(None), - location: start_location.span(&lexer.location()), - id: lexer.next_id(), - } + AstFactory::create_call_statement(reference, None, lexer.next_id(), start.span(&lexer.location())) } else { - parse_any_in_region(lexer, vec![KeywordParensClose], |lexer| AstStatement::CallStatement { - operator: Box::new(reference), - parameters: Box::new(Some(parse_expression_list(lexer))), - location: start_location.span(&lexer.location()), - id: lexer.next_id(), + parse_any_in_region(lexer, vec![KeywordParensClose], |lexer| { + AstFactory::create_call_statement( + reference, + Some(parse_expression_list(lexer)), + lexer.next_id(), + start.span(&lexer.location()), + ) }) }; Ok(call_statement) @@ -392,7 +359,7 @@ pub fn parse_call_statement(lexer: &mut ParseSession) -> Result Result { +pub fn parse_qualified_reference(lexer: &mut ParseSession) -> Result { let mut current = None; let mut pos = lexer.parse_progress - 1; // force an initial loop @@ -410,7 +377,7 @@ pub fn parse_qualified_reference(lexer: &mut ParseSession) -> Result { let exp = parse_atomic_leaf_expression(lexer)?; // pack if this is something to be resolved - current = if matches!(exp, AstStatement::Identifier { .. }) { + current = if exp.is_identifier() { Some(AstFactory::create_member_reference(exp, None, lexer.next_id())) } else { Some(exp) @@ -492,10 +459,7 @@ pub fn parse_qualified_reference(lexer: &mut ParseSession) -> Result Result { +fn parse_direct_access(lexer: &mut ParseSession, access: DirectAccessType) -> Result { //Consume the direct access let location = lexer.location(); lexer.advance(); @@ -512,16 +476,15 @@ fn parse_direct_access( } _ => Err(Diagnostic::unexpected_token_found("Integer or Reference", lexer.slice(), lexer.location())), }?; - let location = location.span(&lexer.last_location()); - Ok(AstStatement::DirectAccess { access, index: Box::new(index), location, id: lexer.next_id() }) + Ok(AstFactory::create_direct_access(access, index, lexer.next_id(), location)) } fn parse_literal_number_with_modifier( lexer: &mut ParseSession, radix: u32, is_negative: bool, -) -> Result { +) -> Result { // we can safely unwrap the number string, since the token has // been matched using regular expressions let location = lexer.location(); @@ -532,10 +495,10 @@ fn parse_literal_number_with_modifier( // again, the parsed number can be safely unwrapped. let value = i128::from_str_radix(number_str.as_str(), radix).expect("valid i128"); let value = if is_negative { -value } else { value }; - Ok(AstStatement::new_literal(AstLiteral::new_integer(value), lexer.next_id(), location)) + Ok(AstNode::new_literal(AstLiteral::new_integer(value), lexer.next_id(), location)) } -fn parse_literal_number(lexer: &mut ParseSession, is_negative: bool) -> Result { +fn parse_literal_number(lexer: &mut ParseSession, is_negative: bool) -> Result { let location = if is_negative { //correct the location if we just parsed a minus before lexer.last_range.start..lexer.range().end @@ -546,7 +509,7 @@ fn parse_literal_number(lexer: &mut ParseSession, is_negative: bool) -> Result Result Result().expect("valid i128"); let value = if is_negative { -value } else { value }; - Ok(AstStatement::new_literal( + Ok(AstNode::new_literal( AstLiteral::new_integer(value), lexer.next_id(), lexer.source_range_factory.create_range(location), @@ -587,7 +550,7 @@ fn parse_literal_number(lexer: &mut ParseSession, is_negative: bool) -> Result Result { +pub fn parse_strict_literal_integer(lexer: &mut ParseSession) -> Result { //correct the location if we just parsed a minus before let location = lexer.location(); let result = lexer.slice_and_advance(); @@ -597,7 +560,7 @@ pub fn parse_strict_literal_integer(lexer: &mut ParseSession) -> Result().expect("valid i128"); - Ok(AstStatement::new_literal(AstLiteral::new_integer(value), lexer.next_id(), location)) + Ok(AstNode::new_literal(AstLiteral::new_integer(value), lexer.next_id(), location)) } } @@ -607,11 +570,7 @@ fn parse_number(text: &str, location: &SourceLocation) -> Result Result { +fn parse_date_from_string(text: &str, location: SourceLocation, id: AstId) -> Result { let mut segments = text.split('-'); //we can safely expect 3 numbers @@ -628,10 +587,10 @@ fn parse_date_from_string( .map(|s| parse_number::(s, &location)) .expect("day-segment - tokenizer broken?")?; - Ok(AstStatement::new_literal(AstLiteral::new_date(year, month, day), id, location)) + Ok(AstNode::new_literal(AstLiteral::new_date(year, month, day), id, location)) } -fn parse_literal_date_and_time(lexer: &mut ParseSession) -> Result { +fn parse_literal_date_and_time(lexer: &mut ParseSession) -> Result { let location = lexer.location(); //get rid of D# or DATE# let slice = lexer.slice_and_advance(); @@ -651,14 +610,14 @@ fn parse_literal_date_and_time(lexer: &mut ParseSession) -> Result Result { +fn parse_literal_date(lexer: &mut ParseSession) -> Result { let location = lexer.location(); //get rid of D# or DATE# let slice = lexer.slice_and_advance(); @@ -668,7 +627,7 @@ fn parse_literal_date(lexer: &mut ParseSession) -> Result Result { +fn parse_literal_time_of_day(lexer: &mut ParseSession) -> Result { let location = lexer.location(); //get rid of TOD# or TIME_OF_DAY# let slice = lexer.slice_and_advance(); @@ -678,11 +637,7 @@ fn parse_literal_time_of_day(lexer: &mut ParseSession) -> Result Result { +fn parse_literal_time(lexer: &mut ParseSession) -> Result { const POS_D: usize = 0; const POS_H: usize = 1; const POS_M: usize = 2; @@ -790,7 +745,7 @@ fn parse_literal_time(lexer: &mut ParseSession) -> Result String { .into() } -fn parse_literal_string(lexer: &mut ParseSession, is_wide: bool) -> Result { +fn parse_literal_string(lexer: &mut ParseSession, is_wide: bool) -> Result { let result = lexer.slice(); let location = lexer.location(); - let string_literal = Ok(AstStatement::new_literal( + let string_literal = Ok(AstNode::new_literal( AstLiteral::new_string(handle_special_chars(&trim_quotes(result), is_wide), is_wide), lexer.next_id(), location, @@ -875,7 +830,7 @@ fn parse_literal_real( integer: String, integer_range: Range, is_negative: bool, -) -> Result { +) -> Result { if lexer.token == LiteralInteger { let start = integer_range.start; let end = lexer.range().end; @@ -883,7 +838,7 @@ fn parse_literal_real( let value = format!("{}{}.{}", if is_negative { "-" } else { "" }, integer, fractional); let new_location = lexer.source_range_factory.create_range(start..end); - Ok(AstStatement::new_literal(AstLiteral::new_real(value), lexer.next_id(), new_location)) + Ok(AstNode::new_literal(AstLiteral::new_real(value), lexer.next_id(), new_location)) } else { Err(Diagnostic::unexpected_token_found( "LiteralInteger or LiteralExponent", diff --git a/src/parser/tests.rs b/src/parser/tests.rs index f4b5959898..8ee0e295b6 100644 --- a/src/parser/tests.rs +++ b/src/parser/tests.rs @@ -1,5 +1,5 @@ use plc_ast::{ - ast::{AstFactory, AstStatement, ReferenceAccess}, + ast::{AstFactory, AstNode}, literals::AstLiteral, }; use plc_source::source_location::SourceLocation; @@ -20,25 +20,20 @@ mod type_parser_tests; mod variable_parser_tests; /// helper function to create references -pub fn ref_to(name: &str) -> AstStatement { - AstStatement::ReferenceExpr { - access: ReferenceAccess::Member(Box::new(AstFactory::create_identifier( - name, - &SourceLocation::undefined(), - 0, - ))), - base: None, - id: 0, - location: SourceLocation::undefined(), - } +pub fn ref_to(name: &str) -> AstNode { + AstFactory::create_member_reference( + AstFactory::create_identifier(name, &SourceLocation::undefined(), 0), + None, + 0, + ) } /// helper function to create literal ints -pub fn literal_int(value: i128) -> AstStatement { - AstStatement::new_literal(AstLiteral::new_integer(value), 0, SourceLocation::undefined()) +pub fn literal_int(value: i128) -> AstNode { + AstNode::new_literal(AstLiteral::new_integer(value), 0, SourceLocation::undefined()) } /// helper function to create empty statements -pub fn empty_stmt() -> AstStatement { - AstStatement::EmptyStatement { location: SourceLocation::undefined(), id: 0 } +pub fn empty_stmt() -> AstNode { + AstFactory::create_empty_statement(SourceLocation::undefined(), 0) } diff --git a/src/parser/tests/control_parser_tests.rs b/src/parser/tests/control_parser_tests.rs index c5d5a3acdd..276bc6570f 100644 --- a/src/parser/tests/control_parser_tests.rs +++ b/src/parser/tests/control_parser_tests.rs @@ -5,6 +5,7 @@ use plc_ast::{ ast::AstStatement, control_statements::{AstControlStatement, ForLoopStatement, IfStatement}, }; + use pretty_assertions::*; #[test] @@ -447,9 +448,8 @@ fn if_stmnt_location_test() { END_IF" ); - if let AstStatement::ControlStatement { - kind: AstControlStatement::If(IfStatement { blocks, .. }), .. - } = &unit.statements[0] + if let AstStatement::ControlStatement(AstControlStatement::If(IfStatement { blocks, .. }), ..) = + &unit.statements[0].get_stmt() { let if_location = blocks[0].condition.as_ref().get_location(); assert_eq!(source[if_location.to_range().unwrap()].to_string(), "a > 4"); @@ -480,10 +480,10 @@ fn for_stmnt_location_test() { END_FOR" ); - if let AstStatement::ControlStatement { - kind: AstControlStatement::ForLoop(ForLoopStatement { counter, start, end, by_step, .. }), - .. - } = &unit.statements[0] + if let AstStatement::ControlStatement( + AstControlStatement::ForLoop(ForLoopStatement { counter, start, end, by_step, .. }), + .., + ) = &unit.statements[0].get_stmt() { let counter_location = counter.as_ref().get_location(); assert_eq!(source[counter_location.to_range().unwrap()].to_string(), "x"); @@ -565,11 +565,10 @@ fn call_stmnt_location_test() { let location = &unit.statements[0].get_location(); assert_eq!(source[location.to_range().unwrap()].to_string(), "foo(a:=3, b:=4)"); - if let AstStatement::CallStatement { operator, parameters, .. } = &unit.statements[0] { - let operator_location = operator.as_ref().get_location(); - assert_eq!(source[operator_location.to_range().unwrap()].to_string(), "foo"); + if let AstStatement::CallStatement(data) = &unit.statements[0].get_stmt() { + assert_eq!(source[data.operator.get_location().to_range().unwrap()].to_string(), "foo"); - let parameters_statement = parameters.as_ref().as_ref(); + let parameters_statement = data.parameters.as_deref(); let parameters_location = parameters_statement.map(|it| it.get_location()).unwrap(); assert_eq!(source[parameters_location.to_range().unwrap()].to_string(), "a:=3, b:=4"); } diff --git a/src/parser/tests/expressions_parser_tests.rs b/src/parser/tests/expressions_parser_tests.rs index 4ce0bbf990..cdc2d12a44 100644 --- a/src/parser/tests/expressions_parser_tests.rs +++ b/src/parser/tests/expressions_parser_tests.rs @@ -2,9 +2,7 @@ use crate::parser::tests::ref_to; use crate::test_utils::tests::parse; use insta::{assert_debug_snapshot, assert_snapshot}; -use plc_ast::ast::{ - AstFactory, AstStatement, DataType, DataTypeDeclaration, LinkageType, Operator, Pou, PouType, -}; +use plc_ast::ast::{AstFactory, AstNode, DataType, DataTypeDeclaration, LinkageType, Operator, Pou, PouType}; use plc_ast::literals::AstLiteral; use plc_source::source_location::SourceLocation; use pretty_assertions::*; @@ -732,7 +730,7 @@ fn literal_real_test() { assert_eq!(ast_string, expected_ast); } -fn cast(data_type: &str, value: AstStatement) -> AstStatement { +fn cast(data_type: &str, value: AstNode) -> AstNode { AstFactory::create_cast_statement( AstFactory::create_member_reference( AstFactory::create_identifier(data_type, &SourceLocation::undefined(), 0), @@ -786,8 +784,8 @@ fn literal_cast_parse_test() { let statement = &prg.statements; let ast_string = format!("{statement:#?}"); - fn literal(value: AstLiteral) -> AstStatement { - AstStatement::Literal { kind: value, location: SourceLocation::undefined(), id: 0 } + fn literal(value: AstLiteral) -> AstNode { + AstFactory::create_literal(value, SourceLocation::undefined(), 0) } assert_eq!( @@ -1611,11 +1609,11 @@ fn sized_string_as_function_return() { data_type: DataType::StringType { name: None, is_wide: false, - size: Some(AstStatement::Literal { - kind: AstLiteral::new_integer(10), - location: SourceLocation::undefined(), - id: 0, - }), + size: Some(AstFactory::create_literal( + AstLiteral::new_integer(10), + SourceLocation::undefined(), + 0, + )), }, location: SourceLocation::undefined(), scope: Some("foo".into()), @@ -1651,19 +1649,11 @@ fn array_type_as_function_return() { referenced_type: "INT".into(), location: SourceLocation::undefined(), }), - bounds: AstStatement::RangeStatement { - start: Box::new(AstStatement::Literal { - id: 0, - location: SourceLocation::undefined(), - kind: AstLiteral::new_integer(0), - }), - end: Box::new(AstStatement::Literal { - id: 0, - location: SourceLocation::undefined(), - kind: AstLiteral::new_integer(10), - }), - id: 0, - }, + bounds: AstFactory::create_range_statement( + AstFactory::create_literal(AstLiteral::Integer(0), SourceLocation::undefined(), 0), + AstFactory::create_literal(AstLiteral::Integer(10), SourceLocation::undefined(), 0), + 0, + ), name: None, is_variable_length: false, }, @@ -1707,22 +1697,16 @@ fn plus_minus_parse_tree_priority_test() { END_FUNCTION ", ); - assert_eq!( format!("{:#?}", ast.implementations[0].statements[0]), format!( "{:#?}", - AstStatement::BinaryExpression { - id: 0, - operator: Operator::Plus, - left: Box::new(AstStatement::BinaryExpression { - id: 0, - operator: Operator::Minus, - left: Box::new(ref_to("a")), - right: Box::new(ref_to("b")), - }), - right: Box::new(ref_to("c")), - } + AstFactory::create_binary_expression( + AstFactory::create_binary_expression(ref_to("a"), Operator::Minus, ref_to("b"), 0), + Operator::Plus, + ref_to("c"), + 0 + ) ) ); assert_eq!(diagnostics.is_empty(), true); @@ -1743,22 +1727,22 @@ fn mul_div_mod_parse_tree_priority_test() { format!("{:#?}", ast.implementations[0].statements[0]), format!( "{:#?}", - AstStatement::BinaryExpression { - id: 0, - operator: Operator::Modulo, - left: Box::new(AstStatement::BinaryExpression { - id: 0, - operator: Operator::Division, - left: Box::new(AstStatement::BinaryExpression { - id: 0, - operator: Operator::Multiplication, - left: Box::new(ref_to("a")), - right: Box::new(ref_to("b")), - }), - right: Box::new(ref_to("c")), - }), - right: Box::new(ref_to("d")), - } + AstFactory::create_binary_expression( + AstFactory::create_binary_expression( + AstFactory::create_binary_expression( + ref_to("a"), + Operator::Multiplication, + ref_to("b"), + 0 + ), + Operator::Division, + ref_to("c"), + 0 + ), + Operator::Modulo, + ref_to("d"), + 0 + ) ) ); assert_eq!(diagnostics.is_empty(), true); diff --git a/src/parser/tests/misc_parser_tests.rs b/src/parser/tests/misc_parser_tests.rs index 53b5342ca8..1f087f5f01 100644 --- a/src/parser/tests/misc_parser_tests.rs +++ b/src/parser/tests/misc_parser_tests.rs @@ -1,15 +1,16 @@ // Copyright (c) 2020 Ghaith Hachem and Mathias Rieder use core::panic; -use std::{collections::HashSet, ops::Range}; +use std::collections::HashSet; -use crate::{parser::tests::empty_stmt, test_utils::tests::parse}; +use crate::test_utils::tests::parse; use insta::assert_debug_snapshot; use plc_ast::{ - ast::{AstFactory, AstStatement, LinkageType, Operator, ReferenceAccess}, + ast::{ + Assignment, AstNode, AstStatement, BinaryExpression, CallStatement, LinkageType, ReferenceAccess, + ReferenceExpr, UnaryExpression, + }, control_statements::{AstControlStatement, CaseStatement, ForLoopStatement, IfStatement, LoopStatement}, - literals::AstLiteral, }; -use plc_source::source_location::{SourceLocation, SourceLocationFactory}; use pretty_assertions::*; #[test] @@ -88,10 +89,10 @@ fn ids_are_assigned_to_parsed_assignments() { let implementation = &parse_result.implementations[0]; let mut ids = HashSet::new(); - if let AstStatement::Assignment { id, left, right } = &implementation.statements[0] { + if let AstStatement::Assignment(Assignment { left, right }) = &implementation.statements[0].get_stmt() { assert!(ids.insert(left.get_id())); assert!(ids.insert(right.get_id())); - assert!(ids.insert(*id)); + assert!(ids.insert(implementation.statements[0].get_id())); } else { panic!("unexpected statement"); } @@ -110,37 +111,49 @@ fn ids_are_assigned_to_callstatements() { let parse_result = parse(src).0; let implementation = &parse_result.implementations[0]; let mut ids = HashSet::new(); - if let AstStatement::CallStatement { id, operator, .. } = &implementation.statements[0] { + if let AstStatement::CallStatement(CallStatement { operator, .. }, ..) = + &implementation.statements[0].get_stmt() + { assert!(ids.insert(operator.get_id())); - assert!(ids.insert(*id)); } else { panic!("unexpected statement"); } - if let AstStatement::CallStatement { id, operator, parameters, .. } = &implementation.statements[1] { + if let AstStatement::CallStatement(CallStatement { operator, parameters, .. }, ..) = + &implementation.statements[1].get_stmt() + { assert!(ids.insert(operator.get_id())); - if let Some(AstStatement::ExpressionList { expressions, id }) = &**parameters { + if let Some(AstNode { stmt: AstStatement::ExpressionList(expressions), id, .. }) = + parameters.as_deref() + { assert!(ids.insert(expressions[0].get_id())); assert!(ids.insert(expressions[1].get_id())); assert!(ids.insert(expressions[2].get_id())); assert!(ids.insert(*id)); } - assert!(ids.insert(*id)); } else { panic!("unexpected statement"); } - if let AstStatement::CallStatement { id, operator, parameters, .. } = &implementation.statements[2] { + if let AstStatement::CallStatement(CallStatement { operator, parameters }, ..) = + &implementation.statements[2].get_stmt() + { assert!(ids.insert(operator.get_id())); - if let Some(AstStatement::ExpressionList { expressions, id }) = &**parameters { - if let AstStatement::Assignment { left, right, id, .. } = &expressions[0] { + if let Some(AstNode { stmt: AstStatement::ExpressionList(expressions), id, .. }) = + parameters.as_deref() + { + if let AstNode { stmt: AstStatement::Assignment(Assignment { left, right }), id, .. } = + &expressions[0] + { assert!(ids.insert(left.get_id())); assert!(ids.insert(right.get_id())); assert!(ids.insert(*id)); } else { panic!("unexpected statement"); } - if let AstStatement::OutputAssignment { left, right, id, .. } = &expressions[1] { + if let AstNode { stmt: AstStatement::OutputAssignment(Assignment { left, right }), id, .. } = + &expressions[1] + { assert!(ids.insert(left.get_id())); assert!(ids.insert(right.get_id())); assert!(ids.insert(*id)); @@ -150,10 +163,13 @@ fn ids_are_assigned_to_callstatements() { assert!(ids.insert(expressions[2].get_id())); assert!(ids.insert(*id)); } - assert!(ids.insert(*id)); } else { panic!("unexpected statement"); } + + for s in &implementation.statements { + assert!(ids.insert(s.get_id())); + } } #[test] @@ -174,7 +190,10 @@ fn ids_are_assigned_to_expressions() { let implementation = &parse_result.implementations[0]; let mut ids = HashSet::new(); - if let AstStatement::BinaryExpression { id, left, right, .. } = &implementation.statements[0] { + if let AstNode { + id, stmt: AstStatement::BinaryExpression(BinaryExpression { left, right, .. }), .. + } = &implementation.statements[0] + { assert!(ids.insert(left.get_id())); assert!(ids.insert(right.get_id())); assert!(ids.insert(*id)); @@ -182,13 +201,21 @@ fn ids_are_assigned_to_expressions() { panic!("unexpected statement"); } - if let AstStatement::ReferenceExpr { access: ReferenceAccess::Member(m), base: Some(base), id, .. } = - &implementation.statements[1] + if let AstNode { + stmt: + AstStatement::ReferenceExpr(ReferenceExpr { access: ReferenceAccess::Member(m), base: Some(base) }), + id, + .. + } = &implementation.statements[1] { assert!(ids.insert(*id)); assert!(ids.insert(m.get_id())); - if let AstStatement::ReferenceExpr { access: ReferenceAccess::Member(m), base: None, .. } = - base.as_ref() + + if let AstNode { + stmt: + AstStatement::ReferenceExpr(ReferenceExpr { access: ReferenceAccess::Member(m), base: None }), + .. + } = base.as_ref() { assert!(ids.insert(m.get_id())); } else { @@ -198,8 +225,11 @@ fn ids_are_assigned_to_expressions() { panic!("unexpected statement"); } - if let AstStatement::ReferenceExpr { access: ReferenceAccess::Member(m), base: None, id, .. } = - &implementation.statements[2] + if let AstNode { + stmt: AstStatement::ReferenceExpr(ReferenceExpr { access: ReferenceAccess::Member(m), base: None }), + id, + .. + } = &implementation.statements[2] { assert!(ids.insert(*id)); assert!(ids.insert(m.get_id())); @@ -207,9 +237,12 @@ fn ids_are_assigned_to_expressions() { panic!("unexpected statement"); } - if let AstStatement::ReferenceExpr { - access: ReferenceAccess::Index(access), - base: Some(reference), + if let AstNode { + stmt: + AstStatement::ReferenceExpr(ReferenceExpr { + access: ReferenceAccess::Index(access), + base: Some(reference), + }), id, .. } = &implementation.statements[3] @@ -221,14 +254,18 @@ fn ids_are_assigned_to_expressions() { panic!("unexpected statement"); } - if let AstStatement::UnaryExpression { id, value, .. } = &implementation.statements[4] { + if let AstNode { stmt: AstStatement::UnaryExpression(UnaryExpression { value, .. }), id, .. } = + &implementation.statements[4] + { assert!(ids.insert(value.get_id())); assert!(ids.insert(*id)); } else { panic!("unexpected statement"); } - if let AstStatement::ExpressionList { id, expressions, .. } = &implementation.statements[5] { + if let AstNode { stmt: AstStatement::ExpressionList(expressions, ..), id, .. } = + &implementation.statements[5] + { assert!(ids.insert(expressions[0].get_id())); assert!(ids.insert(expressions[1].get_id())); assert!(ids.insert(*id)); @@ -236,16 +273,18 @@ fn ids_are_assigned_to_expressions() { panic!("unexpected statement"); } - if let AstStatement::RangeStatement { id, start, end, .. } = &implementation.statements[6] { - assert!(ids.insert(start.get_id())); - assert!(ids.insert(end.get_id())); + if let AstNode { stmt: AstStatement::RangeStatement(data, ..), id, .. } = &implementation.statements[6] { + assert!(ids.insert(data.start.get_id())); + assert!(ids.insert(data.end.get_id())); assert!(ids.insert(*id)); } else { panic!("unexpected statement"); } - if let AstStatement::MultipliedStatement { id, element, .. } = &implementation.statements[7] { - assert!(ids.insert(element.get_id())); + if let AstNode { stmt: AstStatement::MultipliedStatement(data, ..), id, .. } = + &implementation.statements[7] + { + assert!(ids.insert(data.element.get_id())); assert!(ids.insert(*id)); } else { panic!("unexpected statement"); @@ -267,8 +306,11 @@ fn ids_are_assigned_to_if_statements() { let implementation = &parse_result.implementations[0]; let mut ids = HashSet::new(); match &implementation.statements[0] { - AstStatement::ControlStatement { - kind: AstControlStatement::If(IfStatement { blocks, else_block, .. }), + AstNode { + stmt: + AstStatement::ControlStatement(AstControlStatement::If(IfStatement { + blocks, else_block, .. + })), .. } => { assert!(ids.insert(blocks[0].condition.get_id())); @@ -295,9 +337,17 @@ fn ids_are_assigned_to_for_statements() { let implementation = &parse_result.implementations[0]; let mut ids = HashSet::new(); match &implementation.statements[0] { - AstStatement::ControlStatement { + AstNode { + stmt: + AstStatement::ControlStatement(AstControlStatement::ForLoop(ForLoopStatement { + counter, + start, + end, + by_step, + body, + .. + })), id, - kind: AstControlStatement::ForLoop(ForLoopStatement { counter, start, end, by_step, body, .. }), .. } => { assert!(ids.insert(counter.get_id())); @@ -326,14 +376,20 @@ fn ids_are_assigned_to_while_statements() { let implementation = &parse_result.implementations[0]; let mut ids = HashSet::new(); match &implementation.statements[0] { - AstStatement::ControlStatement { - kind: AstControlStatement::WhileLoop(LoopStatement { condition, body, .. }), + AstNode { + stmt: + AstStatement::ControlStatement(AstControlStatement::WhileLoop(LoopStatement { + condition, + body, + .. + })), + id, .. } => { assert!(ids.insert(condition.get_id())); assert!(ids.insert(body[0].get_id())); assert!(ids.insert(body[1].get_id())); - assert!(ids.insert(implementation.statements[0].get_id())); + assert!(ids.insert(*id)); } _ => panic!("invalid statement"), } @@ -353,8 +409,14 @@ fn ids_are_assigned_to_repeat_statements() { let mut ids = HashSet::new(); match &implementation.statements[0] { - AstStatement::ControlStatement { - kind: AstControlStatement::RepeatLoop(LoopStatement { condition, body, .. }), + AstNode { + stmt: + AstStatement::ControlStatement(AstControlStatement::RepeatLoop(LoopStatement { + condition, + body, + .. + })), + id: _, .. } => { assert!(ids.insert(body[0].get_id())); @@ -384,8 +446,15 @@ fn ids_are_assigned_to_case_statements() { let implementation = &parse_result.implementations[0]; let mut ids = HashSet::new(); match &implementation.statements[0] { - AstStatement::ControlStatement { - kind: AstControlStatement::Case(CaseStatement { case_blocks, else_block, selector, .. }), + AstNode { + stmt: + AstStatement::ControlStatement(AstControlStatement::Case(CaseStatement { + case_blocks, + else_block, + selector, + .. + })), + id: _, .. } => { //1st case block @@ -394,7 +463,9 @@ fn ids_are_assigned_to_case_statements() { assert!(ids.insert(case_blocks[0].body[0].get_id())); //2nd case block - if let AstStatement::ExpressionList { expressions, id, .. } = case_blocks[1].condition.as_ref() { + if let AstNode { stmt: AstStatement::ExpressionList(expressions), id, .. } = + case_blocks[1].condition.as_ref() + { assert!(ids.insert(expressions[0].get_id())); assert!(ids.insert(expressions[1].get_id())); assert!(ids.insert(*id)); @@ -410,235 +481,3 @@ fn ids_are_assigned_to_case_statements() { _ => panic!("invalid statement"), } } - -#[test] -fn id_implementation_for_all_statements() { - assert_eq!( - AstStatement::Assignment { left: Box::new(empty_stmt()), right: Box::new(empty_stmt()), id: 7 } - .get_id(), - 7 - ); - assert_eq!( - AstStatement::BinaryExpression { - left: Box::new(empty_stmt()), - right: Box::new(empty_stmt()), - operator: Operator::And, - id: 7 - } - .get_id(), - 7 - ); - assert_eq!( - AstStatement::BinaryExpression { - left: Box::new(empty_stmt()), - right: Box::new(empty_stmt()), - operator: Operator::And, - id: 7 - } - .get_id(), - 7 - ); - assert_eq!( - AstStatement::CallStatement { - operator: Box::new(empty_stmt()), - parameters: Box::new(None), - id: 7, - location: SourceLocation::undefined() - } - .get_id(), - 7 - ); - assert_eq!(AstStatement::CaseCondition { condition: Box::new(empty_stmt()), id: 7 }.get_id(), 7); - assert_eq!( - AstFactory::create_case_statement(empty_stmt(), vec![], vec![], SourceLocation::undefined(), 7) - .get_id(), - 7 - ); - assert_eq!(AstStatement::EmptyStatement { location: SourceLocation::undefined(), id: 7 }.get_id(), 7); - assert_eq!(AstStatement::ExpressionList { expressions: vec![], id: 7 }.get_id(), 7); - assert_eq!( - AstFactory::create_for_loop( - empty_stmt(), - empty_stmt(), - empty_stmt(), - None, - vec![], - SourceLocation::undefined(), - 7 - ) - .get_id(), - 7 - ); - assert_eq!( - AstFactory::create_if_statement(Vec::new(), Vec::new(), SourceLocation::undefined(), 7).get_id(), - 7 - ); - assert_eq!( - AstStatement::Literal { kind: AstLiteral::Null, location: SourceLocation::undefined(), id: 7 } - .get_id(), - 7 - ); - assert_eq!( - AstStatement::MultipliedStatement { - element: Box::new(empty_stmt()), - multiplier: 9, - location: SourceLocation::undefined(), - id: 7 - } - .get_id(), - 7 - ); - assert_eq!( - AstStatement::OutputAssignment { left: Box::new(empty_stmt()), right: Box::new(empty_stmt()), id: 7 } - .get_id(), - 7 - ); - assert_eq!( - AstStatement::RangeStatement { start: Box::new(empty_stmt()), end: Box::new(empty_stmt()), id: 7 } - .get_id(), - 7 - ); - assert_eq!( - AstStatement::Identifier { name: "ab".to_string(), location: SourceLocation::undefined(), id: 7 } - .get_id(), - 7 - ); - assert_eq!( - AstFactory::create_repeat_statement(empty_stmt(), vec![], SourceLocation::undefined(), 7).get_id(), - 7 - ); - assert_eq!( - AstStatement::UnaryExpression { - operator: Operator::Minus, - value: Box::new(empty_stmt()), - location: SourceLocation::undefined(), - id: 7 - } - .get_id(), - 7 - ); - assert_eq!( - AstFactory::create_while_statement(empty_stmt(), vec![], SourceLocation::undefined(), 7).get_id(), - 7 - ); -} - -fn at(location: Range) -> AstStatement { - let factory = SourceLocationFactory::internal(""); - AstStatement::EmptyStatement { id: 7, location: factory.create_range(location) } -} - -#[test] -fn location_implementation_for_all_statements() { - let factory = SourceLocationFactory::internal(""); - assert_eq!( - AstStatement::Assignment { left: Box::new(at(0..2)), right: Box::new(at(3..8)), id: 7 } - .get_location(), - factory.create_range(0..8) - ); - assert_eq!( - AstStatement::BinaryExpression { - left: Box::new(at(0..2)), - right: Box::new(at(3..8)), - operator: Operator::And, - id: 7 - } - .get_location(), - factory.create_range(0..8) - ); - assert_eq!( - AstStatement::CallStatement { - operator: Box::new(empty_stmt()), - parameters: Box::new(None), - id: 7, - location: SourceLocation::undefined() - } - .get_location(), - SourceLocation::undefined() - ); - assert_eq!( - AstStatement::CaseCondition { condition: Box::new(at(2..4)), id: 7 }.get_location(), - factory.create_range(2..4) - ); - assert_eq!( - AstFactory::create_case_statement(empty_stmt(), vec![], vec![], SourceLocation::undefined(), 7) - .get_location(), - SourceLocation::undefined() - ); - assert_eq!( - AstStatement::EmptyStatement { location: SourceLocation::undefined(), id: 7 }.get_location(), - SourceLocation::undefined() - ); - assert_eq!( - AstStatement::ExpressionList { expressions: vec![at(0..3), at(4..8)], id: 7 }.get_location(), - factory.create_range(0..8) - ); - assert_eq!( - AstFactory::create_for_loop( - empty_stmt(), - empty_stmt(), - empty_stmt(), - None, - vec![], - SourceLocation::undefined(), - 7 - ) - .get_location(), - SourceLocation::undefined() - ); - assert_eq!( - AstFactory::create_if_statement(Vec::new(), Vec::new(), SourceLocation::undefined(), 7) - .get_location(), - SourceLocation::undefined() - ); - assert_eq!( - AstStatement::Literal { kind: AstLiteral::Null, location: SourceLocation::undefined(), id: 7 } - .get_location(), - SourceLocation::undefined() - ); - assert_eq!( - AstStatement::MultipliedStatement { - element: Box::new(empty_stmt()), - multiplier: 9, - location: SourceLocation::undefined(), - id: 7 - } - .get_location(), - SourceLocation::undefined() - ); - assert_eq!( - AstStatement::OutputAssignment { left: Box::new(at(0..3)), right: Box::new(at(4..9)), id: 7 } - .get_location(), - factory.create_range(0..9) - ); - assert_eq!( - AstStatement::RangeStatement { start: Box::new(at(0..3)), end: Box::new(at(6..9)), id: 7 } - .get_location(), - factory.create_range(0..9) - ); - assert_eq!( - AstStatement::Identifier { name: "ab".to_string(), location: SourceLocation::undefined(), id: 7 } - .get_location(), - SourceLocation::undefined() - ); - assert_eq!( - AstFactory::create_repeat_statement(empty_stmt(), vec![], SourceLocation::undefined(), 7) - .get_location(), - SourceLocation::undefined() - ); - assert_eq!( - AstStatement::UnaryExpression { - operator: Operator::Minus, - value: Box::new(empty_stmt()), - location: SourceLocation::undefined(), - id: 7 - } - .get_location(), - SourceLocation::undefined() - ); - assert_eq!( - AstFactory::create_while_statement(empty_stmt(), vec![], SourceLocation::undefined(), 7) - .get_location(), - SourceLocation::undefined() - ); -} diff --git a/src/parser/tests/parse_errors/parse_error_literals_tests.rs b/src/parser/tests/parse_errors/parse_error_literals_tests.rs index 7bb2f82fb1..5a562585f3 100644 --- a/src/parser/tests/parse_errors/parse_error_literals_tests.rs +++ b/src/parser/tests/parse_errors/parse_error_literals_tests.rs @@ -1,4 +1,5 @@ use insta::{assert_debug_snapshot, assert_snapshot}; + use plc_diagnostics::diagnostics::Diagnostic; use crate::test_utils::tests::{parse, parse_and_validate_buffered, parse_buffered}; diff --git a/src/parser/tests/parse_errors/parse_error_statements_tests.rs b/src/parser/tests/parse_errors/parse_error_statements_tests.rs index 3a4935a909..d53683fafb 100644 --- a/src/parser/tests/parse_errors/parse_error_statements_tests.rs +++ b/src/parser/tests/parse_errors/parse_error_statements_tests.rs @@ -2,7 +2,7 @@ use crate::{parser::tests::ref_to, test_utils::tests::parse_buffered}; use insta::{assert_debug_snapshot, assert_snapshot}; use plc_ast::ast::{ - AccessModifier, AstStatement, DataType, DataTypeDeclaration, LinkageType, UserTypeDeclaration, Variable, + AccessModifier, AstFactory, DataType, DataTypeDeclaration, LinkageType, UserTypeDeclaration, Variable, VariableBlock, VariableBlockType, }; use plc_source::source_location::SourceLocation; @@ -58,15 +58,16 @@ fn missing_comma_in_call_parameters() { format!("{:#?}", pou.statements), format!( "{:#?}", - vec![AstStatement::CallStatement { - location: SourceLocation::undefined(), - operator: Box::new(ref_to("buz")), - parameters: Box::new(Some(AstStatement::ExpressionList { - expressions: vec![ref_to("a"), ref_to("b"),], - id: 0 - })), - id: 0 - }] + vec![AstFactory::create_call_statement( + ref_to("buz"), + Some(AstFactory::create_expression_list( + vec![ref_to("a"), ref_to("b")], + SourceLocation::undefined(), + 0 + )), + 0, + SourceLocation::undefined() + )] ) ); } @@ -87,20 +88,22 @@ fn illegal_semicolon_in_call_parameters() { assert_snapshot!(diagnostics); let pou = &compilation_unit.implementations[0]; + assert_eq!( format!("{:#?}", pou.statements), format!( "{:#?}", vec![ - AstStatement::CallStatement { - location: SourceLocation::undefined(), - operator: Box::new(ref_to("buz")), - parameters: Box::new(Some(AstStatement::ExpressionList { - expressions: vec![ref_to("a"), ref_to("b")], - id: 0 - })), - id: 0 - }, + AstFactory::create_call_statement( + ref_to("buz"), + Some(AstFactory::create_expression_list( + vec![ref_to("a"), ref_to("b")], + SourceLocation::undefined(), + 0 + )), + 0, + SourceLocation::undefined() + ), ref_to("c") ] ) diff --git a/src/parser/tests/snapshots/rusty__parser__tests__type_parser_tests__array_type_can_be_parsed_test.snap b/src/parser/tests/snapshots/rusty__parser__tests__type_parser_tests__array_type_can_be_parsed_test.snap new file mode 100644 index 0000000000..089f72cde0 --- /dev/null +++ b/src/parser/tests/snapshots/rusty__parser__tests__type_parser_tests__array_type_can_be_parsed_test.snap @@ -0,0 +1,25 @@ +--- +source: src/parser/tests/type_parser_tests.rs +expression: ast_string +--- +UserTypeDeclaration { + data_type: ArrayType { + name: Some( + "MyArray", + ), + bounds: RangeStatement { + start: LiteralInteger { + value: 0, + }, + end: LiteralInteger { + value: 8, + }, + }, + referenced_type: DataTypeReference { + referenced_type: "INT", + }, + is_variable_length: false, + }, + initializer: None, + scope: None, +} diff --git a/src/parser/tests/statement_parser_tests.rs b/src/parser/tests/statement_parser_tests.rs index fa9494175e..7c932fe27e 100644 --- a/src/parser/tests/statement_parser_tests.rs +++ b/src/parser/tests/statement_parser_tests.rs @@ -1,6 +1,10 @@ -use crate::{parser::tests::ref_to, test_utils::tests::parse, typesystem::DINT_TYPE}; +use crate::{ + parser::tests::{empty_stmt, ref_to}, + test_utils::tests::parse, + typesystem::DINT_TYPE, +}; use insta::assert_snapshot; -use plc_ast::ast::{AstFactory, AstStatement, DataType, DataTypeDeclaration, Variable}; +use plc_ast::ast::{AstFactory, DataType, DataTypeDeclaration, Variable}; use plc_source::source_location::SourceLocation; use pretty_assertions::*; @@ -10,17 +14,10 @@ fn empty_statements_are_are_parsed() { let result = parse(src).0; let prg = &result.implementations[0]; + assert_eq!( format!("{:?}", prg.statements), - format!( - "{:?}", - vec![ - AstStatement::EmptyStatement { location: SourceLocation::undefined(), id: 0 }, - AstStatement::EmptyStatement { location: SourceLocation::undefined(), id: 0 }, - AstStatement::EmptyStatement { location: SourceLocation::undefined(), id: 0 }, - AstStatement::EmptyStatement { location: SourceLocation::undefined(), id: 0 }, - ] - ), + format!("{:?}", vec![empty_stmt(), empty_stmt(), empty_stmt(), empty_stmt(),]), ); } @@ -36,10 +33,10 @@ fn empty_statements_are_parsed_before_a_statement() { format!( "{:?}", vec![ - AstStatement::EmptyStatement { location: SourceLocation::undefined(), id: 0 }, - AstStatement::EmptyStatement { location: SourceLocation::undefined(), id: 0 }, - AstStatement::EmptyStatement { location: SourceLocation::undefined(), id: 0 }, - AstStatement::EmptyStatement { location: SourceLocation::undefined(), id: 0 }, + empty_stmt(), + empty_stmt(), + empty_stmt(), + empty_stmt(), AstFactory::create_member_reference( AstFactory::create_identifier("x", &SourceLocation::undefined(), 0), None, @@ -135,10 +132,11 @@ fn inline_enum_declaration_can_be_parsed() { data_type: DataType::EnumType { name: None, numeric_type: DINT_TYPE.to_string(), - elements: AstStatement::ExpressionList { - expressions: vec![ref_to("red"), ref_to("yellow"), ref_to("green")], - id: 0, - }, + elements: AstFactory::create_expression_list( + vec![ref_to("red"), ref_to("yellow"), ref_to("green")], + SourceLocation::undefined(), + 0, + ), }, location: SourceLocation::undefined(), scope: None, diff --git a/src/parser/tests/type_parser_tests.rs b/src/parser/tests/type_parser_tests.rs index 5a6c69c34f..5bb2655a26 100644 --- a/src/parser/tests/type_parser_tests.rs +++ b/src/parser/tests/type_parser_tests.rs @@ -1,9 +1,6 @@ use crate::test_utils::tests::{parse, parse_buffered}; use insta::{assert_debug_snapshot, assert_snapshot}; -use plc_ast::{ - ast::{AstStatement, DataType, DataTypeDeclaration, UserTypeDeclaration, Variable}, - literals::AstLiteral, -}; +use plc_ast::ast::{DataType, DataTypeDeclaration, UserTypeDeclaration, Variable}; use plc_source::source_location::SourceLocation; use pretty_assertions::*; @@ -181,38 +178,7 @@ fn array_type_can_be_parsed_test() { ); let ast_string = format!("{:#?}", &result.user_types[0]); - - let expected_ast = format!( - "{:#?}", - &UserTypeDeclaration { - data_type: DataType::ArrayType { - name: Some("MyArray".to_string()), - bounds: AstStatement::RangeStatement { - start: Box::new(AstStatement::Literal { - kind: AstLiteral::new_integer(0), - location: SourceLocation::undefined(), - id: 0, - }), - end: Box::new(AstStatement::Literal { - kind: AstLiteral::new_integer(8), - location: SourceLocation::undefined(), - id: 0, - }), - id: 0, - }, - referenced_type: Box::new(DataTypeDeclaration::DataTypeReference { - referenced_type: "INT".to_string(), - location: SourceLocation::undefined(), - }), - is_variable_length: false, - }, - initializer: None, - location: SourceLocation::undefined(), - scope: None, - } - ); - - assert_eq!(ast_string, expected_ast); + assert_snapshot!(ast_string); } #[test] diff --git a/src/resolver.rs b/src/resolver.rs index 0313a35b14..4d557c7ec3 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -13,9 +13,9 @@ use std::{ use indexmap::{IndexMap, IndexSet}; use plc_ast::{ ast::{ - self, flatten_expression_list, AstFactory, AstId, AstStatement, CompilationUnit, DataType, - DataTypeDeclaration, DirectAccessType, Operator, Pou, ReferenceAccess, TypeNature, - UserTypeDeclaration, Variable, + self, flatten_expression_list, Assignment, AstFactory, AstId, AstNode, AstStatement, + BinaryExpression, CastStatement, CompilationUnit, DataType, DataTypeDeclaration, DirectAccessType, + Operator, Pou, ReferenceAccess, ReferenceExpr, TypeNature, UserTypeDeclaration, Variable, }, control_statements::AstControlStatement, literals::{Array, AstLiteral, StringValue}, @@ -175,7 +175,7 @@ pub struct TypeAnnotator<'i> { } impl TypeAnnotator<'_> { - pub fn annotate(&mut self, s: &AstStatement, annotation: StatementAnnotation) { + pub fn annotate(&mut self, s: &AstNode, annotation: StatementAnnotation) { match &annotation { StatementAnnotation::Function { return_type, qualified_name, call_name } => { let name = call_name.as_ref().unwrap_or(qualified_name); @@ -207,8 +207,8 @@ impl TypeAnnotator<'_> { self.annotation_map.annotate(s, annotation); } - fn visit_compare_statement(&mut self, ctx: &VisitorContext, statement: &AstStatement) { - let AstStatement::BinaryExpression { operator, left, right, .. } = statement else { + fn visit_compare_statement(&mut self, ctx: &VisitorContext, statement: &AstNode) { + let AstStatement::BinaryExpression ( BinaryExpression{ operator, left, right}) = statement.get_stmt() else { return; }; let mut ctx = ctx.clone(); @@ -248,10 +248,10 @@ impl TypeAnnotator<'_> { &self, ctx: &mut VisitorContext, operator: &Operator, - left: &AstStatement, - right: &AstStatement, - statement: &AstStatement, - ) -> AstStatement { + left: &AstNode, + right: &AstNode, + statement: &AstNode, + ) -> AstNode { let left_type = self .annotation_map .get_type_hint(left, self.index) @@ -271,10 +271,10 @@ impl TypeAnnotator<'_> { &statement.get_location(), ) }) - .unwrap_or(AstStatement::EmptyStatement { - location: statement.get_location(), - id: ctx.id_provider.next_id(), - }) + .unwrap_or(AstFactory::create_empty_statement( + statement.get_location(), + ctx.id_provider.next_id(), + )) } } @@ -315,7 +315,7 @@ pub enum StatementAnnotation { qualified_name: String, }, ReplacementAst { - statement: AstStatement, + statement: AstNode, }, } @@ -389,25 +389,25 @@ impl Dependency { } pub trait AnnotationMap { - fn get(&self, s: &AstStatement) -> Option<&StatementAnnotation>; + fn get(&self, s: &AstNode) -> Option<&StatementAnnotation>; - fn get_hint(&self, s: &AstStatement) -> Option<&StatementAnnotation>; + fn get_hint(&self, s: &AstNode) -> Option<&StatementAnnotation>; - fn get_hidden_function_call(&self, s: &AstStatement) -> Option<&AstStatement>; + fn get_hidden_function_call(&self, s: &AstNode) -> Option<&AstNode>; - fn get_type_or_void<'i>(&'i self, s: &AstStatement, index: &'i Index) -> &'i typesystem::DataType { + fn get_type_or_void<'i>(&'i self, s: &AstNode, index: &'i Index) -> &'i typesystem::DataType { self.get_type(s, index).unwrap_or_else(|| index.get_void_type()) } - fn get_hint_or_void<'i>(&'i self, s: &AstStatement, index: &'i Index) -> &'i typesystem::DataType { + fn get_hint_or_void<'i>(&'i self, s: &AstNode, index: &'i Index) -> &'i typesystem::DataType { self.get_type_hint(s, index).unwrap_or_else(|| index.get_void_type()) } - fn get_type_hint<'i>(&self, s: &AstStatement, index: &'i Index) -> Option<&'i typesystem::DataType> { + fn get_type_hint<'i>(&self, s: &AstNode, index: &'i Index) -> Option<&'i typesystem::DataType> { self.get_hint(s).and_then(|it| self.get_type_for_annotation(index, it)) } - fn get_type<'i>(&'i self, s: &AstStatement, index: &'i Index) -> Option<&'i typesystem::DataType> { + fn get_type<'i>(&'i self, s: &AstNode, index: &'i Index) -> Option<&'i typesystem::DataType> { self.get(s).and_then(|it| self.get_type_for_annotation(index, it)) } @@ -435,7 +435,7 @@ pub trait AnnotationMap { /// returns the name of the callable that is refered by the given statemt /// or none if this thing may not be callable - fn get_call_name(&self, s: &AstStatement) -> Option<&str> { + fn get_call_name(&self, s: &AstNode) -> Option<&str> { match self.get(s) { Some(StatementAnnotation::Function { qualified_name, call_name, .. }) => { call_name.as_ref().map(String::as_str).or(Some(qualified_name.as_str())) @@ -448,16 +448,16 @@ pub trait AnnotationMap { } } - fn get_qualified_name(&self, s: &AstStatement) -> Option<&str> { + fn get_qualified_name(&self, s: &AstNode) -> Option<&str> { match self.get(s) { Some(StatementAnnotation::Function { qualified_name, .. }) => Some(qualified_name.as_str()), _ => self.get_call_name(s), } } - fn has_type_annotation(&self, s: &AstStatement) -> bool; + fn has_type_annotation(&self, s: &AstNode) -> bool; - fn get_generic_nature(&self, s: &AstStatement) -> Option<&TypeNature>; + fn get_generic_nature(&self, s: &AstNode) -> Option<&TypeNature>; } #[derive(Debug)] @@ -469,7 +469,7 @@ pub struct AstAnnotations { } impl AnnotationMap for AstAnnotations { - fn get(&self, s: &AstStatement) -> Option<&StatementAnnotation> { + fn get(&self, s: &AstNode) -> Option<&StatementAnnotation> { if s.get_id() == self.bool_id { Some(&self.bool_annotation) } else { @@ -477,7 +477,7 @@ impl AnnotationMap for AstAnnotations { } } - fn get_hint(&self, s: &AstStatement) -> Option<&StatementAnnotation> { + fn get_hint(&self, s: &AstNode) -> Option<&StatementAnnotation> { if s.get_id() == self.bool_id { Some(&self.bool_annotation) } else { @@ -485,15 +485,15 @@ impl AnnotationMap for AstAnnotations { } } - fn get_hidden_function_call(&self, s: &AstStatement) -> Option<&AstStatement> { + fn get_hidden_function_call(&self, s: &AstNode) -> Option<&AstNode> { self.annotation_map.get_hidden_function_call(s) } - fn has_type_annotation(&self, s: &AstStatement) -> bool { + fn has_type_annotation(&self, s: &AstNode) -> bool { self.annotation_map.has_type_annotation(s) } - fn get_generic_nature(&self, s: &AstStatement) -> Option<&TypeNature> { + fn get_generic_nature(&self, s: &AstNode) -> Option<&TypeNature> { self.annotation_map.get_generic_nature(s) } } @@ -534,7 +534,7 @@ pub struct AnnotationMapImpl { /// ... /// x : BYTE(0..100); /// x := 10; // a call to `CheckRangeUnsigned` is maped to `10` - hidden_function_calls: IndexMap, + hidden_function_calls: IndexMap, //An index of newly created types pub new_index: Index, @@ -554,52 +554,52 @@ impl AnnotationMapImpl { } /// annotates the given statement (using it's `get_id()`) with the given type-name - pub fn annotate(&mut self, s: &AstStatement, annotation: StatementAnnotation) { + pub fn annotate(&mut self, s: &AstNode, annotation: StatementAnnotation) { self.type_map.insert(s.get_id(), annotation); } - pub fn annotate_type_hint(&mut self, s: &AstStatement, annotation: StatementAnnotation) { + pub fn annotate_type_hint(&mut self, s: &AstNode, annotation: StatementAnnotation) { self.type_hint_map.insert(s.get_id(), annotation); } /// annotates the given statement s with the call-statement f so codegen can generate /// a hidden call f instead of generating s - pub fn annotate_hidden_function_call(&mut self, s: &AstStatement, f: AstStatement) { + pub fn annotate_hidden_function_call(&mut self, s: &AstNode, f: AstNode) { self.hidden_function_calls.insert(s.get_id(), f); } /// Annotates the ast statement with its original generic nature - pub fn add_generic_nature(&mut self, s: &AstStatement, nature: TypeNature) { + pub fn add_generic_nature(&mut self, s: &AstNode, nature: TypeNature) { self.generic_nature_map.insert(s.get_id(), nature); } } impl AnnotationMap for AnnotationMapImpl { - fn get(&self, s: &AstStatement) -> Option<&StatementAnnotation> { + fn get(&self, s: &AstNode) -> Option<&StatementAnnotation> { self.type_map.get(&s.get_id()) } - fn get_hint(&self, s: &AstStatement) -> Option<&StatementAnnotation> { + fn get_hint(&self, s: &AstNode) -> Option<&StatementAnnotation> { self.type_hint_map.get(&s.get_id()) } /// returns the function call previously annoted on s via annotate_hidden_function_call(...) - fn get_hidden_function_call(&self, s: &AstStatement) -> Option<&AstStatement> { + fn get_hidden_function_call(&self, s: &AstNode) -> Option<&AstNode> { self.hidden_function_calls.get(&s.get_id()) } - fn get_type<'i>(&'i self, s: &AstStatement, index: &'i Index) -> Option<&'i typesystem::DataType> { + fn get_type<'i>(&'i self, s: &AstNode, index: &'i Index) -> Option<&'i typesystem::DataType> { self.get(s).and_then(|it| { self.get_type_for_annotation(index, it) .or_else(|| self.get_type_for_annotation(&self.new_index, it)) }) } - fn has_type_annotation(&self, s: &AstStatement) -> bool { + fn has_type_annotation(&self, s: &AstNode) -> bool { self.type_map.contains_key(&s.get_id()) } - fn get_generic_nature(&self, s: &AstStatement) -> Option<&TypeNature> { + fn get_generic_nature(&self, s: &AstNode) -> Option<&TypeNature> { self.generic_nature_map.get(&s.get_id()) } } @@ -726,8 +726,8 @@ impl<'i> TypeAnnotator<'i> { fn update_right_hand_side_expected_type( &mut self, ctx: &VisitorContext, - annotated_left_side: &AstStatement, - right_side: &AstStatement, + annotated_left_side: &AstNode, + right_side: &AstNode, ) { if let Some(expected_type) = self.annotation_map.get_type(annotated_left_side, self.index).cloned() { // for assignments on SubRanges check if there are range type check functions @@ -757,7 +757,7 @@ impl<'i> TypeAnnotator<'i> { } } - fn update_right_hand_side(&mut self, expected_type: &typesystem::DataType, right_side: &AstStatement) { + fn update_right_hand_side(&mut self, expected_type: &typesystem::DataType, right_side: &AstNode) { //annotate the right-hand side as a whole self.annotation_map .annotate_type_hint(right_side, StatementAnnotation::value(expected_type.get_name())); @@ -768,17 +768,17 @@ impl<'i> TypeAnnotator<'i> { /// updates the expected types of statements on the right side of an assignment /// e.g. x : ARRAY [0..1] OF BYTE := [2,3]; - fn update_expected_types(&mut self, expected_type: &typesystem::DataType, statement: &AstStatement) { + fn update_expected_types(&mut self, expected_type: &typesystem::DataType, statement: &AstNode) { //see if we need to dive into it - match statement { - AstStatement::Literal { kind: AstLiteral::Array(Array { elements: Some(elements) }), .. } => { + match statement.get_stmt() { + AstStatement::Literal(AstLiteral::Array(Array { elements: Some(elements) }), ..) => { //annotate the literal-array itself self.annotation_map .annotate_type_hint(statement, StatementAnnotation::value(expected_type.get_name())); //TODO exprssionList and MultipliedExpressions are a mess! if matches!( - elements.as_ref(), - AstStatement::ExpressionList { .. } | AstStatement::MultipliedStatement { .. } + elements.get_stmt(), + AstStatement::ExpressionList(..) | AstStatement::MultipliedStatement(..) ) { self.annotation_map .annotate_type_hint(elements, StatementAnnotation::value(expected_type.get_name())); @@ -792,7 +792,7 @@ impl<'i> TypeAnnotator<'i> { } } } - AstStatement::Assignment { left, right, .. } => { + AstStatement::Assignment(Assignment { left, right }, ..) => { // struct initialization (left := right) // find out left's type and update a type hint for right if let ( @@ -812,24 +812,24 @@ impl<'i> TypeAnnotator<'i> { } } } - AstStatement::MultipliedStatement { element: elements, .. } => { + AstStatement::MultipliedStatement(data, ..) => { // n(elements) //annotate the type to all multiplied elements - for ele in AstStatement::get_as_list(elements) { + for ele in AstNode::get_as_list(&data.element) { self.update_expected_types(expected_type, ele); } } - AstStatement::ExpressionList { expressions, .. } => { + AstStatement::ExpressionList(expressions, ..) => { //annotate the type to all elements for ele in expressions { self.update_expected_types(expected_type, ele); } } - AstStatement::RangeStatement { start, end, .. } => { - self.update_expected_types(expected_type, start); - self.update_expected_types(expected_type, end); + AstStatement::RangeStatement(data, ..) => { + self.update_expected_types(expected_type, &data.start); + self.update_expected_types(expected_type, &data.end); } - AstStatement::Literal { kind: AstLiteral::Integer { .. }, .. } => { + AstStatement::Literal(AstLiteral::Integer { .. }, ..) => { //special case -> promote a literal-Integer directly, not via type-hint // (avoid later cast) if expected_type.get_type_information().is_float() { @@ -845,8 +845,7 @@ impl<'i> TypeAnnotator<'i> { .annotate_type_hint(statement, StatementAnnotation::value(expected_type.get_name())) } } - AstStatement::Literal { kind: AstLiteral::String { .. }, .. } - | AstStatement::BinaryExpression { .. } => { + AstStatement::Literal(AstLiteral::String { .. }, ..) | AstStatement::BinaryExpression { .. } => { // needed if we try to initialize an array with an expression-list // without we would annotate a false type this would leed to an error in expression_generator if let DataTypeInformation::Array { inner_type_name, .. } = @@ -883,7 +882,7 @@ impl<'i> TypeAnnotator<'i> { //right side being the local context let ctx = ctx.with_lhs(expected_type.get_name()); - if matches!(initializer, AstStatement::DefaultValue { .. }) { + if initializer.is_default_value() { // the default-placeholder must be annotated with the correct type, // it will be replaced by the appropriate literal later self.annotate(initializer, StatementAnnotation::value(expected_type.get_name())); @@ -903,7 +902,7 @@ impl<'i> TypeAnnotator<'i> { fn type_hint_for_array_of_structs( &mut self, expected_type: &typesystem::DataType, - statement: &AstStatement, + statement: &AstNode, ctx: &VisitorContext, ) { match expected_type.get_type_information() { @@ -918,8 +917,8 @@ impl<'i> TypeAnnotator<'i> { return; } - match statement { - AstStatement::Literal { kind: AstLiteral::Array(array), .. } => match array.elements() { + match statement.get_stmt() { + AstStatement::Literal(AstLiteral::Array(array)) => match array.elements() { Some(elements) if elements.is_expression_list() => { self.type_hint_for_array_of_structs(expected_type, elements, &ctx) } @@ -927,7 +926,7 @@ impl<'i> TypeAnnotator<'i> { _ => (), }, - AstStatement::ExpressionList { expressions, .. } => { + AstStatement::ExpressionList(expressions) => { for expression in expressions { // annotate with the arrays inner_type let name = inner_data_type.get_name().to_string(); @@ -939,8 +938,8 @@ impl<'i> TypeAnnotator<'i> { } } - AstStatement::Assignment { left, right, .. } if left.is_reference() => { - let AstStatement::Literal { kind: AstLiteral::Array(array), .. } = right.as_ref() else { return }; + AstStatement::Assignment(Assignment { left, right, .. }) if left.is_reference() => { + let AstStatement::Literal (AstLiteral::Array(array)) = right.as_ref().get_stmt() else { return }; let Some(elements) = array.elements() else { return }; if let Some(datatype) = self.annotation_map.get_type(left, self.index).cloned() { @@ -959,8 +958,8 @@ impl<'i> TypeAnnotator<'i> { for (idx, member) in members.iter().enumerate() { let data_type = self.index.get_effective_type_or_void_by_name(member.get_type_name()); if data_type.is_array() { - let Some(AstStatement::Assignment { right, .. }) = flattened.get(idx) else { continue }; - self.type_hint_for_array_of_structs(data_type, right, ctx); + let Some(AstStatement::Assignment(data)) = flattened.get(idx).map(|it| it.get_stmt()) else { continue }; + self.type_hint_for_array_of_structs(data_type, &data.right, ctx); } } } @@ -1047,21 +1046,21 @@ impl<'i> TypeAnnotator<'i> { } } - pub fn visit_statement(&mut self, ctx: &VisitorContext, statement: &AstStatement) { + pub fn visit_statement(&mut self, ctx: &VisitorContext, statement: &AstNode) { self.visit_statement_control(ctx, statement); } /// annotate a control statement - fn visit_statement_control(&mut self, ctx: &VisitorContext, statement: &AstStatement) { - match statement { - AstStatement::ControlStatement { kind: AstControlStatement::If(stmt), .. } => { + fn visit_statement_control(&mut self, ctx: &VisitorContext, statement: &AstNode) { + match statement.get_stmt() { + AstStatement::ControlStatement(AstControlStatement::If(stmt), ..) => { stmt.blocks.iter().for_each(|b| { self.visit_statement(ctx, b.condition.as_ref()); b.body.iter().for_each(|s| self.visit_statement(ctx, s)); }); stmt.else_block.iter().for_each(|e| self.visit_statement(ctx, e)); } - AstStatement::ControlStatement { kind: AstControlStatement::ForLoop(stmt), .. } => { + AstStatement::ControlStatement(AstControlStatement::ForLoop(stmt), ..) => { visit_all_statements!(self, ctx, &stmt.counter, &stmt.start, &stmt.end); if let Some(by_step) = &stmt.by_step { self.visit_statement(ctx, by_step); @@ -1081,12 +1080,12 @@ impl<'i> TypeAnnotator<'i> { } stmt.body.iter().for_each(|s| self.visit_statement(ctx, s)); } - AstStatement::ControlStatement { kind: AstControlStatement::WhileLoop(stmt), .. } - | AstStatement::ControlStatement { kind: AstControlStatement::RepeatLoop(stmt), .. } => { + AstStatement::ControlStatement(AstControlStatement::WhileLoop(stmt), ..) + | AstStatement::ControlStatement(AstControlStatement::RepeatLoop(stmt), ..) => { self.visit_statement(ctx, &stmt.condition); stmt.body.iter().for_each(|s| self.visit_statement(ctx, s)); } - AstStatement::ControlStatement { kind: AstControlStatement::Case(stmt), .. } => { + AstStatement::ControlStatement(AstControlStatement::Case(stmt), ..) => { self.visit_statement(ctx, &stmt.selector); let selector_type = self.annotation_map.get_type(&stmt.selector, self.index).cloned(); stmt.case_blocks.iter().for_each(|b| { @@ -1098,7 +1097,7 @@ impl<'i> TypeAnnotator<'i> { }); stmt.else_block.iter().for_each(|s| self.visit_statement(ctx, s)); } - AstStatement::CaseCondition { condition, .. } => self.visit_statement(ctx, condition), + AstStatement::CaseCondition(condition, ..) => self.visit_statement(ctx, condition), _ => { self.visit_statement_expression(ctx, statement); } @@ -1106,25 +1105,25 @@ impl<'i> TypeAnnotator<'i> { } /// annotate an expression statement - fn visit_statement_expression(&mut self, ctx: &VisitorContext, statement: &AstStatement) { - match statement { - AstStatement::DirectAccess { access, index, .. } => { + fn visit_statement_expression(&mut self, ctx: &VisitorContext, statement: &AstNode) { + match statement.get_stmt() { + AstStatement::DirectAccess(data, ..) => { let ctx = VisitorContext { qualifier: None, ..ctx.clone() }; - visit_all_statements!(self, &ctx, index); - let access_type = get_direct_access_type(access); + visit_all_statements!(self, &ctx, &data.index); + let access_type = get_direct_access_type(&data.access); self.annotate(statement, StatementAnnotation::Value { resulting_type: access_type.into() }); } - AstStatement::HardwareAccess { access, .. } => { - let access_type = get_direct_access_type(access); + AstStatement::HardwareAccess(data, ..) => { + let access_type = get_direct_access_type(&data.access); self.annotate(statement, StatementAnnotation::Value { resulting_type: access_type.into() }); } - AstStatement::BinaryExpression { left, right, operator, .. } => { - visit_all_statements!(self, ctx, left, right); + AstStatement::BinaryExpression(data, ..) => { + visit_all_statements!(self, ctx, &data.left, &data.right); let statement_type = { let left_type = self .annotation_map - .get_type_hint(left, self.index) - .or_else(|| self.annotation_map.get_type(left, self.index)) + .get_type_hint(&data.left, self.index) + .or_else(|| self.annotation_map.get_type(&data.left, self.index)) .and_then(|it| self.index.find_effective_type(it)) .unwrap_or_else(|| self.index.get_void_type()); // do not use for is_pointer() check @@ -1132,8 +1131,8 @@ impl<'i> TypeAnnotator<'i> { self.index.get_intrinsic_type_by_name(left_type.get_name()).get_type_information(); let right_type = self .annotation_map - .get_type_hint(right, self.index) - .or_else(|| self.annotation_map.get_type(right, self.index)) + .get_type_hint(&data.right, self.index) + .or_else(|| self.annotation_map.get_type(&data.right, self.index)) .and_then(|it| self.index.find_effective_type(it)) .unwrap_or_else(|| self.index.get_void_type()); // do not use for is_pointer() check @@ -1156,7 +1155,7 @@ impl<'i> TypeAnnotator<'i> { ) }; - let target_name = if operator.is_bool_type() { + let target_name = if data.operator.is_bool_type() { BOOL_TYPE.to_string() } else { bigger_type.get_name().to_string() @@ -1169,10 +1168,10 @@ impl<'i> TypeAnnotator<'i> { // if these types are different we need to update the 'other' type's annotation let bigger_type = bigger_type.clone(); // clone here, so we release the borrow on self if bigger_is_right { - self.update_expected_types(&bigger_type, left); + self.update_expected_types(&bigger_type, &data.left); } if bigger_is_left { - self.update_expected_types(&bigger_type, right); + self.update_expected_types(&bigger_type, &data.right); } } @@ -1181,16 +1180,16 @@ impl<'i> TypeAnnotator<'i> { || right_type.get_type_information().is_pointer() { // get the target type of the binary expression - let target_type = if operator.is_comparison_operator() { + let target_type = if data.operator.is_comparison_operator() { // compare instructions result in BOOL // to generate valid IR code if a pointer is beeing compared to an integer // we need to cast the int to the pointers size if !left_type.get_type_information().is_pointer() { let left_type = left_type.clone(); // clone here, so we release the borrow on self - self.annotate_to_pointer_size_if_necessary(&left_type, left); + self.annotate_to_pointer_size_if_necessary(&left_type, &data.left); } else if !right_type.get_type_information().is_pointer() { let right_type = right_type.clone(); // clone here, so we release the borrow on self - self.annotate_to_pointer_size_if_necessary(&right_type, right); + self.annotate_to_pointer_size_if_necessary(&right_type, &data.right); } BOOL_TYPE } else if left_type.get_type_information().is_pointer() { @@ -1199,7 +1198,7 @@ impl<'i> TypeAnnotator<'i> { right_type.get_name() }; Some(target_type.to_string()) - } else if operator.is_comparison_operator() { + } else if data.operator.is_comparison_operator() { //Annotate as the function call to XXX_EQUALS/LESS/GREATER.. self.visit_compare_statement(ctx, statement); None @@ -1212,19 +1211,19 @@ impl<'i> TypeAnnotator<'i> { self.annotate(statement, StatementAnnotation::new_value(statement_type)); } } - AstStatement::UnaryExpression { value, operator, .. } => { - self.visit_statement(ctx, value); + AstStatement::UnaryExpression(data, ..) => { + self.visit_statement(ctx, &data.value); - let statement_type = if operator == &Operator::Minus { + let statement_type = if data.operator == Operator::Minus { let inner_type = - self.annotation_map.get_type_or_void(value, self.index).get_type_information(); + self.annotation_map.get_type_or_void(&data.value, self.index).get_type_information(); //keep the same type but switch to signed typesystem::get_signed_type(inner_type, self.index).map(|it| it.get_name().to_string()) } else { let inner_type = self .annotation_map - .get_type_or_void(value, self.index) + .get_type_or_void(&data.value, self.index) .get_type_information() .get_name() .to_string(); @@ -1237,38 +1236,38 @@ impl<'i> TypeAnnotator<'i> { } } - AstStatement::ExpressionList { expressions, .. } => { + AstStatement::ExpressionList(expressions, ..) => { expressions.iter().for_each(|e| self.visit_statement(ctx, e)) } - AstStatement::RangeStatement { start, end, .. } => { - visit_all_statements!(self, ctx, start, end); + AstStatement::RangeStatement(data, ..) => { + visit_all_statements!(self, ctx, &data.start, &data.end); } - AstStatement::Assignment { left, right, .. } => { - self.visit_statement(ctx, right); + AstStatement::Assignment(data, ..) => { + self.visit_statement(ctx, &data.right); if let Some(lhs) = ctx.lhs { //special context for left hand side - self.visit_statement(&ctx.with_pou(lhs).with_lhs(lhs), left); + self.visit_statement(&ctx.with_pou(lhs).with_lhs(lhs), &data.left); } else { - self.visit_statement(ctx, left); + self.visit_statement(ctx, &data.left); } // give a type hint that we want the right side to be stored in the left's type - self.update_right_hand_side_expected_type(ctx, left, right); + self.update_right_hand_side_expected_type(ctx, &data.left, &data.right); } - AstStatement::OutputAssignment { left, right, .. } => { - visit_all_statements!(self, ctx, left, right); + AstStatement::OutputAssignment(data, ..) => { + visit_all_statements!(self, ctx, &data.left, &data.right); if let Some(lhs) = ctx.lhs { //special context for left hand side - self.visit_statement(&ctx.with_pou(lhs), left); + self.visit_statement(&ctx.with_pou(lhs), &data.left); } else { - self.visit_statement(ctx, left); + self.visit_statement(ctx, &data.left); } - self.update_right_hand_side_expected_type(ctx, left, right); + self.update_right_hand_side_expected_type(ctx, &data.left, &data.right); } - AstStatement::CallStatement { .. } => { + AstStatement::CallStatement(..) => { self.visit_call_statement(statement, ctx); } - AstStatement::CastStatement { target, type_name, .. } => { + AstStatement::CastStatement(CastStatement { target, type_name }, ..) => { //see if this type really exists let data_type = self.index.find_effective_type_info(type_name); let statement_to_annotation = if let Some(DataTypeInformation::Enum { name, .. }) = data_type @@ -1281,26 +1280,26 @@ impl<'i> TypeAnnotator<'i> { } else if let Some(t) = data_type { // special handling for unlucky casted-strings where caste-type does not match the literal encoding // ´STRING#"abc"´ or ´WSTRING#'abc'´ - match (t, target.as_ref()) { + match (t, target.as_ref().get_stmt()) { ( DataTypeInformation::String { encoding: StringEncoding::Utf8, .. }, - AstStatement::Literal { - kind: AstLiteral::String(StringValue { value, is_wide: is_wide @ true }), - .. - }, + AstStatement::Literal(AstLiteral::String(StringValue { + value, + is_wide: is_wide @ true, + })), ) | ( DataTypeInformation::String { encoding: StringEncoding::Utf16, .. }, - AstStatement::Literal { - kind: AstLiteral::String(StringValue { value, is_wide: is_wide @ false }), - .. - }, + AstStatement::Literal(AstLiteral::String(StringValue { + value, + is_wide: is_wide @ false, + })), ) => { // visit the target-statement as if the programmer used the correct quotes to prevent // a utf16 literal-global-variable that needs to be casted back to utf8 or vice versa self.visit_statement( ctx, - &AstStatement::new_literal( + &AstNode::new_literal( AstLiteral::new_string(value.clone(), !is_wide), target.get_id(), target.get_location(), @@ -1319,8 +1318,8 @@ impl<'i> TypeAnnotator<'i> { self.annotate(stmt, StatementAnnotation::new_value(annotation)); } } - AstStatement::ReferenceExpr { access, base, .. } => { - self.visit_reference_expr(access, base.as_deref(), statement, ctx); + AstStatement::ReferenceExpr(data, ..) => { + self.visit_reference_expr(&data.access, data.base.as_deref(), statement, ctx); } _ => { self.visit_statement_literals(ctx, statement); @@ -1331,8 +1330,8 @@ impl<'i> TypeAnnotator<'i> { fn visit_reference_expr( &mut self, access: &ast::ReferenceAccess, - base: Option<&AstStatement>, - stmt: &AstStatement, + base: Option<&AstNode>, + stmt: &AstNode, ctx: &VisitorContext, ) { // first resolve base @@ -1381,7 +1380,7 @@ impl<'i> TypeAnnotator<'i> { self.annotate(target.as_ref(), annotation); self.annotate(stmt, StatementAnnotation::value(qualifier.as_str())); - if let AstStatement::Literal { .. } = target.as_ref() { + if let AstStatement::Literal(..) = target.get_stmt() { // treate casted literals as the casted type self.annotate(target.as_ref(), StatementAnnotation::value(qualifier.as_str())); } @@ -1425,7 +1424,7 @@ impl<'i> TypeAnnotator<'i> { } } - fn is_const_reference(&self, stmt: &AstStatement, ctx: &VisitorContext<'_>) -> bool { + fn is_const_reference(&self, stmt: &AstNode, ctx: &VisitorContext<'_>) -> bool { self.annotation_map .get(stmt) .map(|it| it.is_const()) @@ -1437,17 +1436,17 @@ impl<'i> TypeAnnotator<'i> { /// Statement annotation if one can be derived. This method the annotation! fn resolve_reference_expression( &mut self, - reference: &AstStatement, + reference: &AstNode, qualifier: Option<&str>, ctx: &VisitorContext<'_>, ) -> Option { - match reference { - AstStatement::Identifier { name, .. } => ctx + match reference.get_stmt() { + AstStatement::Identifier(name, ..) => ctx .resolve_strategy .iter() .find_map(|scope| scope.resolve_name(name, qualifier, self.index, ctx)), - AstStatement::Literal { .. } => { + AstStatement::Literal(..) => { self.visit_statement_literals(ctx, reference); let literal_annotation = self.annotation_map.get(reference).cloned(); // return what we just annotated //TODO not elegant, we need to clone if let Some((base_type, literal_type)) = @@ -1465,10 +1464,10 @@ impl<'i> TypeAnnotator<'i> { literal_annotation } - AstStatement::DirectAccess { access, index, .. } if qualifier.is_some() => { + AstStatement::DirectAccess(data, ..) if qualifier.is_some() => { // x.%X1 - bit access - self.visit_statement(ctx, index.as_ref()); - Some(StatementAnnotation::value(get_direct_access_type(access))) + self.visit_statement(ctx, data.index.as_ref()); + Some(StatementAnnotation::value(get_direct_access_type(&data.access))) } _ => None, } @@ -1476,7 +1475,7 @@ impl<'i> TypeAnnotator<'i> { /// annotates the vla-statement it with a type hint /// referencing the contained array. This is needed to simplify codegen and validation. - fn annotate_vla_hint(&mut self, ctx: &VisitorContext, statement: &AstStatement) { + fn annotate_vla_hint(&mut self, ctx: &VisitorContext, statement: &AstNode) { let DataTypeInformation::Struct { source: StructSource::Internal(InternalType::VariableLengthArray { .. }), members, @@ -1498,7 +1497,7 @@ impl<'i> TypeAnnotator<'i> { let Some(pou) = ctx.pou else { unreachable!("VLA not allowed outside of POUs") }; - let name = if let AstStatement::Identifier { name, .. } = statement { + let name = if let AstStatement::Identifier(name, ..) = statement.get_stmt() { name.as_str() } else { statement.get_flat_reference_name().expect("must be a reference to a VLA") @@ -1526,13 +1525,13 @@ impl<'i> TypeAnnotator<'i> { } } - fn visit_call_statement(&mut self, statement: &AstStatement, ctx: &VisitorContext) { - let (operator, parameters_stmt) = - if let AstStatement::CallStatement { operator, parameters, .. } = statement { - (operator.as_ref(), parameters.as_ref().as_ref()) - } else { - unreachable!("Always a call statement"); - }; + fn visit_call_statement(&mut self, statement: &AstNode, ctx: &VisitorContext) { + let (operator, parameters_stmt) = if let AstStatement::CallStatement(data, ..) = statement.get_stmt() + { + (data.operator.as_ref(), data.parameters.as_deref()) + } else { + unreachable!("Always a call statement"); + }; // #604 needed for recursive function calls self.visit_statement(&ctx.with_resolving_strategy(ResolvingScope::call_operator_scopes()), operator); let operator_qualifier = self.get_call_name(operator); @@ -1663,7 +1662,7 @@ impl<'i> TypeAnnotator<'i> { } } - fn get_call_name(&mut self, operator: &AstStatement) -> String { + fn get_call_name(&mut self, operator: &AstNode) -> String { let operator_qualifier = self .annotation_map .get(operator) @@ -1682,9 +1681,9 @@ impl<'i> TypeAnnotator<'i> { // call statements on array access "arr[1]()" will return a StatementAnnotation::Value StatementAnnotation::Value { resulting_type } => { // make sure we come from an array or function_block access - match operator { - AstStatement::ReferenceExpr { access: ReferenceAccess::Index(_),.. } => Some(resulting_type.clone()), - AstStatement::ReferenceExpr { access: ReferenceAccess::Deref, .. } => + match operator.get_stmt() { + AstStatement::ReferenceExpr ( ReferenceExpr{access: ReferenceAccess::Index(_), ..},.. ) => Some(resulting_type.clone()), + AstStatement::ReferenceExpr ( ReferenceExpr{access: ReferenceAccess::Deref, ..}, .. ) => // AstStatement::ArrayAccess { .. } => Some(resulting_type.clone()), // AstStatement::PointerAccess { .. } => { self.index.find_pou(resulting_type.as_str()).map(|it| it.get_name().to_string()), @@ -1698,8 +1697,8 @@ impl<'i> TypeAnnotator<'i> { operator_qualifier } - pub(crate) fn annotate_parameters(&mut self, p: &AstStatement, type_name: &str) { - if !matches!(p, AstStatement::Assignment { .. } | AstStatement::OutputAssignment { .. }) { + pub(crate) fn annotate_parameters(&mut self, p: &AstNode, type_name: &str) { + if !matches!(p.get_stmt(), AstStatement::Assignment(..) | AstStatement::OutputAssignment(..)) { if let Some(effective_member_type) = self.index.find_effective_type_by_name(type_name) { //update the type hint self.annotation_map @@ -1709,9 +1708,9 @@ impl<'i> TypeAnnotator<'i> { } /// annotate a literal statement - fn visit_statement_literals(&mut self, ctx: &VisitorContext, statement: &AstStatement) { - match statement { - AstStatement::Literal { kind, .. } => { + fn visit_statement_literals(&mut self, ctx: &VisitorContext, statement: &AstNode) { + match statement.get_stmt() { + AstStatement::Literal(kind, ..) => { match kind { AstLiteral::Bool { .. } => { self.annotate(statement, StatementAnnotation::value(BOOL_TYPE)); @@ -1756,8 +1755,8 @@ impl<'i> TypeAnnotator<'i> { _ => {} // ignore literalNull, arrays (they are covered earlier) } } - AstStatement::MultipliedStatement { element, .. } => { - self.visit_statement(ctx, element) + AstStatement::MultipliedStatement(data, ..) => { + self.visit_statement(ctx, &data.element) //TODO as of yet we have no way to derive a name that reflects a fixed size array } _ => {} @@ -1767,7 +1766,7 @@ impl<'i> TypeAnnotator<'i> { fn annotate_to_pointer_size_if_necessary( &mut self, value_type: &typesystem::DataType, - statement: &AstStatement, + statement: &AstNode, ) { // pointer size is 64Bits matching LINT // therefore get the bigger type of current and LINT to check if cast is necessary @@ -2016,23 +2015,17 @@ impl ResolvingScope { fn accept_cast_string_literal( literals: &mut StringLiterals, cast_type: &typesystem::DataType, - literal: &AstStatement, + literal: &AstNode, ) { // check if we need to register an additional string-literal - match (cast_type.get_type_information(), literal) { + match (cast_type.get_type_information(), literal.get_stmt()) { ( DataTypeInformation::String { encoding: StringEncoding::Utf8, .. }, - AstStatement::Literal { - kind: AstLiteral::String(StringValue { value, is_wide: is_wide @ true }), - .. - }, + AstStatement::Literal(AstLiteral::String(StringValue { value, is_wide: is_wide @ true })), ) | ( DataTypeInformation::String { encoding: StringEncoding::Utf16, .. }, - AstStatement::Literal { - kind: AstLiteral::String(StringValue { value, is_wide: is_wide @ false }), - .. - }, + AstStatement::Literal(AstLiteral::String(StringValue { value, is_wide: is_wide @ false })), ) => { // re-register the string-literal in the opposite encoding if *is_wide { diff --git a/src/resolver/const_evaluator.rs b/src/resolver/const_evaluator.rs index a7e546839f..805886c977 100644 --- a/src/resolver/const_evaluator.rs +++ b/src/resolver/const_evaluator.rs @@ -79,7 +79,11 @@ pub fn evaluate_constants(mut index: Index) -> (Index, Vec match (initial_value_literal, &candidates_type) { //we found an Int-Value and we found the const's datatype to be an unsigned Integer type (e.g. WORD) ( - Ok(Some(AstStatement::Literal { kind: AstLiteral::Integer(i), id, location })), + Ok(Some(AstNode { + stmt: AstStatement::Literal(AstLiteral::Integer(i)), + id, + location, + })), Some(DataTypeInformation::Integer { size, signed: false, .. }), ) => { // since we store literal-ints as i128 we need to truncate all of them down to their @@ -91,11 +95,7 @@ pub fn evaluate_constants(mut index: Index) -> (Index, Vec .get_mut_const_expressions() .mark_resolved( &candidate, - AstStatement::Literal { - id, - location, - kind: AstLiteral::new_integer(masked_value), - }, + AstNode::new_literal(AstLiteral::new_integer(masked_value), id, location), ) .expect("unknown id for const-expression"); //panic if we dont know the id failed_tries = 0; @@ -142,7 +142,7 @@ pub fn evaluate_constants(mut index: Index) -> (Index, Vec (index, unresolvable) } -fn do_resolve_candidate(index: &mut Index, candidate: ConstId, new_statement: AstStatement) { +fn do_resolve_candidate(index: &mut Index, candidate: ConstId, new_statement: AstNode) { index .get_mut_const_expressions() .mark_resolved(&candidate, new_statement) @@ -152,11 +152,11 @@ fn do_resolve_candidate(index: &mut Index, candidate: ConstId, new_statement: As /// returns true, if the given expression needs to be evaluated. /// literals must not be further evaluated and can be known at /// compile time -fn needs_evaluation(expr: &AstStatement) -> bool { - match expr { - AstStatement::Literal { kind, .. } => match &kind { - &AstLiteral::Array(Array { elements: Some(elements), .. }) => match elements.as_ref() { - AstStatement::ExpressionList { expressions, .. } => expressions.iter().any(needs_evaluation), +fn needs_evaluation(expr: &AstNode) -> bool { + match expr.get_stmt() { + AstStatement::Literal(kind) => match &kind { + &AstLiteral::Array(Array { elements: Some(elements), .. }) => match &elements.get_stmt() { + AstStatement::ExpressionList(expressions) => expressions.iter().any(needs_evaluation), _ => needs_evaluation(elements.as_ref()), }, @@ -165,9 +165,9 @@ fn needs_evaluation(expr: &AstStatement) -> bool { _ => false, }, - AstStatement::Assignment { right, .. } => needs_evaluation(right.as_ref()), - AstStatement::ExpressionList { expressions, .. } => expressions.iter().any(needs_evaluation), - AstStatement::RangeStatement { start, end, .. } => needs_evaluation(start) || needs_evaluation(end), + AstStatement::Assignment(data) => needs_evaluation(data.right.as_ref()), + AstStatement::ExpressionList(expressions) => expressions.iter().any(needs_evaluation), + AstStatement::RangeStatement(data) => needs_evaluation(&data.start) || needs_evaluation(&data.end), _ => true, } } @@ -178,20 +178,18 @@ fn get_default_initializer( target_type: &str, index: &Index, location: &SourceLocation, -) -> Result, UnresolvableKind> { +) -> Result, UnresolvableKind> { if let Some(init) = index.get_initial_value_for_type(target_type) { evaluate(init, None, index) //TODO do we ave a scope here? } else { let dt = index.get_type_information_or_void(target_type); let init = match dt { DataTypeInformation::Pointer { .. } => { - Some(AstStatement::Literal { kind: AstLiteral::Null, location: location.clone(), id }) + Some(AstFactory::create_literal(AstLiteral::Null, location.clone(), id)) + } + DataTypeInformation::Integer { .. } => { + Some(AstFactory::create_literal(AstLiteral::new_integer(0), location.clone(), id)) } - DataTypeInformation::Integer { .. } => Some(AstStatement::Literal { - kind: AstLiteral::new_integer(0), - location: location.clone(), - id, - }), DataTypeInformation::Enum { name, elements, .. } => elements .get(0) .and_then(|default_enum| index.find_enum_element(name, default_enum)) @@ -200,16 +198,16 @@ fn get_default_initializer( index.get_const_expressions().get_resolved_constant_statement(&initial_val) }) .cloned(), - DataTypeInformation::Float { .. } => Some(AstStatement::Literal { - kind: AstLiteral::new_real("0.0".to_string()), - location: location.clone(), + DataTypeInformation::Float { .. } => Some(AstFactory::create_literal( + AstLiteral::new_real("0.0".to_string()), + location.clone(), id, - }), - DataTypeInformation::String { encoding, .. } => Some(AstStatement::Literal { - kind: AstLiteral::new_string("".to_string(), encoding == &StringEncoding::Utf16), - location: location.clone(), + )), + DataTypeInformation::String { encoding, .. } => Some(AstFactory::create_literal( + AstLiteral::new_string("".to_string(), encoding == &StringEncoding::Utf16), + location.clone(), id, - }), + )), DataTypeInformation::SubRange { referenced_type, .. } | DataTypeInformation::Alias { referenced_type, .. } => { return get_default_initializer(id, referenced_type, index, location) @@ -222,27 +220,24 @@ fn get_default_initializer( /// transforms the given literal to better fit the datatype of the candidate /// effectively this casts an IntLiteral to a RealLiteral if necessary -fn cast_if_necessary( - statement: AstStatement, - target_type_name: &Option<&str>, - index: &Index, -) -> AstStatement { +fn cast_if_necessary(statement: AstNode, target_type_name: &Option<&str>, index: &Index) -> AstNode { let Some(dti) = target_type_name.and_then(|it| index.find_effective_type_info(it)) else { return statement; }; - if let AstStatement::Literal { kind: literal, location, id } = &statement { + if let AstStatement::Literal(literal) = statement.get_stmt() { + let (id, location) = (statement.get_id(), statement.get_location()); match literal { AstLiteral::Integer(value) if dti.is_float() => { - return AstStatement::new_real(value.to_string(), *id, location.to_owned()) + return AstNode::new_real(value.to_string(), id, location) } AstLiteral::String(StringValue { value, is_wide: true }) if dti.is_string_utf8() => { - return AstStatement::new_string(value, false, *id, location.to_owned()) + return AstNode::new_string(value, false, id, location) } AstLiteral::String(StringValue { value, is_wide: false }) if dti.is_string_utf16() => { - return AstStatement::new_string(value, true, *id, location.to_owned()) + return AstNode::new_string(value, true, id, location) } _ => (), @@ -253,9 +248,9 @@ fn cast_if_necessary( } /// Checks if a literal integer or float overflows based on its value, and if so returns true. -fn does_overflow(literal: &AstStatement, dti: Option<&DataTypeInformation>) -> bool { +fn does_overflow(literal: &AstNode, dti: Option<&DataTypeInformation>) -> bool { let Some(dti) = dti else { return false }; - let AstStatement::Literal { kind, .. } = literal else { return false }; + let AstStatement::Literal(kind) = literal.get_stmt() else { return false }; if !matches!(kind, AstLiteral::Integer(_) | AstLiteral::Real(_)) { return false; @@ -304,10 +299,10 @@ fn does_overflow(literal: &AstStatement, dti: Option<&DataTypeInformation>) -> b } pub fn evaluate( - initial: &AstStatement, + initial: &AstNode, scope: Option<&str>, index: &Index, -) -> Result, UnresolvableKind> { +) -> Result, UnresolvableKind> { evaluate_with_target_hint(initial, scope, index, None) } @@ -320,38 +315,41 @@ pub fn evaluate( /// - returns an Err if resolving caused an internal error (e.g. number parsing) /// - returns None if the initializer cannot be resolved (e.g. missing value) fn evaluate_with_target_hint( - initial: &AstStatement, + initial: &AstNode, scope: Option<&str>, index: &Index, target_type: Option<&str>, -) -> Result, UnresolvableKind> { +) -> Result, UnresolvableKind> { if !needs_evaluation(initial) { return Ok(Some(initial.clone())); // TODO hmm ... } - let literal = match initial { - AstStatement::Literal { kind, location, id } => match kind { + let (id, location) = (initial.get_id(), initial.get_location()); + let literal = match initial.get_stmt() { + AstStatement::Literal(kind) => match kind { AstLiteral::Array(Array { elements: Some(elements) }) => { let tt = target_type .and_then(|it| index.find_effective_type_info(it)) .and_then(|it| it.get_inner_array_type_name()) .or(target_type); - let inner_elements = AstStatement::get_as_list(elements) + let inner_elements = AstNode::get_as_list(elements) .iter() .map(|e| evaluate_with_target_hint(e, scope, index, tt)) - .collect::>, UnresolvableKind>>()? + .collect::>, UnresolvableKind>>()? .into_iter() - .collect::>>(); - - //return a new array, or return none if one was not resolvable - inner_elements.map(|ie| AstStatement::Literal { - id: *id, - kind: AstLiteral::new_array(Some(Box::new(AstStatement::ExpressionList { - expressions: ie, - id: *id, - }))), - location: location.clone(), + .collect::>>(); + + inner_elements.map(|ie| { + AstFactory::create_literal( + AstLiteral::new_array(Some(Box::new(AstFactory::create_expression_list( + ie, + location.clone(), + id, + )))), + location.clone(), + id, + ) }) } @@ -370,23 +368,24 @@ fn evaluate_with_target_hint( _ => return Ok(Some(initial.clone())), }, - AstStatement::DefaultValue { location, .. } => { + AstStatement::DefaultValue(_) => { return get_default_initializer( initial.get_id(), target_type.unwrap_or(VOID_TYPE), index, - location, + &location, ) } - AstStatement::ReferenceExpr { - access: ReferenceAccess::Cast(target), base: Some(type_name), .. - } => { + AstStatement::ReferenceExpr(ReferenceExpr { + access: ReferenceAccess::Cast(target), + base: Some(type_name), + }) => { let dti = type_name .get_flat_reference_name() .and_then(|type_name| index.find_effective_type_info(type_name)); match dti { Some(DataTypeInformation::Enum { name: enum_name, .. }) => { - if let AstStatement::Identifier { name: ref_name, .. } = target.as_ref() { + if let AstStatement::Identifier(ref_name) = target.get_stmt() { return index .find_enum_element(enum_name, ref_name) .ok_or_else(|| { @@ -406,7 +405,7 @@ fn evaluate_with_target_hint( None => return Err(UnresolvableKind::Misc("Cannot resolve unknown Type-Cast.".to_string())), } } - AstStatement::ReferenceExpr { access: ReferenceAccess::Member(reference), base, .. } => { + AstStatement::ReferenceExpr(ReferenceExpr { access: ReferenceAccess::Member(reference), base }) => { if let Some(name) = reference.get_flat_reference_name() { index .find_variable( @@ -420,33 +419,33 @@ fn evaluate_with_target_hint( None } } - AstStatement::BinaryExpression { left, right, operator, id, .. } => { + AstStatement::BinaryExpression(BinaryExpression { left, right, operator }) => { let eval_left = evaluate(left, scope, index)?; let eval_right = evaluate(right, scope, index)?; if let Some((left, right)) = eval_left.zip(eval_right).as_ref() { let evalualted = match operator { - Operator::Plus => arithmetic_expression!(left, +, right, "+", *id)?, - Operator::Minus => arithmetic_expression!(left, -, right, "-", *id)?, - Operator::Multiplication => arithmetic_expression!(left, *, right, "*", *id)?, + Operator::Plus => arithmetic_expression!(left, +, right, "+", id)?, + Operator::Minus => arithmetic_expression!(left, -, right, "-", id)?, + Operator::Multiplication => arithmetic_expression!(left, *, right, "*", id)?, Operator::Division if right.is_zero() => { return Err(UnresolvableKind::Misc("Attempt to divide by zero".to_string())) } - Operator::Division => arithmetic_expression!(left, /, right, "/", *id)?, + Operator::Division => arithmetic_expression!(left, /, right, "/", id)?, Operator::Modulo if right.is_zero() => { return Err(UnresolvableKind::Misc( "Attempt to calculate the remainder with a divisor of zero".to_string(), )) } - Operator::Modulo => arithmetic_expression!(left, %, right, "MOD", *id)?, - Operator::Equal => compare_expression!(left, ==, right, "=", *id)?, - Operator::NotEqual => compare_expression!(left, !=, right, "<>", *id)?, - Operator::Greater => compare_expression!(left, >, right, ">", *id)?, - Operator::GreaterOrEqual => compare_expression!(left, >=, right, ">=", *id)?, - Operator::Less => compare_expression!(left, <, right, "<", *id)?, - Operator::LessOrEqual => compare_expression!(left, <=, right, "<=", *id)?, - Operator::And => bitwise_expression!(left, & , right, "AND", *id)?, - Operator::Or => bitwise_expression!(left, | , right, "OR", *id)?, - Operator::Xor => bitwise_expression!(left, ^, right, "XOR", *id)?, + Operator::Modulo => arithmetic_expression!(left, %, right, "MOD", id)?, + Operator::Equal => compare_expression!(left, ==, right, "=", id)?, + Operator::NotEqual => compare_expression!(left, !=, right, "<>", id)?, + Operator::Greater => compare_expression!(left, >, right, ">", id)?, + Operator::GreaterOrEqual => compare_expression!(left, >=, right, ">=", id)?, + Operator::Less => compare_expression!(left, <, right, "<", id)?, + Operator::LessOrEqual => compare_expression!(left, <=, right, "<=", id)?, + Operator::And => bitwise_expression!(left, & , right, "AND", id)?, + Operator::Or => bitwise_expression!(left, | , right, "OR", id)?, + Operator::Xor => bitwise_expression!(left, ^, right, "XOR", id)?, _ => { return Err(UnresolvableKind::Misc(format!( "Cannot resolve operator {operator:?} in constant evaluation" @@ -462,15 +461,15 @@ fn evaluate_with_target_hint( } // NOT x - AstStatement::UnaryExpression { operator: Operator::Not, value, .. } => { + AstStatement::UnaryExpression(UnaryExpression { operator: Operator::Not, value }) => { let eval = evaluate(value, scope, index)?; - match eval.clone() { - Some(AstStatement::Literal { kind: AstLiteral::Bool(v), id, location }) => { - Some(AstStatement::Literal { kind: AstLiteral::Bool(!v), id, location }) + match eval.as_ref() { + Some(AstNode { stmt: AstStatement::Literal(AstLiteral::Bool(v)), id, location }) => { + Some(AstFactory::create_literal(AstLiteral::Bool(!v), location.clone(), *id)) } - Some(AstStatement::Literal { kind: AstLiteral::Integer(v), id, location }) => { + Some(AstNode { stmt: AstStatement::Literal(AstLiteral::Integer(v)), id, location }) => { evaluate_with_target_hint(eval.as_ref().unwrap(), scope, index, target_type)?; - Some(AstStatement::Literal { kind: AstLiteral::Integer(!v), id, location }) + Some(AstFactory::create_literal(AstLiteral::Integer(!v), location.clone(), *id)) } None => { None //not yet resolvable @@ -479,21 +478,21 @@ fn evaluate_with_target_hint( } } // - x - AstStatement::UnaryExpression { operator: Operator::Minus, value, .. } => { + AstStatement::UnaryExpression(UnaryExpression { operator: Operator::Minus, value }) => { match evaluate(value, scope, index)? { - Some(AstStatement::Literal { kind: AstLiteral::Integer(v), id, location }) => { - Some(AstStatement::Literal { kind: AstLiteral::Integer(-v), id, location }) + Some(AstNode { stmt: AstStatement::Literal(AstLiteral::Integer(v)), id, location }) => { + Some(AstNode::new(AstStatement::Literal(AstLiteral::Integer(-v)), id, location)) } - Some(AstStatement::Literal { kind: AstLiteral::Real(v), id, location }) => { - let lit = AstStatement::Literal { - kind: AstLiteral::new_real(format!( + Some(AstNode { stmt: AstStatement::Literal(AstLiteral::Real(v)), id, location }) => { + let lit = AstNode::new( + AstStatement::Literal(AstLiteral::new_real(format!( "{:}", -(v.parse::()) .map_err(|err| UnresolvableKind::Misc(format!("{err:}: {v:}")))? - )), + ))), id, location, - }; + ); evaluate_with_target_hint(&lit, scope, index, target_type)? } None => { @@ -502,56 +501,51 @@ fn evaluate_with_target_hint( _ => return Err(UnresolvableKind::Misc(format!("Cannot resolve constant Minus {value:?}"))), } } - AstStatement::ExpressionList { expressions, id } => { + AstStatement::ExpressionList(expressions) => { let inner_elements = expressions .iter() .map(|e| evaluate(e, scope, index)) - .collect::>, UnresolvableKind>>()? + .collect::>, UnresolvableKind>>()? .into_iter() - .collect::>>(); + .collect::>>(); //return a new array, or return none if one was not resolvable - inner_elements.map(|ie| AstStatement::ExpressionList { expressions: ie, id: *id }) + inner_elements.map(|ie| AstNode::new(AstStatement::ExpressionList(ie), id, location)) } - AstStatement::MultipliedStatement { element, id, multiplier, location } => { - let inner_elements = AstStatement::get_as_list(element.as_ref()) + AstStatement::MultipliedStatement(MultipliedStatement { element, multiplier }) => { + let inner_elements = AstNode::get_as_list(element.as_ref()) .iter() .map(|e| evaluate(e, scope, index)) - .collect::>, UnresolvableKind>>()? + .collect::>, UnresolvableKind>>()? .into_iter() - .collect::>>(); + .collect::>>(); //return a new array, or return none if one was not resolvable inner_elements.map(|ie| { if let [ie] = ie.as_slice() { - AstStatement::MultipliedStatement { - id: *id, - element: Box::new(ie.clone()), //TODO - multiplier: *multiplier, - location: location.clone(), - } + AstFactory::create_multiplied_statement(*multiplier, ie.clone(), location.clone(), id) } else { - AstStatement::MultipliedStatement { - id: *id, - element: Box::new(AstStatement::ExpressionList { expressions: ie, id: *id }), - multiplier: *multiplier, - location: location.clone(), - } + AstFactory::create_multiplied_statement( + *multiplier, + AstFactory::create_expression_list(ie, location.clone(), id), + location.clone(), + id, + ) } }) } - AstStatement::Assignment { left, right, id } => { + AstStatement::Assignment(data) => { //Right needs evaluation - if let Some(right) = evaluate(right, scope, index)? { - Some(AstStatement::Assignment { left: left.clone(), right: Box::new(right), id: *id }) + if let Some(right) = evaluate(&data.right, scope, index)? { + Some(AstFactory::create_assignment(*data.left.clone(), right, id)) } else { Some(initial.clone()) } } - AstStatement::RangeStatement { start, end, id } => { - let start = Box::new(evaluate(start, scope, index)?.unwrap_or_else(|| *start.to_owned())); - let end = Box::new(evaluate(end, scope, index)?.unwrap_or_else(|| *end.to_owned())); - Some(AstStatement::RangeStatement { start, end, id: *id }) + AstStatement::RangeStatement(data) => { + let start = evaluate(&data.start, scope, index)?.unwrap_or_else(|| *data.start.to_owned()); + let end = evaluate(&data.end, scope, index)?.unwrap_or_else(|| *data.end.to_owned()); + Some(AstFactory::create_range_statement(start, end, id)) } _ => return Err(UnresolvableKind::Misc(format!("Cannot resolve constant: {initial:#?}"))), }; @@ -565,7 +559,7 @@ fn resolve_const_reference( variable: &crate::index::VariableIndexEntry, name: &str, index: &Index, -) -> Result, UnresolvableKind> { +) -> Result, UnresolvableKind> { if !variable.is_constant() { return Err(UnresolvableKind::Misc(format!("'{name}' is no const reference"))); } @@ -583,18 +577,18 @@ fn resolve_const_reference( /// [`AstLiteral::Integer`] with value `65_535` whereas `INT#FFFF` will not evaluate because it overflows /// (see also [`does_overflow`] and [`evaluate_with_target_hint`]). fn get_cast_statement_literal( - cast_statement: &AstStatement, + cast_statement: &AstNode, type_name: &str, scope: Option<&str>, index: &Index, -) -> Result { +) -> Result { let dti = index.find_effective_type_info(type_name); match dti { Some(&DataTypeInformation::Integer { .. }) => { let evaluated_initial = evaluate_with_target_hint(cast_statement, scope, index, Some(type_name))? .as_ref() .map(|v| { - if let AstStatement::Literal { kind: AstLiteral::Integer(value), .. } = v { + if let AstStatement::Literal(AstLiteral::Integer(value)) = v.get_stmt() { Ok(*value) } else { Err(UnresolvableKind::Misc(format!("Expected integer value, found {v:?}"))) @@ -603,11 +597,11 @@ fn get_cast_statement_literal( .transpose()?; if let Some(value) = evaluated_initial { - return Ok(AstStatement::Literal { - kind: AstLiteral::new_integer(value), - id: cast_statement.get_id(), - location: cast_statement.get_location(), - }); + return Ok(AstNode::new( + AstStatement::Literal(AstLiteral::new_integer(value)), + cast_statement.get_id(), + cast_statement.get_location(), + )); } Err(UnresolvableKind::Misc(format!("Cannot resolve constant: {type_name}#{cast_statement:?}"))) @@ -615,11 +609,9 @@ fn get_cast_statement_literal( Some(DataTypeInformation::Float { .. }) => { let evaluated = evaluate(cast_statement, scope, index)?; - let value = match evaluated { - Some(AstStatement::Literal { kind: AstLiteral::Integer(value), .. }) => Some(value as f64), - Some(AstStatement::Literal { kind: AstLiteral::Real(value), .. }) => { - value.parse::().ok() - } + let value = match evaluated.as_ref().map(|it| it.get_stmt()) { + Some(AstStatement::Literal(AstLiteral::Integer(value))) => Some(*value as f64), + Some(AstStatement::Literal(AstLiteral::Real(value))) => value.parse::().ok(), _ => { return Err(UnresolvableKind::Misc(format!( "Expected floating point type, got: {evaluated:?}" @@ -633,11 +625,11 @@ fn get_cast_statement_literal( ))); }; - Ok(AstStatement::Literal { - kind: AstLiteral::new_real(value.to_string()), - id: cast_statement.get_id(), - location: cast_statement.get_location(), - }) + Ok(AstNode::new( + AstStatement::Literal(AstLiteral::new_real(value.to_string())), + cast_statement.get_id(), + cast_statement.get_location(), + )) } _ => Err(UnresolvableKind::Misc(format!("Cannot resolve constant: {type_name}#{cast_statement:?}"))), @@ -653,92 +645,94 @@ use cannot_eval_error; macro_rules! arithmetic_expression { ($left:expr, $op:tt, $right:expr, $op_text:expr, $resulting_id:expr) => { - match ($left, $right) { - ( AstStatement::Literal{kind: AstLiteral::Integer(lvalue), location: loc_left, ..}, - AstStatement::Literal{kind: AstLiteral::Integer(rvalue), location: loc_right, ..}) => { - Ok(AstStatement::Literal{ - id: $resulting_id, kind: AstLiteral::new_integer(lvalue $op rvalue), location: loc_left.span(loc_right) - }) + { + let loc_left = $left.get_location(); + let loc_right = $right.get_location(); + match ($left.get_stmt(), $right.get_stmt()) { + ( AstStatement::Literal(AstLiteral::Integer(lvalue)), + AstStatement::Literal(AstLiteral::Integer(rvalue))) => { + Ok(AstStatement::Literal(AstLiteral::new_integer(lvalue $op rvalue))) }, - ( AstStatement::Literal{kind: AstLiteral::Integer(lvalue), location: loc_left, ..}, - AstStatement::Literal{kind: AstLiteral::Real(rvalue), location: loc_right, ..}) => { + + ( AstStatement::Literal(AstLiteral::Integer(lvalue)), + AstStatement::Literal(AstLiteral::Real(rvalue))) => { let rvalue = rvalue.parse::() .map_err(|err| UnresolvableKind::Misc(err.to_string()))?; - Ok(AstStatement::Literal{ - id: $resulting_id, kind: AstLiteral::new_real((*lvalue as f64 $op rvalue).to_string()), location: loc_left.span(loc_right) - }) + Ok(AstStatement::Literal( + AstLiteral::new_real((*lvalue as f64 $op rvalue).to_string()))) }, - ( AstStatement::Literal{kind: AstLiteral::Real(lvalue), location: loc_left, ..}, - AstStatement::Literal{kind: AstLiteral::Integer(rvalue), location: loc_right, ..}) => { + + ( AstStatement::Literal(AstLiteral::Real(lvalue)), + AstStatement::Literal(AstLiteral::Integer(rvalue))) => { let lvalue = lvalue.parse::() .map_err(|err| UnresolvableKind::Misc(err.to_string()))?; - Ok(AstStatement::Literal{ - id: $resulting_id, kind: AstLiteral::new_real((lvalue $op *rvalue as f64).to_string()), location: loc_left.span(loc_right) - }) + Ok(AstStatement::Literal(AstLiteral::new_real((lvalue $op *rvalue as f64).to_string()))) }, - ( AstStatement::Literal{kind: AstLiteral::Real(lvalue), location: loc_left, ..}, - AstStatement::Literal{kind: AstLiteral::Real(rvalue), location: loc_right, ..}) => { + + ( AstStatement::Literal(AstLiteral::Real(lvalue)), + AstStatement::Literal(AstLiteral::Real(rvalue))) => { let lvalue = lvalue.parse::() .map_err(|err| UnresolvableKind::Misc(err.to_string()))?; let rvalue = rvalue.parse::() .map_err(|err| UnresolvableKind::Misc(err.to_string()))?; - Ok(AstStatement::Literal{ - id: $resulting_id, kind: AstLiteral::new_real((lvalue $op rvalue).to_string()), location: loc_left.span(loc_right) - }) + Ok(AstStatement::Literal( + AstLiteral::new_real((lvalue $op rvalue).to_string()), + )) }, _ => cannot_eval_error!($left, $op_text, $right), + }.map(|it| AstNode::new(it, $resulting_id, loc_left.span(&loc_right))) } } } use arithmetic_expression; macro_rules! bitwise_expression { - ($left:expr, $op:tt, $right:expr, $op_text:expr, $resulting_id:expr) => { - match ($left, $right) { - ( AstStatement::Literal{kind: AstLiteral::Integer(lvalue), location: loc_left, ..}, - AstStatement::Literal{kind: AstLiteral::Integer(rvalue), location: loc_right, ..}) => { - Ok(AstStatement::Literal{ - id: $resulting_id, kind: AstLiteral::new_integer(lvalue $op rvalue), location: loc_left.span(loc_right) - }) + ($left:expr, $op:tt, $right:expr, $op_text:expr, $resulting_id:expr) => {{ + let loc_left = $left.get_location(); + let loc_right = $right.get_location(); + match ($left.get_stmt(), $right.get_stmt()) { + ( AstStatement::Literal(AstLiteral::Integer(lvalue)), + AstStatement::Literal(AstLiteral::Integer(rvalue))) => { + Ok(AstStatement::Literal(AstLiteral::new_integer(lvalue $op rvalue))) }, - ( AstStatement::Literal{kind: AstLiteral::Bool(lvalue), location: loc_left, ..}, - AstStatement::Literal{kind: AstLiteral::Bool(rvalue), location: loc_right, ..}) => { - Ok(AstStatement::Literal{ - id: $resulting_id, kind: AstLiteral::new_bool(lvalue $op rvalue), location: loc_left.span(loc_right) - }) + ( AstStatement::Literal(AstLiteral::Bool(lvalue)), + AstStatement::Literal(AstLiteral::Bool(rvalue))) => { + Ok(AstStatement::Literal(AstLiteral::new_bool(lvalue $op rvalue))) }, _ => cannot_eval_error!($left, $op_text, $right), - } - }; -} + }.map(|it| AstNode::new(it, $resulting_id, loc_left.span(&loc_right))) + } +}} use bitwise_expression; macro_rules! compare_expression { - ($left:expr, $op:tt, $right:expr, $op_text:expr, $resulting_id:expr) => { - match ($left, $right) { - ( AstStatement::Literal{kind: AstLiteral::Integer(lvalue), location: loc_left, ..}, - AstStatement::Literal{kind: AstLiteral::Integer(rvalue), location: loc_right, ..}) => { - Ok(AstStatement::Literal{ - id: $resulting_id, kind: AstLiteral::new_bool(lvalue $op rvalue), location: loc_left.span(loc_right) - }) + ($left:expr, $op:tt, $right:expr, $op_text:expr, $resulting_id:expr) => {{ + let loc_left = $left.get_location(); + let loc_right = $right.get_location(); + match ($left.get_stmt(), $right.get_stmt()) { + ( AstStatement::Literal(AstLiteral::Integer(lvalue)), + AstStatement::Literal(AstLiteral::Integer(rvalue))) => { + Ok(AstStatement::Literal( + AstLiteral::new_bool(lvalue $op rvalue))) }, - ( AstStatement::Literal{kind: AstLiteral::Real{..}, ..}, - AstStatement::Literal{kind: AstLiteral::Real{..}, ..}) => { - Err(UnresolvableKind::Misc("Cannot compare Reals without epsilon".into())) + ( AstStatement::Literal(AstLiteral::Real(..)), + AstStatement::Literal(AstLiteral::Real(..))) => { + Err(UnresolvableKind::Misc("Cannot compare Reals without epsilon".into())) }, - ( AstStatement::Literal{kind: AstLiteral::Bool(lvalue), location: loc_left, ..}, - AstStatement::Literal{kind: AstLiteral::Bool(rvalue), location: loc_right, ..}) => { - Ok(AstStatement::Literal{ - id: $resulting_id, kind: AstLiteral::new_bool(lvalue $op rvalue), location: loc_left.span(loc_right) - }) + ( AstStatement::Literal(AstLiteral::Bool(lvalue)), + AstStatement::Literal(AstLiteral::Bool(rvalue))) => { + Ok(AstStatement::Literal(AstLiteral::new_bool(lvalue $op rvalue))) }, _ => cannot_eval_error!($left, $op_text, $right), - } - } + }.map(|it| AstNode::new(it, $resulting_id, loc_left.span(&loc_right))) + }} } use compare_expression; use plc_ast::{ - ast::{AstId, AstStatement, Operator, ReferenceAccess}, + ast::{ + AstFactory, AstId, AstNode, AstStatement, BinaryExpression, MultipliedStatement, Operator, + ReferenceAccess, ReferenceExpr, UnaryExpression, + }, literals::{Array, AstLiteral, StringValue}, }; diff --git a/src/resolver/generics.rs b/src/resolver/generics.rs index 4bdb5c2235..81e0d7103d 100644 --- a/src/resolver/generics.rs +++ b/src/resolver/generics.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use plc_ast::ast::{flatten_expression_list, AstStatement, GenericBinding, LinkageType, TypeNature}; +use plc_ast::ast::{flatten_expression_list, AstNode, AstStatement, GenericBinding, LinkageType, TypeNature}; use plc_source::source_location::SourceLocation; use crate::{ @@ -32,7 +32,7 @@ impl<'i> TypeAnnotator<'i> { index: &'idx Index, annotation_map: &'idx AnnotationMapImpl, type_name: &str, - statement: &AstStatement, + statement: &AstNode, ) -> Option<(&'idx str, &'idx str)> { //find inner type if this was turned into an array or pointer (if this is `POINTER TO T` lets find out what T is) let effective_type = index.find_effective_type_info(type_name); @@ -46,9 +46,9 @@ impl<'i> TypeAnnotator<'i> { //If generic add a generic annotation if let Some(DataTypeInformation::Generic { generic_symbol, .. }) = candidate { - let statement = match statement { + let statement = match statement.get_stmt() { //The right side of the assignment is the source of truth - AstStatement::Assignment { right, .. } => right, + AstStatement::Assignment(data) => &data.right, _ => statement, }; //Find the statement's type @@ -64,8 +64,8 @@ impl<'i> TypeAnnotator<'i> { &mut self, generics_candidates: HashMap>, implementation_name: &str, - operator: &AstStatement, - parameters: Option<&AstStatement>, + operator: &AstNode, + parameters: Option<&AstNode>, ctx: VisitorContext, ) { if let Some(PouIndexEntry::Function { generics, .. }) = self.index.find_pou(implementation_name) { @@ -231,7 +231,7 @@ impl<'i> TypeAnnotator<'i> { fn update_generic_function_parameters( &mut self, - s: &AstStatement, + s: &AstNode, function_name: &str, generic_map: &HashMap, ) { @@ -279,10 +279,9 @@ impl<'i> TypeAnnotator<'i> { self.annotation_map.add_generic_nature(passed_parameter, generic.generic_nature); // for assignments we need to annotate the left side aswell - match parameter_stmt { - AstStatement::Assignment { left, .. } - | AstStatement::OutputAssignment { left, .. } => { - self.annotate(left, StatementAnnotation::value(datatype.get_name())); + match parameter_stmt.get_stmt() { + AstStatement::Assignment(data) | AstStatement::OutputAssignment(data) => { + self.annotate(&data.left, StatementAnnotation::value(datatype.get_name())); } _ => {} } diff --git a/src/resolver/tests/const_resolver_tests.rs b/src/resolver/tests/const_resolver_tests.rs index 578438a0ef..0eb2d572e9 100644 --- a/src/resolver/tests/const_resolver_tests.rs +++ b/src/resolver/tests/const_resolver_tests.rs @@ -1,4 +1,4 @@ -use plc_ast::ast::AstStatement; +use plc_ast::ast::{AstFactory, AstNode, AstStatement}; use plc_ast::literals::{Array, AstLiteral}; use plc_ast::provider::IdProvider; use plc_source::source_location::SourceLocation; @@ -26,40 +26,32 @@ macro_rules! global { }; } -fn find_member_value<'a>(index: &'a Index, pou: &str, reference: &str) -> Option<&'a AstStatement> { +fn find_member_value<'a>(index: &'a Index, pou: &str, reference: &str) -> Option<&'a AstNode> { index .find_member(pou, reference) .and_then(|it| index.get_const_expressions().maybe_get_constant_statement(&it.initial_value)) } -fn find_constant_value<'a>(index: &'a Index, reference: &str) -> Option<&'a AstStatement> { +fn find_constant_value<'a>(index: &'a Index, reference: &str) -> Option<&'a AstNode> { index .find_global_variable(reference) .and_then(|it| index.get_const_expressions().maybe_get_constant_statement(&it.initial_value)) } -fn create_int_literal(v: i128) -> AstStatement { - AstStatement::Literal { kind: AstLiteral::new_integer(v), id: 0, location: SourceLocation::undefined() } +fn create_int_literal(v: i128) -> AstNode { + AstFactory::create_literal(AstLiteral::new_integer(v), SourceLocation::undefined(), 0) } -fn create_string_literal(v: &str, wide: bool) -> AstStatement { - AstStatement::Literal { - kind: AstLiteral::new_string(v.to_string(), wide), - id: 0, - location: SourceLocation::undefined(), - } +fn create_string_literal(v: &str, wide: bool) -> AstNode { + AstFactory::create_literal(AstLiteral::new_string(v.to_string(), wide), SourceLocation::undefined(), 0) } -fn create_real_literal(v: f64) -> AstStatement { - AstStatement::Literal { - kind: AstLiteral::new_real(format!("{v:}")), - id: 0, - location: SourceLocation::undefined(), - } +fn create_real_literal(v: f64) -> AstNode { + AstFactory::create_literal(AstLiteral::new_real(format!("{v:}")), SourceLocation::undefined(), 0) } -fn create_bool_literal(v: bool) -> AstStatement { - AstStatement::Literal { kind: AstLiteral::new_bool(v), id: 0, location: SourceLocation::undefined() } +fn create_bool_literal(v: bool) -> AstNode { + AstFactory::create_literal(AstLiteral::new_bool(v), SourceLocation::undefined(), 0) } #[test] @@ -887,22 +879,8 @@ fn const_string_initializers_should_be_converted() { // AND the globals should have gotten their values - debug_assert_eq!( - find_constant_value(&index, "aa"), - Some(AstStatement::Literal { - kind: AstLiteral::new_string("World".into(), false), - id: 0, - location: SourceLocation::undefined() - }) - ); - debug_assert_eq!( - find_constant_value(&index, "bb"), - Some(AstStatement::Literal { - kind: AstLiteral::new_string("Hello".into(), true), - id: 0, - location: SourceLocation::undefined() - }) - ); + debug_assert_eq!(find_constant_value(&index, "aa"), Some(create_string_literal("World", false))); + debug_assert_eq!(find_constant_value(&index, "bb"), Some(create_string_literal("Hello", true))); } #[test] @@ -931,14 +909,7 @@ fn const_lreal_initializers_should_be_resolved_correctly() { debug_assert_eq!(EMPTY, unresolvable); // AND the globals should have gotten their values - debug_assert_eq!( - find_constant_value(&index, "tau"), - Some(AstStatement::Literal { - kind: AstLiteral::new_real("6.283".into()), - id: 0, - location: SourceLocation::undefined() - }) - ); + debug_assert_eq!(find_constant_value(&index, "tau"), Some(create_real_literal("6.283".parse().unwrap()))); //AND the type is correctly associated let i = index.find_global_variable("tau").unwrap().initial_value.unwrap(); @@ -1007,10 +978,10 @@ fn array_literals_type_resolving() { ); // AND the array-literals types are associated correctly - if let AstStatement::Literal { kind: AstLiteral::Array(Array { elements: Some(elements) }), .. } = - parse_result.global_vars[0].variables[0].initializer.as_ref().unwrap() + if let AstStatement::Literal(AstLiteral::Array(Array { elements: Some(elements) })) = + parse_result.global_vars[0].variables[0].initializer.as_ref().unwrap().get_stmt() { - if let AstStatement::ExpressionList { expressions, .. } = elements.as_ref() { + if let AstStatement::ExpressionList(expressions) = elements.as_ref().get_stmt() { for ele in expressions.iter() { assert_eq!(annotations.get_type_hint(ele, &index), index.find_effective_type_by_name("BYTE")); } @@ -1063,7 +1034,7 @@ fn nested_array_literals_type_resolving() { ); //check the initializer's array-element's types - if let AstStatement::Literal { kind: AstLiteral::Array(Array { elements: Some(e) }), .. } = initializer { + if let AstStatement::Literal(AstLiteral::Array(Array { elements: Some(e) })) = initializer.get_stmt() { if let Some(DataTypeInformation::Array { inner_type_name, .. }) = index.find_effective_type_by_name(a.get_type_name()).map(|t| t.get_type_information()) { @@ -1074,7 +1045,7 @@ fn nested_array_literals_type_resolving() { );*/ // check if the array's elements have the array's inner type - for ele in AstStatement::get_as_list(e) { + for ele in AstNode::get_as_list(e) { let element_hint = annotations.get_type_hint(ele, &index).unwrap(); assert_eq!(Some(element_hint), index.find_effective_type_by_name(inner_type_name)) } @@ -1124,10 +1095,8 @@ fn nested_array_literals_multiplied_statement_type_resolving() { //check the initializer's array-element's types // [[2(2)],[2(3)]] - if let AstStatement::Literal { - kind: AstLiteral::Array(Array { elements: Some(outer_expression_list) }), - .. - } = initializer + if let AstStatement::Literal(AstLiteral::Array(Array { elements: Some(outer_expression_list) })) = + initializer.get_stmt() { // outer_expression_list = [2(2)],[2(3)] if let Some(DataTypeInformation::Array { inner_type_name: array_of_byte, .. }) = @@ -1140,26 +1109,24 @@ fn nested_array_literals_multiplied_statement_type_resolving() { ); // check if the array's elements have the array's inner type - for inner_array in AstStatement::get_as_list(outer_expression_list) { + for inner_array in AstNode::get_as_list(outer_expression_list) { // [2(2)] let element_hint = annotations.get_type_hint(inner_array, &index).unwrap(); assert_eq!(Some(element_hint), index.find_effective_type_by_name(array_of_byte)); //check if the inner array statement's also got the type-annotations - if let AstStatement::Literal { - kind: AstLiteral::Array(Array { elements: Some(inner_multiplied_stmt) }), - .. - } = inner_array + + if let AstStatement::Literal(AstLiteral::Array(Array { + elements: Some(inner_multiplied_stmt), + })) = inner_array.get_stmt() { // inner_multiplied_stmt = 2(2) - for inner_multiplied_stmt in AstStatement::get_as_list(inner_multiplied_stmt) { - if let AstStatement::MultipliedStatement { element: multiplied_element, .. } = - inner_multiplied_stmt - { + for inner_multiplied_stmt in AstNode::get_as_list(inner_multiplied_stmt) { + if let AstStatement::MultipliedStatement(data) = inner_multiplied_stmt.get_stmt() { //check if the inner thing really got the BYTE hint // multiplied-element = 2 assert_eq!( - annotations.get_type_hint(multiplied_element.as_ref(), &index), + annotations.get_type_hint(data.element.as_ref(), &index), index.find_effective_type_by_name("BYTE") ); } else { diff --git a/src/resolver/tests/resolve_control_statments.rs b/src/resolver/tests/resolve_control_statments.rs index 4f383ac67b..c911b37417 100644 --- a/src/resolver/tests/resolve_control_statments.rs +++ b/src/resolver/tests/resolve_control_statments.rs @@ -23,11 +23,13 @@ fn binary_expressions_resolves_types() { let (annotations, ..) = TypeAnnotator::visit_unit(&index, &unit, id_provider); let statements = &unit.implementations[0].statements; - if let AstStatement::ControlStatement { - kind: - AstControlStatement::ForLoop(ForLoopStatement { counter, start, end, by_step: Some(by_step), .. }), + if let AstStatement::ControlStatement(AstControlStatement::ForLoop(ForLoopStatement { + counter, + start, + end, + by_step: Some(by_step), .. - } = &statements[0] + })) = statements[0].get_stmt() { assert_type_and_hint!(&annotations, &index, counter, "INT", None); assert_type_and_hint!(&annotations, &index, start, "DINT", Some("INT")); diff --git a/src/resolver/tests/resolve_expressions_tests.rs b/src/resolver/tests/resolve_expressions_tests.rs index 1af40898f6..0950850c6e 100644 --- a/src/resolver/tests/resolve_expressions_tests.rs +++ b/src/resolver/tests/resolve_expressions_tests.rs @@ -2,7 +2,11 @@ use core::panic; use insta::{assert_debug_snapshot, assert_snapshot}; use plc_ast::{ - ast::{flatten_expression_list, AstStatement, DataType, Pou, ReferenceAccess, UserTypeDeclaration}, + ast::{ + flatten_expression_list, Assignment, AstNode, AstStatement, BinaryExpression, CallStatement, + DataType, DirectAccess, MultipliedStatement, Pou, RangeStatement, ReferenceAccess, ReferenceExpr, + UnaryExpression, UserTypeDeclaration, + }, control_statements::{AstControlStatement, CaseStatement}, literals::{Array, AstLiteral}, provider::IdProvider, @@ -72,9 +76,7 @@ fn cast_expressions_resolves_types() { let statements = &unit.implementations[0].statements; assert_type_and_hint!(&annotations, &index, &statements[0], BYTE_TYPE, None); assert_type_and_hint!(&annotations, &index, &statements[1], INT_TYPE, None); - let AstStatement::ReferenceExpr { access: ReferenceAccess::Cast(target), .. } = &statements[1] else { - unreachable!() - }; + let AstNode { stmt: AstStatement::ReferenceExpr(ReferenceExpr{access: ReferenceAccess::Cast(target), ..}), ..} = &statements[1] else {unreachable!()}; assert_type_and_hint!(&annotations, &index, target.as_ref(), SINT_TYPE, None); assert_type_and_hint!(&annotations, &index, &statements[2], UINT_TYPE, None); @@ -95,38 +97,20 @@ fn cast_expression_literals_get_casted_types() { let statements = &unit.implementations[0].statements; { assert_type_and_hint!(&annotations, &index, &statements[0], INT_TYPE, None); - let AstStatement::ReferenceExpr { access: ReferenceAccess::Cast(target), .. } = &statements[0] else { - unreachable!() - }; + let AstNode { stmt: AstStatement::ReferenceExpr(ReferenceExpr{access: ReferenceAccess::Cast(target), ..}), ..} = &statements[0] else {unreachable!()}; let t = target.as_ref(); assert_eq!( - format!( - "{:#?}", - AstStatement::Literal { - kind: AstLiteral::Integer(0xFFFF), - location: SourceLocation::undefined(), - id: 0 - } - ), + format!("{:#?}", AstNode::new_integer(0xFFFF, 0, SourceLocation::undefined())), format!("{t:#?}") ); assert_type_and_hint!(&annotations, &index, target.as_ref(), INT_TYPE, None); } { assert_type_and_hint!(&annotations, &index, &statements[1], WORD_TYPE, None); - let AstStatement::ReferenceExpr { access: ReferenceAccess::Cast(target), .. } = &statements[1] else { - unreachable!() - }; + let AstStatement::ReferenceExpr(ReferenceExpr { access: ReferenceAccess::Cast(target), ..} )= &statements[1].get_stmt() else {unreachable!()}; let t = target.as_ref(); assert_eq!( - format!( - "{:#?}", - AstStatement::Literal { - kind: AstLiteral::Integer(0xFFFF), - location: SourceLocation::undefined(), - id: 0 - } - ), + format!("{:#?}", AstNode::new_integer(0xFFFF, 0, SourceLocation::undefined())), format!("{t:#?}") ); assert_type_and_hint!(&annotations, &index, target.as_ref(), WORD_TYPE, None); @@ -150,9 +134,7 @@ fn cast_expressions_of_enum_with_resolves_types() { assert_type_and_hint!(&annotations, &index, &statements[0], "MyEnum", None); assert_type_and_hint!(&annotations, &index, &statements[1], "MyEnum", None); - let AstStatement::ReferenceExpr { access: ReferenceAccess::Cast(access), .. } = &statements[0] else { - unreachable!() - }; + let AstStatement::ReferenceExpr (ReferenceExpr { access: ReferenceAccess::Cast(access), ..}) = &statements[0].get_stmt() else { unreachable!()}; assert_eq!( annotations.get(access), Some(&StatementAnnotation::Variable { @@ -164,9 +146,7 @@ fn cast_expressions_of_enum_with_resolves_types() { }) ); - let AstStatement::ReferenceExpr { access: ReferenceAccess::Cast(access), .. } = &statements[1] else { - unreachable!() - }; + let AstStatement::ReferenceExpr (ReferenceExpr{ access: ReferenceAccess::Cast(access), ..}) = &statements[1].get_stmt() else { unreachable!()}; assert_eq!( annotations.get(access), Some(&StatementAnnotation::Variable { @@ -235,7 +215,9 @@ fn binary_expressions_resolves_types_for_mixed_signed_ints() { ); let annotations = annotate_with_ids(&unit, &mut index, id_provider); let statements = &unit.implementations[0].statements; - if let AstStatement::BinaryExpression { left, right, .. } = &statements[0] { + if let AstNode { stmt: AstStatement::BinaryExpression(BinaryExpression { left, right, .. }), .. } = + &statements[0] + { assert_type_and_hint!(&annotations, &index, left, INT_TYPE, Some(DINT_TYPE)); assert_type_and_hint!(&annotations, &index, right, UINT_TYPE, Some(DINT_TYPE)); assert_type_and_hint!(&annotations, &index, &statements[0], DINT_TYPE, None); @@ -247,8 +229,8 @@ fn binary_expressions_resolves_types_for_mixed_signed_ints() { #[test] #[ignore = "Types on builtin types are not correctly annotated"] fn expt_binary_expression() { - fn get_params(stmt: &AstStatement) -> (&AstStatement, &AstStatement) { - if let AstStatement::CallStatement { parameters, .. } = stmt { + fn get_params(stmt: &AstNode) -> (&AstNode, &AstNode) { + if let AstNode { stmt: AstStatement::CallStatement(CallStatement { parameters, .. }), .. } = stmt { if let &[left, right] = flatten_expression_list(parameters.as_ref().as_ref().unwrap()).as_slice() { return (left, right); @@ -342,10 +324,15 @@ fn binary_expressions_resolves_types_for_literals_directly() { let annotations = annotate_with_ids(&unit, &mut index, id_provider); let statements = &unit.implementations[0].statements; - if let AstStatement::Assignment { right: addition, .. } = &statements[0] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { right: addition, .. }), .. } = &statements[0] + { // a + 7 --> DINT (BYTE hint) assert_type_and_hint!(&annotations, &index, addition, DINT_TYPE, Some(BYTE_TYPE)); - if let AstStatement::BinaryExpression { left: a, right: seven, .. } = addition.as_ref() { + if let AstNode { + stmt: AstStatement::BinaryExpression(BinaryExpression { left: a, right: seven, .. }), + .. + } = addition.as_ref() + { // a --> BYTE (DINT hint) assert_type_and_hint!(&annotations, &index, a, BYTE_TYPE, Some(DINT_TYPE)); // 7 --> DINT (no hint) @@ -357,7 +344,7 @@ fn binary_expressions_resolves_types_for_literals_directly() { unreachable!() } - if let AstStatement::Assignment { right: seven, .. } = &statements[1] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { right: seven, .. }), .. } = &statements[1] { assert_type_and_hint!(&annotations, &index, seven, DINT_TYPE, Some(BYTE_TYPE)); } else { unreachable!() @@ -379,16 +366,21 @@ fn addition_subtraction_expression_with_pointers_resolves_to_pointer_type() { let annotations = annotate_with_ids(&unit, &mut index, id_provider); let statements = &unit.implementations[0].statements; - if let AstStatement::Assignment { right: addition, .. } = &statements[0] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { right: addition, .. }), .. } = &statements[0] + { assert_type_and_hint!(&annotations, &index, addition, "__POINTER_TO_BYTE", Some("__PRG_a")); } - if let AstStatement::Assignment { right: addition, .. } = &statements[1] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { right: addition, .. }), .. } = &statements[1] + { assert_type_and_hint!(&annotations, &index, addition, "__PRG_a", Some("__PRG_a")); - if let AstStatement::BinaryExpression { left, .. } = &**addition { + if let AstNode { stmt: AstStatement::BinaryExpression(BinaryExpression { left, .. }), .. } = + &**addition + { assert_type_and_hint!(&annotations, &index, left, "__PRG_a", None); } } - if let AstStatement::Assignment { right: addition, .. } = &statements[2] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { right: addition, .. }), .. } = &statements[2] + { assert_type_and_hint!(&annotations, &index, addition, "__POINTER_TO_BYTE", Some("__PRG_a")); } } @@ -407,10 +399,12 @@ fn equality_with_pointers_is_bool() { let annotations = annotate_with_ids(&unit, &mut index, id_provider); let statements = &unit.implementations[0].statements; - if let AstStatement::Assignment { right: addition, .. } = &statements[0] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { right: addition, .. }), .. } = &statements[0] + { assert_type_and_hint!(&annotations, &index, addition, BOOL_TYPE, Some(BOOL_TYPE)); } - if let AstStatement::Assignment { right: addition, .. } = &statements[1] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { right: addition, .. }), .. } = &statements[1] + { assert_type_and_hint!(&annotations, &index, addition, BOOL_TYPE, Some(BOOL_TYPE)); } } @@ -432,16 +426,23 @@ fn complex_expressions_resolves_types_for_literals_directly() { let annotations = annotate_with_ids(&unit, &mut index, id_provider); let statements = &unit.implementations[0].statements; - if let AstStatement::Assignment { right, .. } = &statements[0] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { right, .. }), .. } = &statements[0] { // ((b + USINT#7) - c) assert_type_and_hint!(&annotations, &index, right, DINT_TYPE, Some(BYTE_TYPE)); - if let AstStatement::BinaryExpression { left, right: c, .. } = right.as_ref() { + if let AstNode { + stmt: AstStatement::BinaryExpression(BinaryExpression { left, right: c, .. }), .. + } = right.as_ref() + { // c assert_type_and_hint!(&annotations, &index, c, INT_TYPE, Some(DINT_TYPE)); // (b + USINT#7) assert_type_and_hint!(&annotations, &index, left, DINT_TYPE, None); - if let AstStatement::BinaryExpression { left: b, right: seven, .. } = left.as_ref() { + if let AstNode { + stmt: AstStatement::BinaryExpression(BinaryExpression { left: b, right: seven, .. }), + .. + } = left.as_ref() + { //b assert_type_and_hint!(&annotations, &index, b, SINT_TYPE, Some(DINT_TYPE)); // USINT#7 @@ -522,7 +523,10 @@ fn binary_expressions_resolves_types_with_float_comparisons() { for s in statements.iter() { assert_type_and_hint!(&annotations, &index, s, BOOL_TYPE, None); - if let AstStatement::BinaryExpression { left, right, .. } = s { + if let AstNode { + stmt: AstStatement::BinaryExpression(BinaryExpression { left, right, .. }), .. + } = s + { assert_type_and_hint!(&annotations, &index, left, REAL_TYPE, None); assert_type_and_hint!(&annotations, &index, right, REAL_TYPE, None); } else { @@ -551,7 +555,10 @@ fn binary_expressions_resolves_types_of_literals_with_float_comparisons() { for s in statements.iter() { assert_type_and_hint!(&annotations, &index, s, BOOL_TYPE, None); - if let AstStatement::BinaryExpression { left, right, .. } = s { + if let AstNode { + stmt: AstStatement::BinaryExpression(BinaryExpression { left, right, .. }), .. + } = s + { assert_type_and_hint!(&annotations, &index, left, REAL_TYPE, None); assert_type_and_hint!(&annotations, &index, right, REAL_TYPE, None); } else { @@ -678,7 +685,7 @@ fn global_initializers_resolves_types() { id_provider.clone(), ); let (annotations, ..) = TypeAnnotator::visit_unit(&index, &unit, id_provider); - let statements: Vec<&AstStatement> = + let statements: Vec<&AstNode> = unit.global_vars[0].variables.iter().map(|it| it.initializer.as_ref().unwrap()).collect(); let expected_types = @@ -760,7 +767,9 @@ fn necessary_promotions_should_be_type_hinted() { let statements = &unit.implementations[0].statements; // THEN we want a hint to promote b to DINT, BYTE + DINT should be treated as DINT - if let AstStatement::BinaryExpression { left, .. } = &statements[0] { + if let AstNode { stmt: AstStatement::BinaryExpression(BinaryExpression { left, .. }), .. } = + &statements[0] + { assert_eq!(annotations.get_type(&statements[0], &index), index.find_effective_type_by_name("DINT")); assert_eq!( (annotations.get_type(left.as_ref(), &index), annotations.get_type_hint(left.as_ref(), &index)), @@ -771,7 +780,9 @@ fn necessary_promotions_should_be_type_hinted() { } // THEN we want a hint to promote b to DINT, BYTE < DINT should be treated as BOOL - if let AstStatement::BinaryExpression { left, .. } = &statements[1] { + if let AstNode { stmt: AstStatement::BinaryExpression(BinaryExpression { left, .. }), .. } = + &statements[1] + { assert_eq!(annotations.get_type(&statements[1], &index), index.find_effective_type_by_name("BOOL")); assert_eq!( (annotations.get_type(left.as_ref(), &index), annotations.get_type_hint(left.as_ref(), &index)), @@ -803,7 +814,9 @@ fn necessary_promotions_between_real_and_literal_should_be_type_hinted() { let statements = &unit.implementations[0].statements; // THEN we want '0' to be treated as a REAL right away, the result of f > 0 should be type bool - if let AstStatement::BinaryExpression { right, .. } = &statements[0] { + if let AstNode { stmt: AstStatement::BinaryExpression(BinaryExpression { right, .. }), .. } = + &statements[0] + { assert_eq!(annotations.get_type(&statements[0], &index), index.find_effective_type_by_name("BOOL")); assert_type_and_hint!(&annotations, &index, &statements[0], BOOL_TYPE, None); @@ -1058,13 +1071,13 @@ fn assignment_expressions_resolve_types() { assert_eq!(format!("{expected_types:?}"), format!("{type_names:?}")); - if let AstStatement::Assignment { left, right, .. } = &statements[0] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { left, right, .. }), .. } = &statements[0] { assert_eq!(annotations.get_type_or_void(left, &index).get_name(), "INT"); assert_eq!(annotations.get_type_or_void(right, &index).get_name(), "BYTE"); } else { panic!("expected assignment") } - if let AstStatement::Assignment { left, right, .. } = &statements[1] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { left, right, .. }), .. } = &statements[1] { assert_eq!(annotations.get_type_or_void(left, &index).get_name(), "LWORD"); assert_eq!(annotations.get_type_or_void(right, &index).get_name(), "INT"); } else { @@ -1234,7 +1247,7 @@ fn function_call_expression_resolves_to_the_function_itself_not_its_return_type( assert_eq!(index.find_effective_type_by_name("INT"), associated_type); // AND the reference itself should be ... - let AstStatement::CallStatement { operator, .. } = &statements[0] else { unreachable!() }; + let AstNode { stmt: AstStatement::CallStatement(CallStatement { operator,..}), ..} = &statements[0] else {unreachable!()}; assert_eq!( Some(&StatementAnnotation::Function { return_type: "INT".into(), @@ -1491,7 +1504,9 @@ fn function_parameter_assignments_resolve_types() { assert_eq!(annotations.get_type_or_void(&statements[0], &index).get_name(), "INT"); assert_eq!(annotations.get(&statements[0]), Some(&StatementAnnotation::value("INT"))); - if let AstStatement::CallStatement { operator, parameters, .. } = &statements[0] { + if let AstNode { stmt: AstStatement::CallStatement(CallStatement { operator, parameters, .. }), .. } = + &statements[0] + { //make sure the call's operator resolved correctly assert_eq!(annotations.get_type_or_void(operator, &index).get_name(), VOID_TYPE); assert_eq!( @@ -1503,14 +1518,19 @@ fn function_parameter_assignments_resolve_types() { }) ); - if let Some(AstStatement::ExpressionList { expressions, .. }) = &**parameters { - if let AstStatement::Assignment { left, right, .. } = &expressions[0] { + let param = ¶meters.as_ref().unwrap(); + if let AstStatement::ExpressionList(expressions, ..) = param.get_stmt() { + if let AstNode { stmt: AstStatement::Assignment(Assignment { left, right, .. }), .. } = + &expressions[0] + { assert_eq!(annotations.get_type_or_void(left, &index).get_name(), "INT"); assert_eq!(annotations.get_type_or_void(right, &index).get_name(), "DINT"); } else { panic!("assignment expected") } - if let AstStatement::OutputAssignment { left, right, .. } = &expressions[1] { + if let AstNode { stmt: AstStatement::OutputAssignment(Assignment { left, right, .. }), .. } = + &expressions[1] + { assert_eq!(annotations.get_type_or_void(left, &index).get_name(), "SINT"); assert_eq!(annotations.get_type_or_void(right, &index).get_name(), "DINT"); } else { @@ -1553,16 +1573,23 @@ fn nested_function_parameter_assignments_resolve_types() { let (annotations, ..) = TypeAnnotator::visit_unit(&index, &unit, id_provider); let statements = &unit.implementations[2].statements; - if let AstStatement::CallStatement { parameters, .. } = &statements[0] { + if let AstNode { stmt: AstStatement::CallStatement(CallStatement { parameters, .. }), .. } = + &statements[0] + { + let parameters = parameters.as_deref(); //check the two parameters assert_parameter_assignment(parameters, 0, "INT", "DINT", &annotations, &index); assert_parameter_assignment(parameters, 1, "BOOL", "REAL", &annotations, &index); //check the inner call in the first parameter assignment of the outer call `x := baz(...)` - if let AstStatement::Assignment { right, .. } = get_expression_from_list(parameters, 0) { - if let AstStatement::CallStatement { parameters, .. } = right.as_ref() { + if let AstNode { stmt: AstStatement::Assignment(Assignment { right, .. }), .. } = + get_expression_from_list(parameters, 0) + { + if let AstNode { stmt: AstStatement::CallStatement(CallStatement { parameters, .. }), .. } = + right.as_ref() + { // the left side here should be `x` - so lets see if it got mixed up with the outer call's `x` - assert_parameter_assignment(parameters, 0, "DINT", "DINT", &annotations, &index); + assert_parameter_assignment(parameters.as_deref(), 0, "DINT", "DINT", &annotations, &index); } else { panic!("inner call") } @@ -1646,7 +1673,7 @@ fn actions_are_resolved() { let annotation = annotations.get(foo_reference); assert_eq!(Some(&StatementAnnotation::Program { qualified_name: "prg.foo".into() }), annotation); let method_call = &unit.implementations[2].statements[0]; - if let AstStatement::CallStatement { operator, .. } = method_call { + if let AstNode { stmt: AstStatement::CallStatement(CallStatement { operator, .. }), .. } = method_call { assert_eq!( Some(&StatementAnnotation::Program { qualified_name: "prg.foo".into() }), annotations.get(operator) @@ -1689,7 +1716,7 @@ fn method_references_are_resolved() { annotation ); let method_call = &unit.implementations[2].statements[0]; - if let AstStatement::CallStatement { operator, .. } = method_call { + if let AstNode { stmt: AstStatement::CallStatement(CallStatement { operator, .. }), .. } = method_call { assert_eq!( Some(&StatementAnnotation::Function { return_type: "INT".into(), @@ -1753,20 +1780,14 @@ fn variable_direct_access_type_resolved() { { let a_x1 = &statements[0]; assert_type_and_hint!(&annotations, &index, a_x1, BOOL_TYPE, None); - let AstStatement::ReferenceExpr { access: ReferenceAccess::Member(x1), base: Some(a), .. } = a_x1 - else { - unreachable!() - }; + let AstNode { stmt: AstStatement::ReferenceExpr(ReferenceExpr { access: ReferenceAccess::Member(x1), base: Some(a), ..}), ..} = a_x1 else { unreachable!()}; assert_type_and_hint!(&annotations, &index, a, INT_TYPE, None); assert_type_and_hint!(&annotations, &index, x1, BOOL_TYPE, None); } { let a_w2 = &statements[1]; assert_type_and_hint!(&annotations, &index, a_w2, WORD_TYPE, None); - let AstStatement::ReferenceExpr { access: ReferenceAccess::Member(w2), base: Some(a), .. } = a_w2 - else { - unreachable!() - }; + let AstNode { stmt: AstStatement::ReferenceExpr(ReferenceExpr { access: ReferenceAccess::Member(w2), base: Some(a), ..}), ..} = a_w2 else { unreachable!()}; assert_type_and_hint!(&annotations, &index, a, INT_TYPE, None); assert_type_and_hint!(&annotations, &index, w2, WORD_TYPE, None); } @@ -1797,12 +1818,8 @@ fn variable_direct_access_type_resolved2() { let type_names: Vec<&str> = statements .iter() .map(|s| { - let AstStatement::ReferenceExpr { access: ReferenceAccess::Member(reference), .. } = s else { - unreachable!("expected ReferenceExpr") - }; - let AstStatement::DirectAccess { index, .. } = reference.as_ref() else { - unreachable!("expected DirectAccess") - }; + let AstNode { stmt: AstStatement::ReferenceExpr(ReferenceExpr{access: ReferenceAccess::Member(reference) ,.. }), .. } = s else { unreachable!("expected ReferenceExpr") }; + let AstNode { stmt: AstStatement::DirectAccess(DirectAccess { index, .. }), .. } = reference.as_ref() else { unreachable!("expected DirectAccess") }; index }) .map(|s| annotations.get_type_or_void(s, &index).get_name()) @@ -1811,8 +1828,8 @@ fn variable_direct_access_type_resolved2() { assert_eq!(format!("{expected_types:?}"), format!("{type_names:?}")); } -fn get_expression_from_list(stmt: &Option, index: usize) -> &AstStatement { - if let Some(AstStatement::ExpressionList { expressions, .. }) = stmt { +fn get_expression_from_list(stmt: Option<&AstNode>, index: usize) -> &AstNode { + if let Some(AstStatement::ExpressionList(expressions, ..)) = stmt.map(|it| it.get_stmt()) { &expressions[index] } else { panic!("no expression_list, found {:#?}", stmt) @@ -1820,15 +1837,17 @@ fn get_expression_from_list(stmt: &Option, index: usize) -> &AstSt } fn assert_parameter_assignment( - parameters: &Option, + parameters: Option<&AstNode>, param_index: usize, left_type: &str, right_type: &str, annotations: &AnnotationMapImpl, index: &Index, ) { - if let Some(AstStatement::ExpressionList { expressions, .. }) = parameters { - if let AstStatement::Assignment { left, right, .. } = &expressions[param_index] { + if let Some(AstStatement::ExpressionList(expressions)) = parameters.map(|it| it.get_stmt()) { + if let AstNode { stmt: AstStatement::Assignment(Assignment { left, right, .. }), .. } = + &expressions[param_index] + { assert_eq!(annotations.get_type_or_void(left, index).get_name(), left_type); assert_eq!(annotations.get_type_or_void(right, index).get_name(), right_type); } else { @@ -2134,7 +2153,9 @@ fn enum_element_initialization_is_annotated_correctly() { let annotations = annotate_with_ids(&unit, &mut index, id_provider); let data_type = &unit.user_types[0].data_type; if let DataType::EnumType { elements, .. } = data_type { - if let AstStatement::Assignment { right, .. } = flatten_expression_list(elements)[2] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { right, .. }), .. } = + flatten_expression_list(elements)[2] + { assert_type_and_hint!(&annotations, &index, right, "DINT", Some("MyEnum")); } else { unreachable!() @@ -2191,17 +2212,17 @@ fn enum_initialization_is_annotated_correctly() { ); let statements = &unit.implementations[0].statements; - if let AstStatement::Assignment { right, .. } = &statements[0] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { right, .. }), .. } = &statements[0] { assert_type_and_hint!(&annotations, &index, right.as_ref(), "MyEnum", Some("MyEnum")); } else { unreachable!() } - if let AstStatement::Assignment { right, .. } = &statements[1] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { right, .. }), .. } = &statements[1] { assert_type_and_hint!(&annotations, &index, right.as_ref(), "MyEnum", Some("MyEnum")); } else { unreachable!() } - if let AstStatement::Assignment { right, .. } = &statements[2] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { right, .. }), .. } = &statements[2] { assert_type_and_hint!(&annotations, &index, right.as_ref(), "MyEnum", Some("MyEnum")); } else { unreachable!() @@ -2272,12 +2293,10 @@ fn struct_member_explicit_initialization_test() { let annotations = annotate_with_ids(&unit, &mut index, id_provider); // THEN the initializers assignments have correct annotations - let AstStatement::Assignment { right, .. } = &unit.implementations[0].statements[0] else { - unreachable!() - }; - let AstStatement::ExpressionList { expressions, .. } = right.as_ref() else { unreachable!() }; + let AstNode { stmt: AstStatement::Assignment(Assignment { right, ..}), ..} = &unit.implementations[0].statements[0] else { unreachable!()}; + let AstStatement::ExpressionList ( expressions) = right.get_stmt() else {unreachable!()}; - let AstStatement::Assignment { left, .. } = &expressions[0] else { unreachable!() }; + let AstStatement::Assignment(Assignment { left, ..}) = &expressions[0].get_stmt() else {unreachable!()}; assert_eq!( Some(&StatementAnnotation::Variable { resulting_type: "DINT".to_string(), @@ -2289,7 +2308,7 @@ fn struct_member_explicit_initialization_test() { annotations.get(left) ); - let AstStatement::Assignment { left, .. } = &expressions[1] else { unreachable!() }; + let AstStatement::Assignment(Assignment { left, ..}) = &expressions[1].get_stmt() else {unreachable!()}; assert_eq!( Some(&StatementAnnotation::Variable { resulting_type: "BYTE".to_string(), @@ -2359,10 +2378,12 @@ fn data_type_initializers_type_hint_test() { assert_eq!(Some(index.get_type("MyArray").unwrap()), annotations.get_type_hint(initializer, &index)); let initializer = index.get_type("MyArray").unwrap().initial_value.unwrap(); - if let AstStatement::Literal { kind: AstLiteral::Array(Array { elements: Some(exp_list) }), .. } = - index.get_const_expressions().get_constant_statement(&initializer).unwrap() + if let AstNode { + stmt: AstStatement::Literal(AstLiteral::Array(Array { elements: Some(exp_list) })), + .. + } = index.get_const_expressions().get_constant_statement(&initializer).unwrap() { - if let AstStatement::ExpressionList { expressions: elements, .. } = exp_list.as_ref() { + if let AstStatement::ExpressionList(elements, ..) = exp_list.get_stmt() { for ele in elements { assert_eq!( index.get_type("INT").unwrap(), @@ -2401,13 +2422,15 @@ fn data_type_initializers_multiplied_statement_type_hint_test() { assert_eq!(Some(my_array_type), annotations.get_type_hint(my_array_initializer, &index)); let my_array_type_const_initializer = my_array_type.initial_value.unwrap(); - if let AstStatement::Literal { - kind: AstLiteral::Array(Array { elements: Some(multiplied_statement) }), - .. - } = index.get_const_expressions().get_constant_statement(&my_array_type_const_initializer).unwrap() + if let AstStatement::Literal(AstLiteral::Array(Array { elements: Some(multiplied_statement) })) = + index + .get_const_expressions() + .get_constant_statement(&my_array_type_const_initializer) + .unwrap() + .get_stmt() { - if let AstStatement::MultipliedStatement { element: literal_seven, .. } = - multiplied_statement.as_ref() + if let AstStatement::MultipliedStatement(MultipliedStatement { element: literal_seven, .. }) = + multiplied_statement.get_stmt() { assert_eq!( index.find_effective_type_by_name(BYTE_TYPE), @@ -2430,13 +2453,15 @@ fn data_type_initializers_multiplied_statement_type_hint_test() { ); let global_var_const_initializer = global.initial_value.unwrap(); - if let AstStatement::Literal { - kind: AstLiteral::Array(Array { elements: Some(multiplied_statement) }), - .. - } = index.get_const_expressions().get_constant_statement(&global_var_const_initializer).unwrap() + if let AstStatement::Literal(AstLiteral::Array(Array { elements: Some(multiplied_statement) })) = + index + .get_const_expressions() + .get_constant_statement(&global_var_const_initializer) + .unwrap() + .get_stmt() { - if let AstStatement::MultipliedStatement { element: literal_seven, .. } = - multiplied_statement.as_ref() + if let AstStatement::MultipliedStatement(MultipliedStatement { element: literal_seven, .. }) = + multiplied_statement.get_stmt() { assert_eq!( index.find_effective_type_by_name(BYTE_TYPE), @@ -2480,10 +2505,11 @@ fn case_conditions_type_hint_test() { // THEN we want the case-bocks (1:, 2: , 3:) to have the type hint of the case-selector (x) - in this case BYTE //check if 'CASE x' got the type BYTE - if let AstStatement::ControlStatement { - kind: AstControlStatement::Case(CaseStatement { selector, case_blocks, .. }), + if let AstStatement::ControlStatement(AstControlStatement::Case(CaseStatement { + selector, + case_blocks, .. - } = &unit.implementations[0].statements[0] + })) = &unit.implementations[0].statements[0].get_stmt() { let type_of_x = annotations.get_type(selector, &index).unwrap(); @@ -2513,8 +2539,10 @@ fn range_type_min_max_type_hint_test() { let annotations = annotate_with_ids(&unit, &mut index, id_provider); // THEN we want the range-limits (0 and 100) to have proper type-associations - if let DataType::SubRangeType { bounds: Some(AstStatement::RangeStatement { start, end, .. }), .. } = - &unit.user_types[0].data_type + if let DataType::SubRangeType { + bounds: Some(AstNode { stmt: AstStatement::RangeStatement(RangeStatement { start, end, .. }), .. }), + .. + } = &unit.user_types[0].data_type { //lets see if start and end got their type-annotations assert_eq!( @@ -2622,16 +2650,20 @@ fn deep_struct_variable_initialization_annotates_initializer() { assert_eq!(annotations.get_type_hint(initializer, &index), index.find_effective_type_by_name("MyStruct")); //check the initializer-part - if let AstStatement::ExpressionList { expressions, .. } = initializer { + if let AstStatement::ExpressionList(expressions, ..) = initializer.get_stmt() { // v := (a := 1, b := 2) - if let AstStatement::Assignment { left, right, .. } = &expressions[0] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { left, right, .. }), .. } = + &expressions[0] + { assert_eq!(annotations.get_type(left, &index), index.find_effective_type_by_name("Point")); assert_eq!(annotations.get_type_hint(right, &index), index.find_effective_type_by_name("Point")); // (a := 1, b := 2) - if let AstStatement::ExpressionList { expressions, .. } = right.as_ref() { + if let AstStatement::ExpressionList(expressions, ..) = right.get_stmt() { // a := 1 - if let AstStatement::Assignment { left, right, .. } = &expressions[0] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { left, right, .. }), .. } = + &expressions[0] + { assert_eq!( annotations.get_type(left.as_ref(), &index), index.find_effective_type_by_name("BYTE") @@ -2645,7 +2677,9 @@ fn deep_struct_variable_initialization_annotates_initializer() { } // b := 2 - if let AstStatement::Assignment { left, right, .. } = &expressions[1] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { left, right, .. }), .. } = + &expressions[1] + { assert_eq!( annotations.get_type(left.as_ref(), &index), index.find_effective_type_by_name("SINT") @@ -2718,7 +2752,7 @@ fn action_call_should_be_annotated() { let action_call = &unit.implementations[0].statements[0]; // then accessing inout should be annotated with DINT, because it is auto-dereferenced - if let AstStatement::CallStatement { operator, .. } = action_call { + if let AstNode { stmt: AstStatement::CallStatement(CallStatement { operator, .. }), .. } = action_call { let a = annotations.get(operator); assert_eq!(Some(&StatementAnnotation::Program { qualified_name: "prg.foo".to_string() }), a); } @@ -2750,7 +2784,7 @@ fn action_body_gets_resolved() { let x_assignment = &unit.implementations[1].statements[0]; // then accessing inout should be annotated with DINT, because it is auto-dereferenced - if let AstStatement::Assignment { left, right, .. } = x_assignment { + if let AstNode { stmt: AstStatement::Assignment(Assignment { left, right, .. }), .. } = x_assignment { let a = annotations.get(left); assert_eq!( Some(&StatementAnnotation::Variable { @@ -2824,19 +2858,19 @@ fn nested_bitwise_access_resolves_correctly() { let annotations = annotate_with_ids(&unit, &mut index, id_provider); let assignment = &unit.implementations[0].statements[0]; - let AstStatement::Assignment { right, .. } = assignment else { unreachable!() }; + let AstNode { stmt: AstStatement::Assignment(Assignment { right, ..}), ..} = assignment else {unreachable!()}; assert_type_and_hint!(&annotations, &index, right, "BOOL", Some("BOOL")); //strange - let AstStatement::ReferenceExpr { base: Some(base), .. } = right.as_ref() else { unreachable!() }; + let AstNode { stmt: AstStatement::ReferenceExpr(ReferenceExpr{ base: Some(base),..}), ..} = right.as_ref() else {unreachable!()}; assert_type_and_hint!(&annotations, &index, base, "BYTE", None); - let AstStatement::ReferenceExpr { base: Some(base), .. } = base.as_ref() else { unreachable!() }; + let AstNode { stmt: AstStatement::ReferenceExpr(ReferenceExpr{ base: Some(base),..}), ..} = base.as_ref() else {unreachable!()}; assert_type_and_hint!(&annotations, &index, base, "WORD", None); - let AstStatement::ReferenceExpr { base: Some(base), .. } = base.as_ref() else { unreachable!() }; + let AstNode { stmt: AstStatement::ReferenceExpr(ReferenceExpr{ base: Some(base),..}), ..} = base.as_ref() else {unreachable!()}; assert_type_and_hint!(&annotations, &index, base, "DWORD", None); - let AstStatement::ReferenceExpr { base: Some(base), .. } = base.as_ref() else { unreachable!() }; + let AstNode { stmt: AstStatement::ReferenceExpr(ReferenceExpr{ base: Some(base),..}), ..} = base.as_ref() else {unreachable!()}; assert_type_and_hint!(&annotations, &index, base, "LWORD", None); } @@ -2862,7 +2896,7 @@ fn literals_passed_to_function_get_annotated() { let annotations = annotate_with_ids(&unit, &mut index, id_provider); let call_stmt = &unit.implementations[1].statements[0]; - if let AstStatement::CallStatement { parameters, .. } = call_stmt { + if let AstNode { stmt: AstStatement::CallStatement(CallStatement { parameters, .. }), .. } = call_stmt { let parameters = flatten_expression_list(parameters.as_ref().as_ref().unwrap()); assert_type_and_hint!(&annotations, &index, parameters[0], DINT_TYPE, Some(BYTE_TYPE)); assert_type_and_hint!(&annotations, &index, parameters[1], "__STRING_3", Some("STRING")); @@ -2900,12 +2934,12 @@ fn array_accessor_in_struct_array_is_annotated() { let qr = &unit.implementations[0].statements[0]; assert_type_and_hint!(&annotations, &index, qr, "INT", None); - let AstStatement::ReferenceExpr { base: Some(base), .. } = qr else { + let AstNode { stmt: AstStatement::ReferenceExpr(ReferenceExpr { base: Some(base), ..}), ..} = qr else { panic!("expected ReferenceExpr for {:?}", qr); }; assert_type_and_hint!(&annotations, &index, base, "__MyStruct_arr1", None); - let AstStatement::ReferenceExpr { base: Some(base), .. } = base.as_ref() else { + let AstNode { stmt: AstStatement::ReferenceExpr(ReferenceExpr { base: Some(base), ..}), ..} = base.as_ref() else { panic!("expected ReferenceExpr for {:?}", base); }; assert_type_and_hint!(&annotations, &index, base, "MyStruct", None); @@ -2935,7 +2969,7 @@ fn type_hint_should_not_hint_to_the_effective_type_but_to_the_original() { let annotations = annotate_with_ids(&unit, &mut index, id_provider); let stmt = &unit.implementations[0].statements[0]; - if let AstStatement::Assignment { left, right, .. } = stmt { + if let AstNode { stmt: AstStatement::Assignment(Assignment { left, right, .. }), .. } = stmt { assert_type_and_hint!(&annotations, &index, left, "MyInt", None); assert_type_and_hint!(&annotations, &index, right, "DINT", Some("MyInt")); } else { @@ -2965,7 +2999,7 @@ fn null_statement_should_get_a_valid_type_hint() { let var_x_type = &unit.units[0].variable_blocks[0].variables[0].data_type_declaration.get_name().unwrap(); - if let AstStatement::Assignment { right, .. } = stmt { + if let AstNode { stmt: AstStatement::Assignment(Assignment { right, .. }), .. } = stmt { assert_type_and_hint!(&annotations, &index, right, "VOID", Some(var_x_type)); } else { unreachable!(); @@ -3076,7 +3110,7 @@ fn assigning_lword_to_ptr_will_annotate_correctly() { let ptr_type = unit.units[0].variable_blocks[0].variables[0].data_type_declaration.get_name().unwrap(); - if let AstStatement::Assignment { left, right, .. } = a_eq_b { + if let AstNode { stmt: AstStatement::Assignment(Assignment { left, right, .. }), .. } = a_eq_b { assert_type_and_hint!(&annotations, &index, left, DWORD_TYPE, None); assert_type_and_hint!(&annotations, &index, right, ptr_type, Some(DWORD_TYPE)); } @@ -3105,7 +3139,7 @@ fn assigning_ptr_to_lword_will_annotate_correctly() { let ptr_type = unit.units[0].variable_blocks[0].variables[0].data_type_declaration.get_name().unwrap(); - if let AstStatement::Assignment { left, right, .. } = a_eq_b { + if let AstNode { stmt: AstStatement::Assignment(Assignment { left, right, .. }), .. } = a_eq_b { assert_type_and_hint!(&annotations, &index, left, ptr_type, None); assert_type_and_hint!(&annotations, &index, right, DWORD_TYPE, Some(ptr_type)); } @@ -3134,12 +3168,19 @@ fn assigning_ptr_to_lword_will_annotate_correctly2() { let ptr_type = unit.units[0].variable_blocks[0].variables[0].data_type_declaration.get_name().unwrap(); - if let AstStatement::Assignment { left, right, .. } = a_eq_b { + if let AstNode { stmt: AstStatement::Assignment(Assignment { left, right, .. }), .. } = a_eq_b { assert_type_and_hint!(&annotations, &index, left, DWORD_TYPE, None); assert_type_and_hint!(&annotations, &index, right, INT_TYPE, Some(DWORD_TYPE)); - if let AstStatement::ReferenceExpr { access: ReferenceAccess::Deref, base: Some(reference), .. } = - right.as_ref() + if let AstNode { + stmt: + AstStatement::ReferenceExpr(ReferenceExpr { + access: ReferenceAccess::Deref, + base: Some(reference), + .. + }), + .. + } = right.as_ref() { assert_type_and_hint!(&annotations, &index, reference, ptr_type, None); } else { @@ -3201,7 +3242,7 @@ fn pointer_assignment_with_incompatible_types_hints_correctly() { let annotations = annotate_with_ids(&unit, &mut index, id_provider); let assignment = &unit.implementations[0].statements[0]; - if let AstStatement::Assignment { left, right, .. } = assignment { + if let AstNode { stmt: AstStatement::Assignment(Assignment { left, right, .. }), .. } = assignment { assert_type_and_hint!(&annotations, &index, left, "__PRG_pt", None); assert_type_and_hint!(&annotations, &index, right, "__POINTER_TO_INT", Some("__PRG_pt")); } @@ -3234,19 +3275,9 @@ fn call_explicit_parameter_name_is_resolved() { let (annotations, ..) = TypeAnnotator::visit_unit(&index, &unit, id_provider); // should be the call statement // should contain array access as operator - let AstStatement::CallStatement { parameters, .. } = &unit.implementations[1].statements[0] else { - unreachable!("expected callstatement") - }; - let AstStatement::Assignment { left: b, .. } = - flatten_expression_list(parameters.as_ref().as_ref().unwrap())[0] - else { - unreachable!() - }; - let AstStatement::Assignment { left: a, .. } = - flatten_expression_list(parameters.as_ref().as_ref().unwrap())[1] - else { - unreachable!() - }; + let AstNode { stmt: AstStatement::CallStatement(CallStatement { parameters, ..}), ..} = &unit.implementations[1].statements[0] else { unreachable!("expected callstatement")}; + let AstNode { stmt: AstStatement::Assignment(Assignment { left: b, ..}), ..} = flatten_expression_list(parameters.as_ref().as_ref().unwrap())[0] else { unreachable!()}; + let AstNode { stmt: AstStatement::Assignment(Assignment { left: a, ..}), ..} = flatten_expression_list(parameters.as_ref().as_ref().unwrap())[1] else { unreachable!()}; assert_eq!( Some(&StatementAnnotation::Variable { @@ -3295,12 +3326,13 @@ fn call_on_function_block_array() { // should be the call statement let statements = &unit.implementations[1].statements[0]; // should contain array access as operator - let AstStatement::CallStatement { operator, .. } = statements else { - unreachable!("expected callstatement") - }; + let AstNode { stmt: AstStatement::CallStatement(CallStatement { operator, ..}), ..} = statements else { unreachable!("expected callstatement")}; assert!(matches!( operator.as_ref(), - &AstStatement::ReferenceExpr { access: ReferenceAccess::Index(_), .. } + &AstNode { + stmt: AstStatement::ReferenceExpr(ReferenceExpr { access: ReferenceAccess::Index(_), .. }), + .. + } ),); let annotation = annotations.get(operator.as_ref()); @@ -3390,9 +3422,7 @@ fn resolve_recursive_function_call() { let call = &unit.implementations[0].statements[0]; - let AstStatement::CallStatement { operator, .. } = call else { - unreachable!(); - }; + let AstStatement::CallStatement ( data) = call.get_stmt() else { unreachable!(); }; assert_eq!( Some(&StatementAnnotation::Function { @@ -3400,7 +3430,7 @@ fn resolve_recursive_function_call() { qualified_name: "foo".into(), call_name: None }), - type_map.get(&operator.get_id()) + type_map.get(&data.operator.get_id()) ); // insta::assert_snapshot!(annotated_types); @@ -3436,13 +3466,11 @@ fn resolve_recursive_program_call() { let type_map = annotations.type_map; let call = &unit.implementations[0].statements[0]; - let AstStatement::CallStatement { operator, .. } = call else { - unreachable!(); - }; + let AstStatement::CallStatement( data) = call.get_stmt() else { unreachable!(); }; assert_eq!( Some(&StatementAnnotation::Program { qualified_name: "mainProg".into() }), - type_map.get(&operator.get_id()) + type_map.get(&data.operator.get_id()) ); // insta::assert_snapshot!(annotated_types); @@ -3473,7 +3501,7 @@ fn function_block_initialization_test() { //PT will be a TIME variable, qualified name will be TON.PT let statement = unit.units[1].variable_blocks[0].variables[0].initializer.as_ref().unwrap(); - if let AstStatement::Assignment { left, .. } = statement { + if let AstNode { stmt: AstStatement::Assignment(Assignment { left, .. }), .. } = statement { let left = left.as_ref(); let annotation = annotations.get(left).unwrap(); assert_eq!( @@ -3523,7 +3551,7 @@ fn undeclared_varargs_type_hint_promoted_correctly() { let annotations = annotate_with_ids(&unit, &mut index, id_provider); let call_stmt = &unit.implementations[1].statements[0]; // THEN types smaller than LREAL/DINT get promoted while booleans and other types stay untouched. - if let AstStatement::CallStatement { parameters, .. } = call_stmt { + if let AstNode { stmt: AstStatement::CallStatement(CallStatement { parameters, .. }), .. } = call_stmt { let parameters = flatten_expression_list(parameters.as_ref().as_ref().unwrap()); assert_type_and_hint!(&annotations, &index, parameters[0], REAL_TYPE, Some(LREAL_TYPE)); assert_type_and_hint!(&annotations, &index, parameters[1], LREAL_TYPE, Some(LREAL_TYPE)); @@ -3565,7 +3593,7 @@ fn passing_a_function_as_param_correctly_resolves_as_variable() { let annotations = annotate_with_ids(&unit, &mut index, id_provider); let call_stmt = &unit.implementations[1].statements[0]; // THEN the type of the parameter resolves to the original function type - if let AstStatement::CallStatement { parameters, .. } = call_stmt { + if let AstNode { stmt: AstStatement::CallStatement(CallStatement { parameters, .. }), .. } = call_stmt { let parameters = flatten_expression_list(parameters.as_ref().as_ref().unwrap()); assert_type_and_hint!(&annotations, &index, parameters[1], DINT_TYPE, Some(DINT_TYPE)); assert_type_and_hint!(&annotations, &index, parameters[2], DINT_TYPE, Some(DINT_TYPE)); @@ -3599,11 +3627,16 @@ fn resolve_return_variable_in_nested_call() { let annotations = annotate_with_ids(&unit, &mut index, id_provider); let ass = &unit.implementations[0].statements[0]; - if let AstStatement::Assignment { right, .. } = ass { - if let AstStatement::CallStatement { parameters, .. } = right.as_ref() { + if let AstNode { stmt: AstStatement::Assignment(Assignment { right, .. }), .. } = ass { + if let AstNode { stmt: AstStatement::CallStatement(CallStatement { parameters, .. }), .. } = + right.as_ref() + { let inner_ass = flatten_expression_list(parameters.as_ref().as_ref().unwrap())[0]; - if let AstStatement::Assignment { right, .. } = inner_ass { - if let AstStatement::CallStatement { parameters, .. } = right.as_ref() { + if let AstNode { stmt: AstStatement::Assignment(Assignment { right, .. }), .. } = inner_ass { + if let AstNode { + stmt: AstStatement::CallStatement(CallStatement { parameters, .. }), .. + } = right.as_ref() + { let main = flatten_expression_list(parameters.as_ref().as_ref().unwrap())[0]; let a = annotations.get(main).unwrap(); assert_eq!( @@ -3645,27 +3678,37 @@ fn hardware_access_types_annotated() { ); let annotations = annotate_with_ids(&unit, &mut index, id_provider); - if let AstStatement::Assignment { right, .. } = &unit.implementations[0].statements[0] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { right, .. }), .. } = + &unit.implementations[0].statements[0] + { assert_type_and_hint!(&annotations, &index, right, BYTE_TYPE, Some(BYTE_TYPE)); } else { unreachable!("Must be assignment") } - if let AstStatement::Assignment { right, .. } = &unit.implementations[0].statements[1] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { right, .. }), .. } = + &unit.implementations[0].statements[1] + { assert_type_and_hint!(&annotations, &index, right, WORD_TYPE, Some(BYTE_TYPE)); } else { unreachable!("Must be assignment") } - if let AstStatement::Assignment { right, .. } = &unit.implementations[0].statements[2] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { right, .. }), .. } = + &unit.implementations[0].statements[2] + { assert_type_and_hint!(&annotations, &index, right, DWORD_TYPE, Some(INT_TYPE)); } else { unreachable!("Must be assignment") } - if let AstStatement::Assignment { right, .. } = &unit.implementations[0].statements[3] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { right, .. }), .. } = + &unit.implementations[0].statements[3] + { assert_type_and_hint!(&annotations, &index, right, BOOL_TYPE, Some(INT_TYPE)); } else { unreachable!("Must be assignment") } - if let AstStatement::Assignment { right, .. } = &unit.implementations[0].statements[4] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { right, .. }), .. } = + &unit.implementations[0].statements[4] + { assert_type_and_hint!(&annotations, &index, right, LWORD_TYPE, Some(LINT_TYPE)); } else { unreachable!("Must be assignment") @@ -3723,24 +3766,13 @@ fn multiple_pointer_with_dereference_annotates_and_nests_correctly() { index.import(std::mem::take(&mut annotations.new_index)); // THEN the expressions are nested and annotated correctly - let AstStatement::ReferenceExpr { access: ReferenceAccess::Deref, base: Some(value), .. } = &statement - else { - unreachable!("expected ReferenceExpr, but got {statement:#?}") - }; + let AstNode { stmt: AstStatement::ReferenceExpr(ReferenceExpr { access: ReferenceAccess::Deref, base: Some(value), .. }), ..} = &statement else { unreachable!("expected ReferenceExpr, but got {statement:#?}")}; assert_type_and_hint!(&annotations, &index, value, "__POINTER_TO___POINTER_TO_BYTE", None); - let AstStatement::ReferenceExpr { access: ReferenceAccess::Address, base: Some(base), .. } = - value.as_ref() - else { - unreachable!("expected ReferenceExpr, but got {value:#?}") - }; + let AstNode { stmt: AstStatement::ReferenceExpr(ReferenceExpr { access: ReferenceAccess::Address, base: Some(base), .. }), .. } = value.as_ref() else { unreachable!("expected ReferenceExpr, but got {value:#?}")}; assert_type_and_hint!(&annotations, &index, base, "__POINTER_TO_BYTE", None); - let AstStatement::ReferenceExpr { access: ReferenceAccess::Address, base: Some(base), .. } = - base.as_ref() - else { - unreachable!("expected ReferenceExpr, but got {base:#?}") - }; + let AstNode { stmt: AstStatement::ReferenceExpr(ReferenceExpr { access: ReferenceAccess::Address, base: Some(base), .. }), .. } = base.as_ref() else { unreachable!("expected ReferenceExpr, but got {base:#?}")}; assert_type_and_hint!(&annotations, &index, base, "BYTE", None); // AND the overall type of the statement is annotated correctly @@ -3768,18 +3800,24 @@ fn multiple_negative_annotates_correctly() { index.import(std::mem::take(&mut annotations.new_index)); // THEN it is correctly annotated - if let AstStatement::UnaryExpression { value, .. } = &statements[0] { + if let AstNode { stmt: AstStatement::UnaryExpression(UnaryExpression { value, .. }), .. } = &statements[0] + { assert_type_and_hint!(&annotations, &index, value, DINT_TYPE, None); - if let AstStatement::UnaryExpression { value, .. } = &value.as_ref() { + if let AstNode { stmt: AstStatement::UnaryExpression(UnaryExpression { value, .. }), .. } = + &value.as_ref() + { assert_type_and_hint!(&annotations, &index, value, DINT_TYPE, None); } } - if let AstStatement::UnaryExpression { value, .. } = &statements[1] { + if let AstNode { stmt: AstStatement::UnaryExpression(UnaryExpression { value, .. }), .. } = &statements[1] + { assert_type_and_hint!(&annotations, &index, value, DINT_TYPE, None); - if let AstStatement::UnaryExpression { value, .. } = &value.as_ref() { + if let AstNode { stmt: AstStatement::UnaryExpression(UnaryExpression { value, .. }), .. } = + &value.as_ref() + { assert_type_and_hint!(&annotations, &index, value, DINT_TYPE, None); } } @@ -3813,8 +3851,10 @@ fn array_of_struct_with_inital_values_annotated_correctly() { // there is only one member => main.arr assert_eq!(1, members.len()); - if let Some(AstStatement::ExpressionList { expressions, .. }) = - index.get_const_expressions().maybe_get_constant_statement(&members[0].initial_value) + if let Some(AstStatement::ExpressionList(expressions)) = index + .get_const_expressions() + .maybe_get_constant_statement(&members[0].initial_value) + .map(|it| it.get_stmt()) { // we initialized the array with 2 structs assert_eq!(2, expressions.len()); @@ -3830,7 +3870,7 @@ fn array_of_struct_with_inital_values_annotated_correctly() { let assignments = flatten_expression_list(e); assert_eq!(3, assignments.len()); // the last expression of the list is the assignment to myStruct.c (array initialization) - if let AstStatement::Assignment { left, right, .. } = + if let AstNode { stmt: AstStatement::Assignment(Assignment { left, right, .. }), .. } = assignments.last().expect("this should be the array initialization for myStruct.c") { // the array initialization should be annotated with the correct type hint (myStruct.c type) @@ -3890,7 +3930,9 @@ fn parameter_down_cast_test() { let statements = &unit.implementations[1].statements; // THEN check if downcasts are detected for implicit parameters - if let AstStatement::CallStatement { parameters, .. } = &statements[0] { + if let AstNode { stmt: AstStatement::CallStatement(CallStatement { parameters, .. }), .. } = + &statements[0] + { let parameters = flatten_expression_list(parameters.as_ref().as_ref().unwrap()); assert_type_and_hint!(&annotations, &index, parameters[0], INT_TYPE, Some(SINT_TYPE)); // downcast from type to type-hint! assert_type_and_hint!(&annotations, &index, parameters[1], DINT_TYPE, Some(INT_TYPE)); // downcast! @@ -3900,11 +3942,13 @@ fn parameter_down_cast_test() { } // THEN check if downcasts are detected for explicit parameters - if let AstStatement::CallStatement { parameters, .. } = &statements[1] { + if let AstNode { stmt: AstStatement::CallStatement(CallStatement { parameters, .. }), .. } = + &statements[1] + { let parameters = flatten_expression_list(parameters.as_ref().as_ref().unwrap()) .iter() .map(|it| { - if let AstStatement::Assignment { right, .. } = it { + if let AstNode { stmt: AstStatement::Assignment(Assignment { right, .. }), .. } = it { return right.as_ref(); } unreachable!() @@ -3943,7 +3987,9 @@ fn mux_generic_with_strings_is_annotated_correctly() { let mut annotations = annotate_with_ids(&unit, &mut index, id_provider); index.import(std::mem::take(&mut annotations.new_index)); - if let AstStatement::CallStatement { parameters, .. } = &unit.implementations[0].statements[0] { + if let AstNode { stmt: AstStatement::CallStatement(CallStatement { parameters, .. }), .. } = + &unit.implementations[0].statements[0] + { let list = flatten_expression_list(parameters.as_ref().as_ref().unwrap()); // MUX(2, str2, str3, str4) @@ -3984,7 +4030,11 @@ fn array_passed_to_function_with_vla_param_is_annotated_correctly() { let stmt = &unit.implementations[0].statements[0]; assert_type_and_hint!(&annotations, &index, stmt, "INT", None); - if let AstStatement::ReferenceExpr { access: ReferenceAccess::Index(_), base: Some(reference), .. } = stmt + if let AstStatement::ReferenceExpr(ReferenceExpr { + access: ReferenceAccess::Index(_), + base: Some(reference), + .. + }) = stmt.get_stmt() { assert_type_and_hint!(&annotations, &index, reference.as_ref(), "__foo_arr", Some("__arr_vla_1_int")); } else { @@ -3992,8 +4042,10 @@ fn array_passed_to_function_with_vla_param_is_annotated_correctly() { } let stmt = &unit.implementations[1].statements[0]; - if let AstStatement::CallStatement { parameters, .. } = stmt { - let Some(param) = parameters.as_ref() else { unreachable!() }; + if let AstNode { stmt: AstStatement::CallStatement(CallStatement { parameters, .. }), .. } = stmt { + let Some(param) = parameters.as_ref() else { + unreachable!() + }; assert_type_and_hint!(&annotations, &index, param, "__main_a", Some("__foo_arr")); } else { @@ -4028,8 +4080,10 @@ fn vla_with_two_arrays() { let annotations = annotate_with_ids(&unit, &mut index, id_provider); let stmt = &unit.implementations[1].statements[0]; - if let AstStatement::CallStatement { parameters, .. } = stmt { - let Some(param) = parameters.as_ref() else { unreachable!() }; + if let AstNode { stmt: AstStatement::CallStatement(CallStatement { parameters, .. }), .. } = stmt { + let Some(param) = parameters.as_ref() else { + unreachable!() + }; assert_type_and_hint!(&annotations, &index, param, "__main_a", Some("__foo_arr")); } else { @@ -4037,8 +4091,10 @@ fn vla_with_two_arrays() { } let stmt = &unit.implementations[1].statements[1]; - if let AstStatement::CallStatement { parameters, .. } = stmt { - let Some(param) = parameters.as_ref() else { unreachable!() }; + if let AstNode { stmt: AstStatement::CallStatement(CallStatement { parameters, .. }), .. } = stmt { + let Some(param) = parameters.as_ref() else { + unreachable!() + }; assert_type_and_hint!(&annotations, &index, param, "__main_b", Some("__foo_arr")); } else { @@ -4082,7 +4138,9 @@ fn action_call_statement_parameters_are_annotated_with_a_type_hint() { let mut annotations = annotate_with_ids(&unit, &mut index, id_provider); index.import(std::mem::take(&mut annotations.new_index)); - if let AstStatement::CallStatement { parameters, .. } = &unit.implementations[2].statements[0] { + if let AstNode { stmt: AstStatement::CallStatement(CallStatement { parameters, .. }), .. } = + &unit.implementations[2].statements[0] + { let list = flatten_expression_list(parameters.as_ref().as_ref().unwrap()); assert_type_and_hint!(&annotations, &index, list[0], "STRING", Some("DINT")); @@ -4148,7 +4206,12 @@ fn vla_access_is_annotated_correctly() { let annotations = annotate_with_ids(&unit, &mut index, id_provider); let stmt = &unit.implementations[0].statements[0]; - if let AstStatement::ReferenceExpr { access: ReferenceAccess::Index(_), base: Some(base), .. } = stmt { + if let AstStatement::ReferenceExpr(ReferenceExpr { + access: ReferenceAccess::Index(_), + base: Some(base), + .. + }) = stmt.get_stmt() + { // entire statement resolves to INT assert_type_and_hint!(&annotations, &index, stmt, "INT", None); @@ -4178,10 +4241,12 @@ fn vla_write_access_is_annotated_correctly() { let annotations = annotate_with_ids(&unit, &mut index, id_provider); let stmt = &unit.implementations[0].statements[0]; - if let AstStatement::Assignment { left, .. } = stmt { - if let AstStatement::ReferenceExpr { - access: ReferenceAccess::Index(_), base: Some(reference), .. - } = left.as_ref() + if let AstNode { stmt: AstStatement::Assignment(Assignment { left, .. }), .. } = stmt { + if let AstStatement::ReferenceExpr(ReferenceExpr { + access: ReferenceAccess::Index(_), + base: Some(reference), + .. + }) = left.get_stmt() { // entire statement resolves to INT assert_type_and_hint!(&annotations, &index, left.as_ref(), "INT", None); @@ -4221,12 +4286,14 @@ fn writing_value_read_from_vla_to_vla() { let stmt = &unit.implementations[0].statements[0]; // both VLA references should receive the same type hint - if let AstStatement::Assignment { left, right, .. } = stmt { - if let AstStatement::ReferenceExpr { - access: ReferenceAccess::Index(_), base: Some(reference), .. - } = left.as_ref() + if let AstNode { stmt: AstStatement::Assignment(Assignment { left, right, .. }), .. } = stmt { + if let AstStatement::ReferenceExpr(ReferenceExpr { + access: ReferenceAccess::Index(_), + base: Some(reference), + .. + }) = left.get_stmt() { - // if let AstStatement::ArrayAccess { reference, .. } = left.as_ref() { + // if let AstStatement { stmt: AstStatement::ArrayAccess(ArrayAccess { reference, ..}), ..} = left.as_ref() { // entire statement resolves to INT assert_type_and_hint!(&annotations, &index, left.as_ref(), "INT", None); @@ -4242,9 +4309,11 @@ fn writing_value_read_from_vla_to_vla() { panic!("expected an array access, got none") } - if let AstStatement::ReferenceExpr { - access: ReferenceAccess::Index(_), base: Some(reference), .. - } = right.as_ref() + if let AstStatement::ReferenceExpr(ReferenceExpr { + access: ReferenceAccess::Index(_), + base: Some(reference), + .. + }) = right.as_ref().get_stmt() { // entire statement resolves to INT assert_type_and_hint!(&annotations, &index, right.as_ref(), "INT", Some("INT")); @@ -4283,10 +4352,12 @@ fn address_of_works_on_vla() { let annotations = annotate_with_ids(&unit, &mut index, id_provider); let stmt = &unit.implementations[0].statements[0]; - if let AstStatement::Assignment { right, .. } = stmt { - if let AstStatement::ReferenceExpr { - access: ReferenceAccess::Address, base: Some(reference), .. - } = right.as_ref() + if let AstNode { stmt: AstStatement::Assignment(Assignment { right, .. }), .. } = stmt { + if let AstStatement::ReferenceExpr(ReferenceExpr { + access: ReferenceAccess::Address, + base: Some(reference), + .. + }) = right.get_stmt() { // rhs of assignment resolves to LWORD assert_type_and_hint!( @@ -4334,7 +4405,11 @@ fn by_ref_vla_access_is_annotated_correctly() { let annotations = annotate_with_ids(&unit, &mut index, id_provider); let stmt = &unit.implementations[0].statements[0]; - if let AstStatement::ReferenceExpr { access: ReferenceAccess::Index(_), base: Some(reference), .. } = stmt + if let AstStatement::ReferenceExpr(ReferenceExpr { + access: ReferenceAccess::Index(_), + base: Some(reference), + .. + }) = stmt.get_stmt() { // entire statement resolves to INT assert_type_and_hint!(&annotations, &index, stmt, "INT", None); @@ -4347,7 +4422,11 @@ fn by_ref_vla_access_is_annotated_correctly() { let stmt = &unit.implementations[0].statements[1]; - if let AstStatement::ReferenceExpr { access: ReferenceAccess::Index(_), base: Some(reference), .. } = stmt + if let AstStatement::ReferenceExpr(ReferenceExpr { + access: ReferenceAccess::Index(_), + base: Some(reference), + .. + }) = stmt.get_stmt() { // entire statement resolves to INT assert_type_and_hint!(&annotations, &index, stmt, "INT", None); @@ -4390,12 +4469,12 @@ fn vla_call_statement() { let annotations = annotate_with_ids(&unit, &mut index, id_provider); let stmt = &unit.implementations[0].statements[0]; - let AstStatement::CallStatement { parameters, .. } = stmt else { + let AstNode { stmt: AstStatement::CallStatement(CallStatement { parameters, ..}), ..} = stmt else { unreachable!(); }; - let param = parameters.as_ref().clone().unwrap(); - let statement = flatten_expression_list(¶m)[0]; + let param = parameters.as_ref().unwrap(); + let statement = flatten_expression_list(param)[0]; assert_type_and_hint!(&annotations, &index, statement, "__main_arr", Some("__foo_vla")); } @@ -4425,12 +4504,12 @@ fn vla_call_statement_with_nested_arrays() { let annotations = annotate_with_ids(&unit, &mut index, id_provider); let stmt = &unit.implementations[0].statements[0]; - let AstStatement::CallStatement { parameters, .. } = stmt else { + let AstNode { stmt: AstStatement::CallStatement(CallStatement { parameters, ..}), ..} = stmt else { unreachable!(); }; - let param = parameters.as_ref().clone().unwrap(); - let statement = flatten_expression_list(¶m)[0]; + let param = parameters.as_ref().unwrap(); + let statement = flatten_expression_list(param)[0]; assert_type_and_hint!(&annotations, &index, statement, "__main_arr_", Some("__foo_vla")); } @@ -4454,7 +4533,11 @@ fn multi_dimensional_vla_access_is_annotated_correctly() { let annotations = annotate_with_ids(&unit, &mut index, id_provider); let stmt = &unit.implementations[0].statements[0]; - if let AstStatement::ReferenceExpr { access: ReferenceAccess::Index(_), base: Some(reference), .. } = stmt + if let AstStatement::ReferenceExpr(ReferenceExpr { + access: ReferenceAccess::Index(_), + base: Some(reference), + .. + }) = stmt.get_stmt() { // entire statement resolves to INT assert_type_and_hint!(&annotations, &index, stmt, "INT", None); @@ -4485,7 +4568,9 @@ fn vla_access_assignment_receives_the_correct_type_hint() { let annotations = annotate_with_ids(&unit, &mut index, id_provider); let stmt = &unit.implementations[0].statements[0]; - let AstStatement::Assignment { right, .. } = stmt else { panic!("expected an assignment, got none") }; + let AstNode { stmt: AstStatement::Assignment(Assignment { right, ..}), ..} = stmt else { + panic!("expected an assignment, got none") + }; // RHS resolves to INT and receives type-hint to DINT assert_type_and_hint!(&annotations, &index, right.as_ref(), "INT", Some("DINT")); } @@ -4509,7 +4594,9 @@ fn multi_dim_vla_access_assignment_receives_the_correct_type_hint() { let annotations = annotate_with_ids(&unit, &mut index, id_provider); let stmt = &unit.implementations[0].statements[0]; - let AstStatement::Assignment { right, .. } = stmt else { panic!("expected an assignment, got none") }; + let AstNode { stmt: AstStatement::Assignment(Assignment { right, ..}), ..} = stmt else { + panic!("expected an assignment, got none") + }; // RHS resolves to INT and receives type-hint to DINT assert_type_and_hint!(&annotations, &index, right.as_ref(), "INT", Some("DINT")); } @@ -4541,8 +4628,8 @@ fn function_call_resolves_correctly_to_pou_rather_than_local_variable() { let annotations = annotate_with_ids(&unit, &mut index, id_provider); let stmt = &unit.implementations[1].statements[0]; - let AstStatement::CallStatement { operator, .. } = stmt else { unreachable!() }; - assert_type_and_hint!(&annotations, &index, operator, "C", None); + let AstStatement::CallStatement ( data) = stmt.get_stmt() else { unreachable!() }; + assert_type_and_hint!(&annotations, &index, &data.operator, "C", None); } #[test] @@ -4577,7 +4664,7 @@ fn override_is_resolved() { let (annotations, ..) = TypeAnnotator::visit_unit(&index, &unit, id_provider); let method_call = &unit.implementations[5].statements[0]; - if let AstStatement::CallStatement { operator, .. } = method_call { + if let AstNode { stmt: AstStatement::CallStatement(CallStatement { operator, .. }), .. } = method_call { assert_eq!( Some(&StatementAnnotation::Function { return_type: "INT".to_string(), @@ -4588,7 +4675,7 @@ fn override_is_resolved() { ); } let method_call = &unit.implementations[5].statements[1]; - if let AstStatement::CallStatement { operator, .. } = method_call { + if let AstNode { stmt: AstStatement::CallStatement(CallStatement { operator, .. }), .. } = method_call { assert_eq!( Some(&StatementAnnotation::Function { return_type: "INT".to_string(), @@ -4636,7 +4723,7 @@ fn override_in_grandparent_is_resolved() { let (annotations, ..) = TypeAnnotator::visit_unit(&index, &unit, id_provider); let method_call = &unit.implementations[7].statements[0]; - if let AstStatement::CallStatement { operator, .. } = method_call { + if let AstNode { stmt: AstStatement::CallStatement(CallStatement { operator, .. }), .. } = method_call { assert_eq!( Some(&StatementAnnotation::Function { return_type: "INT".to_string(), @@ -4647,7 +4734,7 @@ fn override_in_grandparent_is_resolved() { ); } let method_call = &unit.implementations[7].statements[1]; - if let AstStatement::CallStatement { operator, .. } = method_call { + if let AstNode { stmt: AstStatement::CallStatement(CallStatement { operator, .. }), .. } = method_call { assert_eq!( Some(&StatementAnnotation::Function { return_type: "INT".to_string(), @@ -4682,7 +4769,9 @@ fn annotate_variable_in_parent_class() { ); let (annotations, ..) = TypeAnnotator::visit_unit(&index, &unit, id_provider); - if let AstStatement::Assignment { right, .. } = &unit.implementations[1].statements[1] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { right, .. }), .. } = + &unit.implementations[1].statements[1] + { let annotation = annotations.get(right); assert_eq!( &StatementAnnotation::Variable { @@ -4695,7 +4784,9 @@ fn annotate_variable_in_parent_class() { annotation.unwrap() ); } - if let AstStatement::Assignment { left, .. } = &unit.implementations[1].statements[1] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { left, .. }), .. } = + &unit.implementations[1].statements[1] + { let annotation = annotations.get(left); assert_eq!( &StatementAnnotation::Variable { @@ -4731,7 +4822,9 @@ fn annotate_variable_in_grandparent_class() { id_provider.clone(), ); let (annotations, ..) = TypeAnnotator::visit_unit(&index, &unit, id_provider); - if let AstStatement::Assignment { left, .. } = &unit.implementations[2].statements[0] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { left, .. }), .. } = + &unit.implementations[2].statements[0] + { let annotation = annotations.get(left); assert_eq!( &StatementAnnotation::Variable { @@ -4774,7 +4867,9 @@ fn annotate_variable_in_field() { id_provider.clone(), ); let (annotations, ..) = TypeAnnotator::visit_unit(&index, &unit, id_provider); - if let AstStatement::Assignment { left, .. } = &unit.implementations[3].statements[0] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { left, .. }), .. } = + &unit.implementations[3].statements[0] + { let annotation = annotations.get(left); assert_eq!( &StatementAnnotation::Variable { @@ -4829,7 +4924,9 @@ fn annotate_method_in_super() { id_provider.clone(), ); let (annotations, ..) = TypeAnnotator::visit_unit(&index, &unit, id_provider); - if let AstStatement::Assignment { left, .. } = &unit.implementations[2].statements[0] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { left, .. }), .. } = + &unit.implementations[2].statements[0] + { let annotation = annotations.get(left); assert_eq!( &StatementAnnotation::Variable { @@ -4842,7 +4939,9 @@ fn annotate_method_in_super() { annotation.unwrap() ); } - if let AstStatement::Assignment { left, .. } = &unit.implementations[2].statements[1] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { left, .. }), .. } = + &unit.implementations[2].statements[1] + { let annotation = annotations.get(left); assert_eq!( &StatementAnnotation::Variable { @@ -4855,7 +4954,9 @@ fn annotate_method_in_super() { annotation.unwrap() ); } - if let AstStatement::Assignment { left, .. } = &unit.implementations[4].statements[0] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { left, .. }), .. } = + &unit.implementations[4].statements[0] + { let annotation = annotations.get(left); assert_eq!( &StatementAnnotation::Variable { @@ -4868,7 +4969,9 @@ fn annotate_method_in_super() { annotation.unwrap() ); } - if let AstStatement::Assignment { left, .. } = &unit.implementations[4].statements[1] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { left, .. }), .. } = + &unit.implementations[4].statements[1] + { let annotation = annotations.get(left); assert_eq!( &StatementAnnotation::Variable { @@ -4881,7 +4984,9 @@ fn annotate_method_in_super() { annotation.unwrap() ); } - if let AstStatement::Assignment { left, .. } = &unit.implementations[4].statements[2] { + if let AstNode { stmt: AstStatement::Assignment(Assignment { left, .. }), .. } = + &unit.implementations[4].statements[2] + { let annotation = annotations.get(left); assert_eq!( &StatementAnnotation::Variable { diff --git a/src/resolver/tests/resolve_generic_calls.rs b/src/resolver/tests/resolve_generic_calls.rs index 7b59173b13..7e35997a97 100644 --- a/src/resolver/tests/resolve_generic_calls.rs +++ b/src/resolver/tests/resolve_generic_calls.rs @@ -1,5 +1,5 @@ use plc_ast::{ - ast::{flatten_expression_list, AstStatement}, + ast::{flatten_expression_list, Assignment, AstNode, AstStatement, CallStatement}, provider::IdProvider, }; @@ -96,11 +96,16 @@ fn generic_call_annotated_with_correct_type() { //The return type should have the correct type assert_type_and_hint!(&annotations, &index, call, INT_TYPE, None); - if let AstStatement::CallStatement { operator, parameters, .. } = call { + if let AstNode { + stmt: AstStatement::CallStatement(CallStatement { operator, parameters, .. }, ..), .. + } = call + { //The call name should nave the correct type assert_eq!(Some("myFunc__INT"), annotations.get_call_name(operator)); //parameters should have the correct type - if let Some(AstStatement::Assignment { left, right, .. }) = &**parameters { + if let Some(AstNode { stmt: AstStatement::Assignment(Assignment { left, right, .. }), .. }) = + parameters.as_deref() + { assert_type_and_hint!(&annotations, &index, left, INT_TYPE, None); assert_type_and_hint!(&annotations, &index, right, INT_TYPE, Some(INT_TYPE)); } else { @@ -115,10 +120,13 @@ fn generic_call_annotated_with_correct_type() { //The return type should have the correct type assert_type_and_hint!(&annotations, &index, call, DINT_TYPE, None); - if let AstStatement::CallStatement { operator, parameters, .. } = call { + if let AstNode { + stmt: AstStatement::CallStatement(CallStatement { operator, parameters, .. }, ..), .. + } = call + { //The call name should nave the correct type assert_eq!(Some("myFunc__DINT"), annotations.get_call_name(operator)); - if let Some(parameter) = &**parameters { + if let Some(parameter) = parameters.as_deref() { //parameters should have the correct type assert_type_and_hint!(&annotations, &index, parameter, DINT_TYPE, Some(DINT_TYPE)); } else { @@ -133,10 +141,13 @@ fn generic_call_annotated_with_correct_type() { //The return type should have the correct type assert_type_and_hint!(&annotations, &index, call, REAL_TYPE, None); - if let AstStatement::CallStatement { operator, parameters, .. } = call { + if let AstNode { + stmt: AstStatement::CallStatement(CallStatement { operator, parameters, .. }, ..), .. + } = call + { //The call name should nave the correct type assert_eq!(Some("myFunc__REAL"), annotations.get_call_name(operator)); - if let Some(parameter) = &**parameters { + if let Some(parameter) = parameters.as_deref() { //parameters should have the correct type assert_type_and_hint!(&annotations, &index, parameter, REAL_TYPE, Some(REAL_TYPE)); } else { @@ -178,27 +189,39 @@ fn generic_call_multi_params_annotated_with_correct_type() { //The return type should have the correct type assert_type_and_hint!(&annotations, &index, call, DINT_TYPE, None); - if let AstStatement::CallStatement { operator, parameters, .. } = call { + if let AstNode { + stmt: AstStatement::CallStatement(CallStatement { operator, parameters, .. }, ..), .. + } = call + { //The call name should nave the correct type assert_eq!(Some("myFunc__DINT__INT"), annotations.get_call_name(operator)); //parameters should have the correct type - if let Some(parameters) = &**parameters { + if let Some(parameters) = parameters.as_deref() { if let [x, y, z] = flatten_expression_list(parameters)[..] { - if let AstStatement::Assignment { left, right, .. } = x { + if let AstNode { + stmt: AstStatement::Assignment(Assignment { left, right, .. }, ..), .. + } = x + { assert_type_and_hint!(&annotations, &index, left, DINT_TYPE, None); assert_type_and_hint!(&annotations, &index, right, INT_TYPE, Some(DINT_TYPE)); } else { unreachable!("Not an assignment"); } - if let AstStatement::Assignment { left, right, .. } = y { + if let AstNode { + stmt: AstStatement::Assignment(Assignment { left, right, .. }, ..), .. + } = y + { assert_type_and_hint!(&annotations, &index, left, DINT_TYPE, None); assert_type_and_hint!(&annotations, &index, right, DINT_TYPE, Some(DINT_TYPE)); } else { unreachable!("Not an assignment"); } - if let AstStatement::Assignment { left, right, .. } = z { + if let AstNode { + stmt: AstStatement::Assignment(Assignment { left, right, .. }, ..), .. + } = z + { assert_type_and_hint!(&annotations, &index, left, INT_TYPE, None); assert_type_and_hint!(&annotations, &index, right, INT_TYPE, Some(INT_TYPE)); } else { @@ -216,11 +239,14 @@ fn generic_call_multi_params_annotated_with_correct_type() { //The return type should have the correct type assert_type_and_hint!(&annotations, &index, call, DINT_TYPE, None); - if let AstStatement::CallStatement { operator, parameters, .. } = call { + if let AstNode { + stmt: AstStatement::CallStatement(CallStatement { operator, parameters, .. }, ..), .. + } = call + { //The call name should nave the correct type assert_eq!(Some("myFunc__DINT__INT"), annotations.get_call_name(operator)); //parameters should have the correct type - if let Some(parameters) = &**parameters { + if let Some(parameters) = parameters.as_deref() { if let [x, y, z] = flatten_expression_list(parameters)[..] { assert_type_and_hint!(&annotations, &index, x, INT_TYPE, Some(DINT_TYPE)); assert_type_and_hint!(&annotations, &index, y, DINT_TYPE, Some(DINT_TYPE)); @@ -238,11 +264,14 @@ fn generic_call_multi_params_annotated_with_correct_type() { //The return type should have the correct type assert_type_and_hint!(&annotations, &index, call, REAL_TYPE, None); - if let AstStatement::CallStatement { operator, parameters, .. } = call { + if let AstNode { + stmt: AstStatement::CallStatement(CallStatement { operator, parameters, .. }, ..), .. + } = call + { //The call name should nave the correct type assert_eq!(Some("myFunc__REAL__SINT"), annotations.get_call_name(operator)); //parameters should have the correct type - if let Some(parameters) = &**parameters { + if let Some(parameters) = parameters.as_deref() { if let [x, y, z] = flatten_expression_list(parameters)[..] { assert_type_and_hint!(&annotations, &index, x, REAL_TYPE, Some(REAL_TYPE)); assert_type_and_hint!(&annotations, &index, y, DINT_TYPE, Some(REAL_TYPE)); @@ -282,15 +311,12 @@ fn call_order_of_parameters_does_not_change_annotations() { ); let (annotations, ..) = TypeAnnotator::visit_unit(&index, &unit, id_provider); - fn get_parameter_with_name<'a>( - parameters_list: &[&'a AstStatement], - expected_name: &str, - ) -> &'a AstStatement { + fn get_parameter_with_name<'a>(parameters_list: &[&'a AstNode], expected_name: &str) -> &'a AstNode { parameters_list .iter() .find(|it| { - matches!(it, AstStatement::Assignment { left, .. } - if { matches!(&**left, AstStatement::ReferenceExpr {..} if { left.get_flat_reference_name() == Some(expected_name)})}) + matches!(it, AstNode { stmt: AstStatement::Assignment(Assignment { left, ..}), ..} + if { matches!(&**left, AstNode { stmt: AstStatement::ReferenceExpr(..), ..} if { left.get_flat_reference_name() == Some(expected_name)})}) }) .unwrap() } @@ -298,19 +324,28 @@ fn call_order_of_parameters_does_not_change_annotations() { // all three call-statements should give the exact same annotations // the order of the parameters should not matter for call in &unit.implementations[1].statements { - if let AstStatement::CallStatement { operator, parameters, .. } = call { + if let AstNode { + stmt: AstStatement::CallStatement(CallStatement { operator, parameters, .. }, ..), + .. + } = call + { //The call name should nave the correct type assert_eq!(Some("myFunc"), annotations.get_call_name(operator)); //parameters should have the correct type - if let Some(parameters) = &**parameters { + if let Some(parameters) = parameters.as_deref() { let parameters_list = flatten_expression_list(parameters); let [x, y, z] = [ get_parameter_with_name(¶meters_list, "x"), get_parameter_with_name(¶meters_list, "y"), get_parameter_with_name(¶meters_list, "z"), ]; - if let [AstStatement::Assignment { left: x, right: a, .. }, AstStatement::Assignment { left: y, right: b, .. }, AstStatement::Assignment { left: z, right: c, .. }] = - [x, y, z] + if let [AstNode { + stmt: AstStatement::Assignment(Assignment { left: x, right: a, .. }), .. + }, AstNode { + stmt: AstStatement::Assignment(Assignment { left: y, right: b, .. }), .. + }, AstNode { + stmt: AstStatement::Assignment(Assignment { left: z, right: c, .. }), .. + }] = [x, y, z] { assert_type_and_hint!(&annotations, &index, x, DINT_TYPE, None); assert_type_and_hint!(&annotations, &index, a, INT_TYPE, Some(DINT_TYPE)); @@ -356,15 +391,12 @@ fn call_order_of_generic_parameters_does_not_change_annotations() { ); let (annotations, ..) = TypeAnnotator::visit_unit(&index, &unit, id_provider); - fn get_parameter_with_name<'a>( - parameters_list: &[&'a AstStatement], - expected_name: &str, - ) -> &'a AstStatement { + fn get_parameter_with_name<'a>(parameters_list: &[&'a AstNode], expected_name: &str) -> &'a AstNode { parameters_list .iter() .find(|it| { - matches!(it, AstStatement::Assignment { left, .. } - if { matches!(&**left, AstStatement::ReferenceExpr{ ..} if {left.get_flat_reference_name() == Some(expected_name)})}) + matches!(it, AstNode { stmt: AstStatement::Assignment(Assignment { left, ..}), ..} + if { matches!(&**left, AstNode { stmt: AstStatement::ReferenceExpr(..), ..} if {left.get_flat_reference_name() == Some(expected_name)})}) }) .unwrap() } @@ -372,19 +404,28 @@ fn call_order_of_generic_parameters_does_not_change_annotations() { // all three call-statements should give the exact same annotations // the order of the parameters should not matter for call in &unit.implementations[1].statements { - if let AstStatement::CallStatement { operator, parameters, .. } = call { + if let AstNode { + stmt: AstStatement::CallStatement(CallStatement { operator, parameters, .. }, ..), + .. + } = call + { //The call name should nave the correct type assert_eq!(Some("myFunc__DINT__INT"), annotations.get_call_name(operator)); //parameters should have the correct type - if let Some(parameters) = &**parameters { + if let Some(parameters) = parameters.as_deref() { let parameters_list = flatten_expression_list(parameters); let [x, y, z] = [ get_parameter_with_name(¶meters_list, "x"), get_parameter_with_name(¶meters_list, "y"), get_parameter_with_name(¶meters_list, "z"), ]; - if let [AstStatement::Assignment { left: x, right: a, .. }, AstStatement::Assignment { left: y, right: b, .. }, AstStatement::Assignment { left: z, right: c, .. }] = - [x, y, z] + if let [AstNode { + stmt: AstStatement::Assignment(Assignment { left: x, right: a, .. }), .. + }, AstNode { + stmt: AstStatement::Assignment(Assignment { left: y, right: b, .. }), .. + }, AstNode { + stmt: AstStatement::Assignment(Assignment { left: z, right: c, .. }), .. + }] = [x, y, z] { assert_type_and_hint!(&annotations, &index, x, DINT_TYPE, None); assert_type_and_hint!(&annotations, &index, a, INT_TYPE, Some(DINT_TYPE)); @@ -452,14 +493,14 @@ fn builtin_generic_functions_do_not_get_specialized_calls() { assert_type_and_hint!(&annotations, &index, call, LWORD_TYPE, None); //The parameter should have the correct (original) type - if let AstStatement::CallStatement { parameters, .. } = call { + if let AstNode { stmt: AstStatement::CallStatement(CallStatement { parameters, .. }, ..), .. } = call { let params = flatten_expression_list(parameters.as_ref().as_ref().unwrap()); assert_type_and_hint!(&annotations, &index, params[0], DINT_TYPE, Some(DINT_TYPE)); } else { panic!("Expected call statement") } let call = &unit.implementations[0].statements[2]; - if let AstStatement::CallStatement { parameters, .. } = call { + if let AstNode { stmt: AstStatement::CallStatement(CallStatement { parameters, .. }, ..), .. } = call { let params = flatten_expression_list(parameters.as_ref().as_ref().unwrap()); assert_type_and_hint!(&annotations, &index, params[0], REAL_TYPE, Some(REAL_TYPE)); } else { @@ -485,7 +526,7 @@ fn builtin_adr_ref_return_annotated() { let (annotations, ..) = TypeAnnotator::visit_unit(&index, &unit, id_provider); let stmt = &unit.implementations[0].statements[0]; - if let AstStatement::Assignment { right, .. } = stmt { + if let AstNode { stmt: AstStatement::Assignment(Assignment { right, .. }, ..), .. } = stmt { let actual_type = AnnotationMap::get_type(&annotations, right, &index); let reference_type = annotations.get_type_or_void(right, &index); @@ -518,7 +559,7 @@ fn builtin_sel_param_type_is_not_changed() { let (annotations, ..) = TypeAnnotator::visit_unit(&index, &unit, id_provider); //get the type/hints for a and b in the call, they should be unchanged (DINT, None) let call = &unit.implementations[0].statements[0]; - if let AstStatement::CallStatement { parameters, .. } = call { + if let AstNode { stmt: AstStatement::CallStatement(CallStatement { parameters, .. }, ..), .. } = call { let params = flatten_expression_list(parameters.as_ref().as_ref().unwrap()); assert_type_and_hint!(&annotations, &index, params[1], DINT_TYPE, Some(DINT_TYPE)); assert_type_and_hint!(&annotations, &index, params[2], DINT_TYPE, Some(DINT_TYPE)); @@ -554,7 +595,10 @@ fn resolve_variadic_generics() { let call = &unit.implementations[1].statements[0]; //The call statement should return a DINT assert_type_and_hint!(&annotations, &index, call, DINT_TYPE, None); - if let AstStatement::CallStatement { operator, parameters, .. } = call { + if let AstNode { + stmt: AstStatement::CallStatement(CallStatement { operator, parameters, .. }, ..), .. + } = call + { assert_eq!(Some("ex__DINT"), annotations.get_call_name(operator)); let params = flatten_expression_list(parameters.as_ref().as_ref().unwrap()); assert_type_and_hint!(&annotations, &index, params[0], DINT_TYPE, Some(DINT_TYPE)); @@ -588,7 +632,7 @@ fn generic_call_gets_cast_to_biggest_type() { let call = &unit.implementations[1].statements[0]; assert_type_and_hint!(&annotations, &index, call, LREAL_TYPE, None); //Call returns LREAL - if let AstStatement::CallStatement { parameters, .. } = call { + if let AstNode { stmt: AstStatement::CallStatement(CallStatement { parameters, .. }, ..), .. } = call { let params = flatten_expression_list(parameters.as_ref().as_ref().unwrap()); assert_type_and_hint!(&annotations, &index, params[0], SINT_TYPE, Some(LREAL_TYPE)); assert_type_and_hint!(&annotations, &index, params[1], DINT_TYPE, Some(LREAL_TYPE)); @@ -708,7 +752,9 @@ fn string_ref_as_generic_resolved() { let call_statement = &unit.implementations[2].statements[0]; - if let AstStatement::CallStatement { parameters, .. } = call_statement { + if let AstNode { stmt: AstStatement::CallStatement(CallStatement { parameters, .. }, ..), .. } = + call_statement + { let parameters = flatten_expression_list(parameters.as_ref().as_ref().unwrap()); assert_type_and_hint!(&annotations, &index, parameters[0], "STRING", Some("STRING")); @@ -852,9 +898,13 @@ fn generic_string_functions_without_specific_implementation_are_annotated_correc let annotations = annotate_with_ids(&unit, &mut index, id_provider); let assignment = &unit.implementations[1].statements[0]; - if let AstStatement::Assignment { right, .. } = assignment { + if let AstNode { stmt: AstStatement::Assignment(Assignment { right, .. }, ..), .. } = assignment { assert_type_and_hint!(&annotations, &index, right, DINT_TYPE, Some(DINT_TYPE)); - if let AstStatement::CallStatement { operator, parameters, .. } = &**right { + if let AstNode { + stmt: AstStatement::CallStatement(CallStatement { operator, parameters, .. }, ..), + .. + } = &**right + { let function_annotation = annotations.get(operator).unwrap(); assert_eq!( function_annotation, @@ -1003,7 +1053,10 @@ fn literal_string_as_parameter_resolves_correctly() { let annotations = annotate_with_ids(&unit, &mut index, id_provider); let statement = &unit.implementations[1].statements[0]; - if let AstStatement::CallStatement { operator, parameters, .. } = statement { + if let AstNode { + stmt: AstStatement::CallStatement(CallStatement { operator, parameters, .. }, ..), .. + } = statement + { let parameters = flatten_expression_list(parameters.as_ref().as_ref().unwrap()); assert_type_and_hint!(&annotations, &index, parameters[0], "__STRING_54", Some(STRING_TYPE)); assert_eq!( @@ -1040,7 +1093,7 @@ fn generic_function_sharing_a_datatype_name_resolves() { let annotations = annotate_with_ids(&unit, &mut index, id_provider); let statement = &unit.implementations[1].statements[0]; - if let AstStatement::CallStatement { operator, .. } = statement { + if let AstNode { stmt: AstStatement::CallStatement(CallStatement { operator, .. }, ..), .. } = statement { assert_eq!( annotations.get(operator).unwrap(), &StatementAnnotation::Function { @@ -1080,7 +1133,7 @@ fn generic_external_function_having_same_name_as_local_variable() { let annotations = annotate_with_ids(&unit, &mut index, id_provider); let statement = &unit.implementations[1].statements[0]; - let AstStatement::Assignment { right, .. } = statement else { unreachable!() }; + let AstNode { stmt: AstStatement::Assignment(Assignment { right, ..}), ..} = statement else { unreachable!() }; assert_eq!( annotations.get(right).unwrap(), &StatementAnnotation::Value { resulting_type: "INT".to_string() } diff --git a/src/resolver/tests/resolve_literals_tests.rs b/src/resolver/tests/resolve_literals_tests.rs index f80ea68c99..2894940dea 100644 --- a/src/resolver/tests/resolve_literals_tests.rs +++ b/src/resolver/tests/resolve_literals_tests.rs @@ -1,5 +1,5 @@ use plc_ast::{ - ast::{AstStatement, ReferenceAccess, TypeNature}, + ast::{AstStatement, ReferenceAccess, ReferenceExpr, TypeNature}, provider::IdProvider, }; use plc_source::source_location::SourceLocation; @@ -283,7 +283,9 @@ fn enum_literals_target_are_annotated() { annotations.get_type_or_void(color_red, &index).get_type_information() ); - if let AstStatement::ReferenceExpr { access: ReferenceAccess::Cast(target), .. } = color_red { + if let AstStatement::ReferenceExpr(ReferenceExpr { access: ReferenceAccess::Cast(target), .. }) = + color_red.get_stmt() + { // right type gets annotated assert_eq!( &DataTypeInformation::Enum { @@ -331,7 +333,10 @@ fn casted_inner_literals_are_annotated() { let actual_types: Vec<&str> = statements .iter() .map(|it| { - if let AstStatement::ReferenceExpr { access: ReferenceAccess::Cast(target), .. } = it { + if let AstStatement::ReferenceExpr(ReferenceExpr { + access: ReferenceAccess::Cast(target), .. + }) = it.get_stmt() + { target.as_ref() } else { panic!("no cast") @@ -363,7 +368,10 @@ fn casted_literals_enums_are_annotated_correctly() { let actual_types: Vec<&str> = statements .iter() .map(|it| { - if let AstStatement::ReferenceExpr { access: ReferenceAccess::Cast(target), .. } = it { + if let AstStatement::ReferenceExpr(ReferenceExpr { + access: ReferenceAccess::Cast(target), .. + }) = it.get_stmt() + { target.as_ref() } else { unreachable!(); @@ -389,7 +397,7 @@ fn expression_list_members_are_annotated() { let expected_types = vec!["DINT", "BOOL", "REAL"]; - if let AstStatement::ExpressionList { expressions, .. } = exp_list { + if let AstStatement::ExpressionList(expressions, ..) = exp_list.get_stmt() { let actual_types: Vec<&str> = expressions.iter().map(|it| annotations.get_type_or_void(it, &index).get_name()).collect(); @@ -420,7 +428,7 @@ fn expression_lists_with_expressions_are_annotated() { let expected_types = vec!["DINT", "BOOL", "LREAL", "LREAL"]; - if let AstStatement::ExpressionList { expressions, .. } = exp_list { + if let AstStatement::ExpressionList(expressions, ..) = exp_list.get_stmt() { let actual_types: Vec<&str> = expressions.iter().map(|it| annotations.get_type_or_void(it, &index).get_name()).collect(); @@ -469,7 +477,7 @@ fn expression_list_as_array_initilization_is_annotated_correctly() { // THEN for the first statement let a_init = unit.global_vars[0].variables[0].initializer.as_ref().unwrap(); // all expressions should be annotated with the right type [INT] - if let AstStatement::ExpressionList { expressions, .. } = a_init { + if let AstStatement::ExpressionList(expressions, ..) = a_init.get_stmt() { for exp in expressions { if let Some(data_type) = annotations.get_type_hint(exp, &index) { let type_info = data_type.get_type_information(); @@ -485,7 +493,7 @@ fn expression_list_as_array_initilization_is_annotated_correctly() { // AND for the second statement let b_init = unit.global_vars[0].variables[1].initializer.as_ref().unwrap(); // all expressions should be annotated with the right type [STRING] - if let AstStatement::ExpressionList { expressions, .. } = b_init { + if let AstStatement::ExpressionList(expressions, ..) = b_init.get_stmt() { for exp in expressions { let data_type = annotations.get_type_hint(exp, &index).unwrap(); let type_info = data_type.get_type_information(); diff --git a/src/tests/adr/annotated_ast_adr.rs b/src/tests/adr/annotated_ast_adr.rs index f5042c469b..b9e10caab4 100644 --- a/src/tests/adr/annotated_ast_adr.rs +++ b/src/tests/adr/annotated_ast_adr.rs @@ -1,4 +1,4 @@ -use plc_ast::ast::{AstStatement, ReferenceAccess}; +use plc_ast::ast::{AstStatement, ReferenceAccess, ReferenceExpr}; use crate::{ index::{ArgumentType, VariableType}, @@ -134,12 +134,7 @@ fn different_types_of_annotations() { // Main.in let qualified_reference = &statements[3]; - let AstStatement::ReferenceExpr { - access: ReferenceAccess::Member(member), base: Some(qualifier), .. - } = qualified_reference - else { - unreachable!() - }; + let AstStatement::ReferenceExpr(ReferenceExpr{access: ReferenceAccess::Member(member), base: Some(qualifier)}) = qualified_reference.get_stmt() else {unreachable!()}; // // Main resolves to a Program assert_eq!( annotations.get(qualifier), diff --git a/src/tests/adr/util_macros.rs b/src/tests/adr/util_macros.rs index eb767c9f09..3adc762f11 100644 --- a/src/tests/adr/util_macros.rs +++ b/src/tests/adr/util_macros.rs @@ -11,8 +11,8 @@ pub(crate) use annotate; macro_rules! deconstruct_assignment { ($src:expr) => {{ - if let plc_ast::ast::AstStatement::Assignment { left, right, .. } = $src { - (left, right) + if let plc_ast::ast::AstNode { stmt: plc_ast::ast::AstStatement::Assignment(data), .. } = $src { + (&data.left, &data.right) } else { unreachable!(); } @@ -22,10 +22,10 @@ pub(crate) use deconstruct_assignment; macro_rules! deconstruct_call_statement { ($src:expr) => {{ - if let plc_ast::ast::AstStatement::CallStatement { operator, parameters, .. } = $src { + if let plc_ast::ast::AstNode { stmt: plc_ast::ast::AstStatement::CallStatement(data), .. } = $src { ( - operator, - parameters.as_ref().as_ref().map(plc_ast::ast::flatten_expression_list).unwrap_or_default(), + &data.operator, + data.parameters.as_deref().map(plc_ast::ast::flatten_expression_list).unwrap_or_default(), ) } else { unreachable!(); @@ -36,8 +36,9 @@ pub(crate) use deconstruct_call_statement; macro_rules! deconstruct_binary_expression { ($src:expr) => {{ - if let plc_ast::ast::AstStatement::BinaryExpression { left, right, .. } = &$src { - (left, right) + if let plc_ast::ast::AstNode { stmt: plc_ast::ast::AstStatement::BinaryExpression(data), .. } = &$src + { + (&data.left, &data.right) } else { unreachable!(); } diff --git a/src/typesystem.rs b/src/typesystem.rs index 71e719974d..0464632cb4 100644 --- a/src/typesystem.rs +++ b/src/typesystem.rs @@ -6,7 +6,7 @@ use std::{ }; use plc_ast::{ - ast::{AstStatement, Operator, PouType, TypeNature}, + ast::{AstNode, Operator, PouType, TypeNature}, literals::{AstLiteral, StringValue}, }; use plc_source::source_location::SourceLocation; @@ -310,7 +310,7 @@ impl TypeSize { /// returns the const expression represented by this TypeSize or None if this TypeSize /// is a compile-time literal - pub fn as_const_expression<'i>(&self, index: &'i Index) -> Option<&'i AstStatement> { + pub fn as_const_expression<'i>(&self, index: &'i Index) -> Option<&'i AstNode> { match self { TypeSize::LiteralInteger(_) => None, TypeSize::ConstExpression(id) => index.get_const_expressions().get_constant_statement(id), @@ -385,7 +385,7 @@ pub enum DataTypeInformation { SubRange { name: TypeId, referenced_type: TypeId, - sub_range: Range, + sub_range: Range, }, Alias { name: TypeId, diff --git a/src/validation.rs b/src/validation.rs index f5c89ed10e..d42f7ca71c 100644 --- a/src/validation.rs +++ b/src/validation.rs @@ -1,4 +1,4 @@ -use plc_ast::ast::{AstStatement, CompilationUnit}; +use plc_ast::ast::{AstNode, CompilationUnit}; use plc_derive::Validators; use plc_diagnostics::diagnostics::Diagnostic; @@ -48,7 +48,7 @@ impl<'s, T: AnnotationMap> ValidationContext<'s, T> { } } - fn find_pou(&self, stmt: &AstStatement) -> Option<&PouIndexEntry> { + fn find_pou(&self, stmt: &AstNode) -> Option<&PouIndexEntry> { self.annotations.get_call_name(stmt).and_then(|pou_name| { self.index // check if this is an instance of a function block and get the type name diff --git a/src/validation/array.rs b/src/validation/array.rs index bec71aa42e..f3b8dac843 100644 --- a/src/validation/array.rs +++ b/src/validation/array.rs @@ -5,12 +5,12 @@ //! violates both the syntax and semantic of array assignments. //! //! Design note: Because we distinguish between variables inside VAR blocks [`plc_ast::ast::Variable`] -//! and POU bodies [`plc_ast::ast::AstStatement`] and how we interact with them (e.g. infering types of +//! and POU bodies [`plc_ast::ast::AstStatementKind`] and how we interact with them (e.g. infering types of //! [`plc_ast::ast::Variable`] from the AstAnnotation being impossible right now) a wrapper enum was //! introduced to make the validation code as generic as possible. use plc_ast::{ - ast::{AstStatement, Variable}, + ast::{AstNode, AstStatement, Variable}, literals::AstLiteral, }; use plc_diagnostics::diagnostics::Diagnostic; @@ -21,7 +21,7 @@ use super::{ValidationContext, Validator, Validators}; /// Indicates whether an array was defined in a VAR block or a POU body pub(super) enum Wrapper<'a> { - Statement(&'a AstStatement), + Statement(&'a AstNode), Variable(&'a Variable), } @@ -54,14 +54,14 @@ pub(super) fn validate_array_assignment( } } -/// Takes an [`AstStatement`] and returns its length as if it was an array. For example calling this function +/// Takes an [`AstStatementKind`] and returns its length as if it was an array. For example calling this function /// on an expression-list such as `[(...), (...)]` would return 2. -fn statement_to_array_length(statement: &AstStatement) -> usize { - match statement { +fn statement_to_array_length(statement: &AstNode) -> usize { + match statement.get_stmt() { AstStatement::ExpressionList { .. } => 1, - AstStatement::MultipliedStatement { multiplier, .. } => *multiplier as usize, - AstStatement::Literal { kind: AstLiteral::Array(arr), .. } => match arr.elements() { - Some(AstStatement::ExpressionList { expressions, .. }) => { + AstStatement::MultipliedStatement(data) => data.multiplier as usize, + AstStatement::Literal(AstLiteral::Array(arr)) => match arr.elements() { + Some(AstNode { stmt: AstStatement::ExpressionList(expressions), .. }) => { expressions.iter().map(statement_to_array_length).sum::() } @@ -72,18 +72,18 @@ fn statement_to_array_length(statement: &AstStatement) -> usize { // Any literal other than an array can be counted as 1 AstStatement::Literal { .. } => 1, - any => { + _any => { // XXX: Not sure what else could be in here - log::warn!("Array size-counting for {any:?} not covered; validation _might_ be wrong"); + log::warn!("Array size-counting for {statement:?} not covered; validation _might_ be wrong"); 0 } } } impl<'a> Wrapper<'a> { - fn get_rhs(&self) -> Option<&'a AstStatement> { + fn get_rhs(&self) -> Option<&'a AstNode> { match self { - Wrapper::Statement(AstStatement::Assignment { right, .. }) => Some(right), + Wrapper::Statement(AstNode { stmt: AstStatement::Assignment(data), .. }) => Some(&data.right), Wrapper::Variable(variable) => variable.initializer.as_ref(), _ => None, } @@ -95,8 +95,8 @@ impl<'a> Wrapper<'a> { { match self { Wrapper::Statement(statement) => { - let AstStatement::Assignment { left, .. } = statement else { return None }; - context.annotations.get_type(left, context.index).map(|it| it.get_type_information()) + let AstNode{ stmt: AstStatement::Assignment ( data), ..} = statement else { return None }; + context.annotations.get_type(&data.left, context.index).map(|it| it.get_type_information()) } Wrapper::Variable(variable) => variable diff --git a/src/validation/statement.rs b/src/validation/statement.rs index 58fd0bce6b..c548480021 100644 --- a/src/validation/statement.rs +++ b/src/validation/statement.rs @@ -1,7 +1,10 @@ use std::{collections::HashSet, mem::discriminant}; use plc_ast::{ - ast::{flatten_expression_list, AstStatement, DirectAccessType, Operator, ReferenceAccess}, + ast::{ + flatten_expression_list, AstNode, AstStatement, DirectAccess, DirectAccessType, Operator, + ReferenceAccess, + }, control_statements::{AstControlStatement, ConditionalBlock}, literals::{Array, AstLiteral, StringValue}, }; @@ -36,10 +39,10 @@ macro_rules! visit_all_statements { pub fn visit_statement( validator: &mut Validator, - statement: &AstStatement, + statement: &AstNode, context: &ValidationContext, ) { - match statement { + match statement.get_stmt() { // AstStatement::EmptyStatement { location, id } => (), // AstStatement::DefaultValue { location, id } => (), // AstStatement::LiteralInteger { value, location, id } => (), @@ -50,55 +53,62 @@ pub fn visit_statement( // AstStatement::LiteralReal { value, location, id } => (), // AstStatement::LiteralBool { value, location, id } => (), // AstStatement::LiteralString { value, is_wide, location, id } => (), - AstStatement::Literal { kind: AstLiteral::Array(Array { elements: Some(elements), .. }), .. } => { + AstStatement::Literal(AstLiteral::Array(Array { elements: Some(elements) })) => { visit_statement(validator, elements.as_ref(), context); } - AstStatement::CastStatement { target, type_name, location, .. } => { - if let AstStatement::Literal { kind: literal, .. } = target.as_ref() { - validate_cast_literal(validator, literal, statement, type_name, location, context); + AstStatement::CastStatement(data) => { + if let AstStatement::Literal(literal) = data.target.get_stmt() { + validate_cast_literal( + validator, + literal, + statement, + &data.type_name, + &statement.get_location(), + context, + ); } } - AstStatement::MultipliedStatement { element, .. } => { - visit_statement(validator, element, context); + AstStatement::MultipliedStatement(data) => { + visit_statement(validator, &data.element, context); } - AstStatement::ReferenceExpr { access, base, .. } => { - if let Some(base) = base { + AstStatement::ReferenceExpr(data) => { + if let Some(base) = &data.base { visit_statement(validator, base, context); } - validate_reference_expression(access, validator, context, statement, base); + validate_reference_expression(&data.access, validator, context, statement, &data.base); } - AstStatement::BinaryExpression { operator, left, right, .. } => { - visit_all_statements!(validator, context, left, right); - visit_binary_expression(validator, statement, operator, left, right, context); + AstStatement::BinaryExpression(data) => { + visit_all_statements!(validator, context, &data.left, &data.right); + visit_binary_expression(validator, statement, &data.operator, &data.left, &data.right, context); } - AstStatement::UnaryExpression { value, .. } => { - visit_statement(validator, value, context); + AstStatement::UnaryExpression(data) => { + visit_statement(validator, &data.value, context); } - AstStatement::ExpressionList { expressions, .. } => { + AstStatement::ExpressionList(expressions) => { expressions.iter().for_each(|element| visit_statement(validator, element, context)) } - AstStatement::RangeStatement { start, end, .. } => { - visit_all_statements!(validator, context, start, end); + AstStatement::RangeStatement(data) => { + visit_all_statements!(validator, context, &data.start, &data.end); } - AstStatement::Assignment { left, right, .. } => { - visit_statement(validator, left, context); - visit_statement(validator, right, context); + AstStatement::Assignment(data) => { + visit_statement(validator, &data.left, context); + visit_statement(validator, &data.right, context); - validate_assignment(validator, right, Some(left), &statement.get_location(), context); + validate_assignment(validator, &data.right, Some(&data.left), &statement.get_location(), context); validate_array_assignment(validator, context, Wrapper::Statement(statement)); } - AstStatement::OutputAssignment { left, right, .. } => { - visit_statement(validator, left, context); - visit_statement(validator, right, context); + AstStatement::OutputAssignment(data) => { + visit_statement(validator, &data.left, context); + visit_statement(validator, &data.right, context); - validate_assignment(validator, right, Some(left), &statement.get_location(), context); + validate_assignment(validator, &data.right, Some(&data.left), &statement.get_location(), context); } - AstStatement::CallStatement { operator, parameters, .. } => { - validate_call(validator, operator, parameters, &context.set_is_call()); + AstStatement::CallStatement(data) => { + validate_call(validator, &data.operator, data.parameters.as_deref(), &context.set_is_call()); } - AstStatement::ControlStatement { kind, .. } => validate_control_statement(validator, kind, context), - AstStatement::CaseCondition { condition, .. } => { + AstStatement::ControlStatement(kind) => validate_control_statement(validator, kind, context), + AstStatement::CaseCondition(condition) => { // if we get here, then a `CaseCondition` is used outside a `CaseStatement` // `CaseCondition` are used as a marker for `CaseStatements` and are not passed as such to the `CaseStatement.case_blocks` // see `control_parser` `parse_case_statement()` @@ -120,8 +130,8 @@ fn validate_reference_expression( access: &ReferenceAccess, validator: &mut Validator, context: &ValidationContext, - statement: &AstStatement, - base: &Option>, + statement: &AstNode, + base: &Option>, ) { match access { ReferenceAccess::Member(m) => { @@ -153,10 +163,8 @@ fn validate_reference_expression( visit_statement(validator, c.as_ref(), context); // see if we try to cast a literal - if let ( - AstStatement::Literal { kind: literal, .. }, - Some(StatementAnnotation::Type { type_name }), - ) = (c.as_ref(), base.as_ref().and_then(|it| context.annotations.get(it))) + if let (AstStatement::Literal(literal), Some(StatementAnnotation::Type { type_name })) = + (c.get_stmt(), base.as_ref().and_then(|it| context.annotations.get(it))) { validate_cast_literal( validator, @@ -191,28 +199,26 @@ fn validate_reference_expression( fn validate_address_of_expression( validator: &mut Validator, - target: &AstStatement, + target: &AstNode, location: SourceLocation, context: &ValidationContext, ) { let a = context.annotations.get(target); //TODO: resolver should also annotate information whether this results in an LValue or RValue // array-access results in a value, but it is an LValue :-( - if !matches!(a, Some(StatementAnnotation::Variable { .. })) - && !matches!(target, AstStatement::ReferenceExpr { access: ReferenceAccess::Index(_), .. }) - { + if !matches!(a, Some(StatementAnnotation::Variable { .. })) && !target.is_array_access() { validator.push_diagnostic(Diagnostic::invalid_operation("Invalid address-of operation", location)); } } fn validate_direct_access( - m: &AstStatement, - base: Option<&AstStatement>, + m: &AstNode, + base: Option<&AstNode>, context: &ValidationContext, validator: &mut Validator, ) { - if let (AstStatement::DirectAccess { access, index, .. }, Some(base_annotation)) = ( - m, + if let (AstStatement::DirectAccess(DirectAccess { access, index }), Some(base_annotation)) = ( + m.get_stmt(), // FIXME: should we consider the hint if one is available? base.and_then(|base| context.annotations.get(base)), ) { @@ -277,7 +283,7 @@ fn validate_cast_literal( // TODO: i feel like literal is misleading here. can be a reference aswell (INT#x) validator: &mut Validator, literal: &AstLiteral, - statement: &AstStatement, + statement: &AstNode, type_name: &str, location: &SourceLocation, context: &ValidationContext, @@ -333,13 +339,13 @@ fn validate_cast_literal( fn validate_access_index( validator: &mut Validator, context: &ValidationContext, - access_index: &AstStatement, + access_index: &AstNode, access_type: &DirectAccessType, target_type: &DataTypeInformation, location: &SourceLocation, ) { - match *access_index { - AstStatement::Literal { kind: AstLiteral::Integer(value), .. } => { + match *access_index.get_stmt() { + AstStatement::Literal(AstLiteral::Integer(value)) => { if !helper::is_in_range( access_type, value.try_into().unwrap_or_default(), @@ -354,7 +360,7 @@ fn validate_access_index( )) } } - AstStatement::ReferenceExpr { .. } => { + AstStatement::ReferenceExpr(_) => { let ref_type = context.annotations.get_type_or_void(access_index, context.index); if !ref_type.get_type_information().is_int() { validator.push_diagnostic(Diagnostic::incompatible_directaccess_variable( @@ -369,8 +375,8 @@ fn validate_access_index( fn validate_reference( validator: &mut Validator, - statement: &AstStatement, - base: Option<&AstStatement>, + statement: &AstNode, + base: Option<&AstNode>, ref_name: &str, location: &SourceLocation, context: &ValidationContext, @@ -418,15 +424,15 @@ fn validate_reference( fn visit_array_access( validator: &mut Validator, - reference: &AstStatement, - access: &AstStatement, + reference: &AstNode, + access: &AstNode, context: &ValidationContext, ) { let target_type = context.annotations.get_type_or_void(reference, context.index).get_type_information(); match target_type { - DataTypeInformation::Array { dimensions, .. } => match access { - AstStatement::ExpressionList { expressions, .. } => { + DataTypeInformation::Array { dimensions, .. } => match access.get_stmt() { + AstStatement::ExpressionList(expressions) => { validate_array_access_dimensions(dimensions.len(), expressions.len(), validator, access); for (i, exp) in expressions.iter().enumerate() { @@ -444,8 +450,8 @@ fn visit_array_access( source: StructSource::Internal(typesystem::InternalType::VariableLengthArray { ndims, .. }), .. } => { - let dims = match access { - AstStatement::ExpressionList { expressions, .. } => expressions.len(), + let dims = match access.get_stmt() { + AstStatement::ExpressionList(expressions) => expressions.len(), _ => 1, }; @@ -459,12 +465,7 @@ fn visit_array_access( } } -fn validate_array_access_dimensions( - ndims: usize, - dims: usize, - validator: &mut Validator, - access: &AstStatement, -) { +fn validate_array_access_dimensions(ndims: usize, dims: usize, validator: &mut Validator, access: &AstNode) { if ndims != dims { validator.push_diagnostic(Diagnostic::invalid_array_access(ndims, dims, access.get_location())) } @@ -472,12 +473,12 @@ fn validate_array_access_dimensions( fn validate_array_access( validator: &mut Validator, - access: &AstStatement, + access: &AstNode, dimensions: &[Dimension], dimension_index: usize, context: &ValidationContext, ) { - if let AstStatement::Literal { kind: AstLiteral::Integer(value), .. } = access { + if let AstStatement::Literal(AstLiteral::Integer(value)) = access.get_stmt() { if let Some(dimension) = dimensions.get(dimension_index) { if let Ok(range) = dimension.get_range(context.index) { if !(range.start as i128 <= *value && range.end as i128 >= *value) { @@ -501,10 +502,10 @@ fn validate_array_access( fn visit_binary_expression( validator: &mut Validator, - statement: &AstStatement, + statement: &AstNode, operator: &Operator, - left: &AstStatement, - right: &AstStatement, + left: &AstNode, + right: &AstNode, context: &ValidationContext, ) { match operator { @@ -529,10 +530,10 @@ fn visit_binary_expression( fn validate_binary_expression( validator: &mut Validator, - statement: &AstStatement, + statement: &AstNode, operator: &Operator, - left: &AstStatement, - right: &AstStatement, + left: &AstNode, + right: &AstNode, context: &ValidationContext, ) { let left_type = context.annotations.get_type_or_void(left, context.index).get_type_information(); @@ -610,20 +611,20 @@ fn compare_function_exists( /// Validates if an argument can be passed to a function with [`VariableType::Output`] and /// [`VariableType::InOut`] parameter types by checking if the argument is a reference (e.g. `foo(x)`) or /// an assignment (e.g. `foo(x := y)`, `foo(x => y)`). If neither is the case a diagnostic is generated. -fn validate_call_by_ref(validator: &mut Validator, param: &VariableIndexEntry, arg: &AstStatement) { +fn validate_call_by_ref(validator: &mut Validator, param: &VariableIndexEntry, arg: &AstNode) { let ty = param.argument_type.get_inner(); if !matches!(ty, VariableType::Output | VariableType::InOut) { return; } - match (arg.can_be_assigned_to(), arg) { + match (arg.can_be_assigned_to(), arg.get_stmt()) { (true, _) => (), // Output assignments are optional, e.g. `foo(bar => )` is considered valid - (false, AstStatement::EmptyStatement { .. }) if matches!(ty, VariableType::Output) => (), + (false, AstStatement::EmptyStatement(_)) if matches!(ty, VariableType::Output) => (), - (false, AstStatement::Assignment { right, .. } | AstStatement::OutputAssignment { right, .. }) => { - validate_call_by_ref(validator, param, right); + (false, AstStatement::Assignment(data) | AstStatement::OutputAssignment(data)) => { + validate_call_by_ref(validator, param, &data.right); } _ => validator.push_diagnostic(Diagnostic::invalid_argument_type( @@ -636,8 +637,8 @@ fn validate_call_by_ref(validator: &mut Validator, param: &VariableIndexEntry, a fn validate_assignment( validator: &mut Validator, - right: &AstStatement, - left: Option<&AstStatement>, + right: &AstNode, + left: Option<&AstNode>, location: &SourceLocation, context: &ValidationContext, ) { @@ -705,7 +706,7 @@ fn validate_assignment( left_type.get_type_information().get_name(), location.clone(), )); - } else if !matches!(right, AstStatement::Literal { .. }) { + } else if right.is_literal() { // TODO: See https://github.com/PLC-lang/rusty/issues/857 // validate_assignment_type_sizes(validator, left_type, right_type, location, context) } @@ -752,7 +753,7 @@ fn validate_variable_length_array_assignment( fn is_valid_assignment( left_type: &DataType, right_type: &DataType, - right: &AstStatement, + right: &AstNode, index: &Index, location: &SourceLocation, validator: &mut Validator, @@ -782,13 +783,13 @@ fn is_valid_assignment( fn is_valid_string_to_char_assignment( left_type: &DataTypeInformation, right_type: &DataTypeInformation, - right: &AstStatement, + right: &AstNode, location: &SourceLocation, validator: &mut Validator, ) -> bool { // TODO: casted literals and reference if left_type.is_compatible_char_and_string(right_type) { - if let AstStatement::Literal { kind: AstLiteral::String(StringValue { value, .. }), .. } = right { + if let AstStatement::Literal(AstLiteral::String(StringValue { value, .. })) = right.get_stmt() { if value.len() == 1 { return true; } else { @@ -871,8 +872,8 @@ fn is_aggregate_type_missmatch(left_type: &DataType, right_type: &DataType, inde fn validate_call( validator: &mut Validator, - operator: &AstStatement, - parameters: &Option, + operator: &AstNode, + parameters: Option<&AstNode>, context: &ValidationContext, ) { // visit called pou @@ -885,7 +886,7 @@ fn validate_call( } let declared_parameters = context.index.get_declared_parameters(pou.get_name()); - let passed_parameters = parameters.as_ref().map(flatten_expression_list).unwrap_or_default(); + let passed_parameters = parameters.map(flatten_expression_list).unwrap_or_default(); let mut are_implicit_parameters = true; let mut variable_location_in_parent = vec![]; @@ -950,9 +951,9 @@ fn validate_call( // selector, case_blocks, else_block fn validate_case_statement( validator: &mut Validator, - selector: &AstStatement, + selector: &AstNode, case_blocks: &[ConditionalBlock], - else_block: &[AstStatement], + else_block: &[AstNode], context: &ValidationContext, ) { visit_statement(validator, selector, context); @@ -962,7 +963,7 @@ fn validate_case_statement( let condition = b.condition.as_ref(); // invalid case conditions - if matches!(condition, AstStatement::Assignment { .. } | AstStatement::CallStatement { .. }) { + if matches!(condition.get_stmt(), AstStatement::Assignment(_) | AstStatement::CallStatement(_)) { validator.push_diagnostic(Diagnostic::invalid_case_condition(condition.get_location())); } @@ -978,7 +979,7 @@ fn validate_case_statement( }) .map(|v| { // check for duplicates if we got a value - if let Some(AstStatement::Literal { kind: AstLiteral::Integer(value), .. }) = v { + if let Some(AstNode { stmt: AstStatement::Literal(AstLiteral::Integer(value)), .. }) = v { if !cases.insert(value) { validator.push_diagnostic(Diagnostic::duplicate_case_condition( &value, @@ -1000,7 +1001,7 @@ fn validate_case_statement( /// statement fn validate_type_nature( validator: &mut Validator, - statement: &AstStatement, + statement: &AstNode, context: &ValidationContext, ) { if let Some(type_hint) = context diff --git a/src/validation/tests/snapshots/rusty__validation__tests__array_validation_test__array_initialization_validation.snap b/src/validation/tests/snapshots/rusty__validation__tests__array_validation_test__array_initialization_validation.snap index ad105d8d74..eb91d8a089 100644 --- a/src/validation/tests/snapshots/rusty__validation__tests__array_validation_test__array_initialization_validation.snap +++ b/src/validation/tests/snapshots/rusty__validation__tests__array_validation_test__array_initialization_validation.snap @@ -9,10 +9,10 @@ error: Array assignments must be surrounded with `[]` │ ^^^^ Array assignments must be surrounded with `[]` error: Array assignments must be surrounded with `[]` - ┌─ :6:41 + ┌─ :6:40 │ 6 │ arr3 : ARRAY[1..2] OF myStruct := ((var1 := 1), (var1 := 2, var2 := (1, 2))); // Missing `[` - │ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Array assignments must be surrounded with `[]` + │ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Array assignments must be surrounded with `[]` error: Array assignments must be surrounded with `[]` ┌─ :6:74 @@ -21,10 +21,10 @@ error: Array assignments must be surrounded with `[]` │ ^^^^ Array assignments must be surrounded with `[]` error: Array assignments must be surrounded with `[]` - ┌─ :7:41 + ┌─ :7:40 │ 7 │ arr4 : ARRAY[1..2] OF myStruct := ((var1 := 1), (var1 := 2, var2 := 1, 2)); // Missing `[` - │ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Array assignments must be surrounded with `[]` + │ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Array assignments must be surrounded with `[]` error: Array assignments must be surrounded with `[]` ┌─ :7:73 diff --git a/src/validation/tests/snapshots/rusty__validation__tests__array_validation_test__assignment_structs.snap b/src/validation/tests/snapshots/rusty__validation__tests__array_validation_test__assignment_structs.snap index d236e724e3..e0d7cbeceb 100644 --- a/src/validation/tests/snapshots/rusty__validation__tests__array_validation_test__assignment_structs.snap +++ b/src/validation/tests/snapshots/rusty__validation__tests__array_validation_test__assignment_structs.snap @@ -15,10 +15,10 @@ error: Array assignments must be surrounded with `[]` │ ^^^^^^^^^^^^^ Array assignments must be surrounded with `[]` error: Array assignments must be surrounded with `[]` - ┌─ :18:49 + ┌─ :18:48 │ 18 │ foo_invalid_0 : FOO := (idx := 0, arr := ((arr := (1, 2, 3, 4, 5)), (arr := (1, 2, 3, 4, 5)))); - │ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Array assignments must be surrounded with `[]` + │ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Array assignments must be surrounded with `[]` error: Array assignments must be surrounded with `[]` ┌─ :19:57 @@ -33,9 +33,9 @@ error: Array assignments must be surrounded with `[]` │ ^^^^^^^^^^^^^ Array assignments must be surrounded with `[]` error: Array assignments must be surrounded with `[]` - ┌─ :19:49 + ┌─ :19:48 │ 19 │ foo_invalid_1 : FOO := (idx := 0, arr := ((arr := (1, 2, 3, 4, 5)), (arr := (1, 2, 3, 4, 5)))); - │ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Array assignments must be surrounded with `[]` + │ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Array assignments must be surrounded with `[]` diff --git a/src/validation/types.rs b/src/validation/types.rs index 7139133997..b85ad9a60b 100644 --- a/src/validation/types.rs +++ b/src/validation/types.rs @@ -1,4 +1,4 @@ -use plc_ast::ast::{AstStatement, DataType, DataTypeDeclaration, PouType, UserTypeDeclaration}; +use plc_ast::ast::{AstNode, AstStatement, DataType, DataTypeDeclaration, PouType, UserTypeDeclaration}; use plc_diagnostics::diagnostics::Diagnostic; use plc_source::source_location::SourceLocation; @@ -49,9 +49,10 @@ fn validate_data_type(validator: &mut Validator, data_type: &DataType, location: validator.push_diagnostic(Diagnostic::empty_variable_block(location.clone())); } } - DataType::EnumType { elements: AstStatement::ExpressionList { expressions, .. }, .. } - if expressions.is_empty() => - { + DataType::EnumType { + elements: AstNode { stmt: AstStatement::ExpressionList(expressions), .. }, + .. + } if expressions.is_empty() => { validator.push_diagnostic(Diagnostic::empty_variable_block(location.clone())); } DataType::VarArgs { referenced_type: None, sized: true } => {