Skip to content

Commit

Permalink
Some cleanup to Yul AST code
Browse files Browse the repository at this point in the history
  • Loading branch information
camden-smallwood committed Oct 9, 2023
1 parent 9405f8f commit 284aeb3
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 93 deletions.
18 changes: 8 additions & 10 deletions solidity/src/ast/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1868,9 +1868,9 @@ impl AstBuilder {

solang_parser::pt::YulStatement::VariableDeclaration(_loc, variables, value) => {
YulStatement::YulVariableDeclaration(YulVariableDeclaration {
value: value.as_ref()
value: Some(value.as_ref()
.map(|x| self.build_yul_expression(x))
.unwrap(),
.unwrap()),
variables: variables.iter()
.map(|x| self.build_yul_typed_name(x))
.collect(),
Expand Down Expand Up @@ -1967,14 +1967,12 @@ impl AstBuilder {
match case {
solang_parser::pt::YulSwitchOptions::Case(_loc, expression, body) => YulCase {
body: self.build_yul_block(body),
value: self.build_yul_expression(expression),
value: Some(self.build_yul_expression(expression)),
},

solang_parser::pt::YulSwitchOptions::Default(_loc, body) => YulCase {
body: self.build_yul_block(body),
value: YulExpression::YulIdentifier(YulIdentifier {
name: "default".to_string(),
}),
value: None,
},
}
}
Expand All @@ -1991,12 +1989,12 @@ impl AstBuilder {
pub fn build_yul_function_definition(&mut self, function: &solang_parser::pt::YulFunctionDefinition) -> YulFunctionDefinition {
YulFunctionDefinition {
name: function.id.name.clone(),
parameters: function.params.iter()
parameters: Some(function.params.iter()
.map(|param| self.build_yul_typed_name(param))
.collect(),
return_parameters: function.returns.iter()
.collect()),
return_parameters: Some(function.returns.iter()
.map(|param| self.build_yul_typed_name(param))
.collect(),
.collect()),
body: self.build_yul_block(&function.body),
}
}
Expand Down
6 changes: 3 additions & 3 deletions solidity/src/ast/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ pub enum TypeName {
impl Display for TypeName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TypeName::ElementaryTypeName(elementary_type_name) => elementary_type_name.fmt(f),
TypeName::UserDefinedTypeName(user_defined_type_name) => user_defined_type_name.fmt(f),
TypeName::FunctionTypeName(function_type_name) => function_type_name.type_descriptions.type_string.as_ref().unwrap().fmt(f),
TypeName::ArrayTypeName(array_type_name) => array_type_name.fmt(f),
TypeName::Mapping(mapping) => mapping.fmt(f),
TypeName::UserDefinedTypeName(user_defined_type_name) => user_defined_type_name.fmt(f),
TypeName::ElementaryTypeName(elementary_type_name) => elementary_type_name.fmt(f),
TypeName::String(string) => string.fmt(f),
_ => unimplemented!(),
}
}
}
Expand Down
154 changes: 86 additions & 68 deletions solidity/src/ast/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2382,6 +2382,23 @@ impl AstVisitor for AstVisitorData<'_> {
self.leave_yul_function_definition(&mut context)?;
}

YulStatement::YulBlock(yul_block) => {
let mut context = YulBlockContext {
source_units: context.source_units,
current_source_unit: context.current_source_unit,
contract_definition: context.contract_definition,
definition_node: context.definition_node,
blocks: context.blocks,
statement: context.statement,
inline_assembly: context.inline_assembly,
yul_blocks: context.yul_blocks,
yul_block,
};

self.visit_yul_block(&mut context)?;
self.leave_yul_block(&mut context)?;
}

YulStatement::YulLeave => {
self.visit_yul_leave(context)?;
self.leave_yul_leave(context)?;
Expand Down Expand Up @@ -2589,21 +2606,23 @@ impl AstVisitor for AstVisitorData<'_> {
visitor.visit_yul_case(context)?;
}

let mut value_context = YulExpressionContext {
source_units: context.source_units,
current_source_unit: context.current_source_unit,
contract_definition: context.contract_definition,
definition_node: context.definition_node,
blocks: context.blocks,
statement: context.statement,
inline_assembly: context.inline_assembly,
yul_blocks: context.yul_blocks,
yul_statement: Some(context.yul_statement),
yul_expression: &context.yul_case.value,
};

self.visit_yul_expression(&mut value_context)?;
self.leave_yul_expression(&mut value_context)?;
if let Some(value) = context.yul_case.value.as_ref() {
let mut value_context = YulExpressionContext {
source_units: context.source_units,
current_source_unit: context.current_source_unit,
contract_definition: context.contract_definition,
definition_node: context.definition_node,
blocks: context.blocks,
statement: context.statement,
inline_assembly: context.inline_assembly,
yul_blocks: context.yul_blocks,
yul_statement: Some(context.yul_statement),
yul_expression: value,
};

self.visit_yul_expression(&mut value_context)?;
self.leave_yul_expression(&mut value_context)?;
}

let mut body_context = YulBlockContext {
source_units: context.source_units,
Expand Down Expand Up @@ -2687,21 +2706,23 @@ impl AstVisitor for AstVisitorData<'_> {
visitor.visit_yul_variable_declaration(context)?;
}

let mut value_context = YulExpressionContext {
source_units: context.source_units,
current_source_unit: context.current_source_unit,
contract_definition: context.contract_definition,
definition_node: context.definition_node,
blocks: context.blocks,
statement: context.statement,
inline_assembly: context.inline_assembly,
yul_blocks: context.yul_blocks,
yul_statement: Some(context.yul_statement),
yul_expression: &context.yul_variable_declaration.value,
};

self.visit_yul_expression(&mut value_context)?;
self.leave_yul_expression(&mut value_context)?;
if let Some(value) = context.yul_variable_declaration.value.as_ref() {
let mut value_context = YulExpressionContext {
source_units: context.source_units,
current_source_unit: context.current_source_unit,
contract_definition: context.contract_definition,
definition_node: context.definition_node,
blocks: context.blocks,
statement: context.statement,
inline_assembly: context.inline_assembly,
yul_blocks: context.yul_blocks,
yul_statement: Some(context.yul_statement),
yul_expression: value,
};

self.visit_yul_expression(&mut value_context)?;
self.leave_yul_expression(&mut value_context)?;
}

Ok(())
}
Expand Down Expand Up @@ -2767,40 +2788,44 @@ impl AstVisitor for AstVisitorData<'_> {
visitor.visit_yul_function_definition(context)?;
}

for parameter in context.yul_function_definition.parameters.iter() {
let mut context = YulTypedNameContext {
source_units: context.source_units,
current_source_unit: context.current_source_unit,
contract_definition: context.contract_definition,
definition_node: context.definition_node,
blocks: context.blocks,
statement: context.statement,
inline_assembly: context.inline_assembly,
yul_blocks: context.yul_blocks,
yul_statement: Some(context.yul_statement),
yul_typed_name: parameter,
};

self.visit_yul_typed_name(&mut context)?;
self.leave_yul_typed_name(&mut context)?;
if let Some(parameters) = context.yul_function_definition.parameters.as_ref() {
for parameter in parameters.iter() {
let mut context = YulTypedNameContext {
source_units: context.source_units,
current_source_unit: context.current_source_unit,
contract_definition: context.contract_definition,
definition_node: context.definition_node,
blocks: context.blocks,
statement: context.statement,
inline_assembly: context.inline_assembly,
yul_blocks: context.yul_blocks,
yul_statement: Some(context.yul_statement),
yul_typed_name: parameter,
};

self.visit_yul_typed_name(&mut context)?;
self.leave_yul_typed_name(&mut context)?;
}
}

for parameter in context.yul_function_definition.return_parameters.iter() {
let mut context = YulTypedNameContext {
source_units: context.source_units,
current_source_unit: context.current_source_unit,
contract_definition: context.contract_definition,
definition_node: context.definition_node,
blocks: context.blocks,
statement: context.statement,
inline_assembly: context.inline_assembly,
yul_blocks: context.yul_blocks,
yul_statement: Some(context.yul_statement),
yul_typed_name: parameter,
};
if let Some(return_parameters) = context.yul_function_definition.return_parameters.as_ref() {
for parameter in return_parameters.iter() {
let mut context = YulTypedNameContext {
source_units: context.source_units,
current_source_unit: context.current_source_unit,
contract_definition: context.contract_definition,
definition_node: context.definition_node,
blocks: context.blocks,
statement: context.statement,
inline_assembly: context.inline_assembly,
yul_blocks: context.yul_blocks,
yul_statement: Some(context.yul_statement),
yul_typed_name: parameter,
};

self.visit_yul_typed_name(&mut context)?;
self.leave_yul_typed_name(&mut context)?;
self.visit_yul_typed_name(&mut context)?;
self.leave_yul_typed_name(&mut context)?;
}
}

let mut context = YulBlockContext {
Expand Down Expand Up @@ -2939,13 +2964,6 @@ impl AstVisitor for AstVisitorData<'_> {
self.visit_yul_function_call(&mut function_call_context)?;
self.leave_yul_function_call(&mut function_call_context)?;
}

YulExpression::UnhandledYulExpression { node_type, src, id } => {
println!(
"WARNING: Unhandled yul expression: {:?} {:?} {:?}",
node_type, src, id
);
}
}

Ok(())
Expand Down
50 changes: 38 additions & 12 deletions yul/src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,26 @@ pub struct ExternalReferenceData {
value_size: NodeID,
}

#[derive(Clone, Debug, Deserialize, Eq, Serialize, PartialEq)]
#[derive(Clone, Debug, Eq, Serialize, PartialEq)]
#[serde(untagged)]
pub enum YulExpression {
YulLiteral(YulLiteral),
YulIdentifier(YulIdentifier),
YulFunctionCall(YulFunctionCall),
}

impl<'de> Deserialize<'de> for YulExpression {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let json = serde_json::Value::deserialize(deserializer)?;
let node_type = json.get("nodeType").unwrap().as_str().unwrap();

#[serde(rename_all = "camelCase")]
UnhandledYulExpression {
node_type: String,
src: Option<String>,
id: Option<NodeID>,
},
match node_type {
"YulLiteral" => Ok(YulExpression::YulLiteral(serde_json::from_value(json).unwrap())),
"YulIdentifier" => Ok(YulExpression::YulIdentifier(serde_json::from_value(json).unwrap())),
"YulFunctionCall" => Ok(YulExpression::YulFunctionCall(serde_json::from_value(json).unwrap())),
_ => panic!("Invalid yul expression node type: {node_type}"),
}
}
}

#[derive(Clone, Debug, Deserialize, Eq, Serialize, PartialEq)]
Expand Down Expand Up @@ -86,6 +93,7 @@ pub enum YulStatement {
YulVariableDeclaration(YulVariableDeclaration),
YulExpressionStatement(YulExpressionStatement),
YulFunctionDefinition(YulFunctionDefinition),
YulBlock(YulBlock),
YulLeave,
YulBreak,
YulContinue,
Expand All @@ -104,6 +112,7 @@ impl<'de> Deserialize<'de> for YulStatement {
"YulVariableDeclaration" => Ok(YulStatement::YulVariableDeclaration(serde_json::from_value(json).unwrap())),
"YulExpressionStatement" => Ok(YulStatement::YulExpressionStatement(serde_json::from_value(json).unwrap())),
"YulFunctionDefinition" => Ok(YulStatement::YulFunctionDefinition(serde_json::from_value(json).unwrap())),
"YulBlock" => Ok(YulStatement::YulBlock(serde_json::from_value(json).unwrap())),
"YulLeave" => Ok(YulStatement::YulLeave),
"YulBreak" => Ok(YulStatement::YulBreak),
"YulContinue" => Ok(YulStatement::YulContinue),
Expand All @@ -126,11 +135,28 @@ pub struct YulSwitch {
pub expression: YulExpression,
}

#[derive(Clone, Debug, Deserialize, Eq, Serialize, PartialEq)]
#[derive(Clone, Debug, Eq, Serialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct YulCase {
pub body: YulBlock,
pub value: YulExpression,
pub value: Option<YulExpression>,
}

impl<'de> Deserialize<'de> for YulCase {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let json = serde_json::Value::deserialize(deserializer)?;
let body = json.get("body").unwrap();
let value = json.get("value").unwrap();

Ok(YulCase {
body: serde_json::from_value(body.clone()).unwrap(),
value: if matches!(value.as_str(), Some("default")) {
None
} else {
Some(serde_json::from_value(value.clone()).unwrap())
},
})
}
}

#[derive(Clone, Debug, Deserialize, Eq, Serialize, PartialEq)]
Expand All @@ -152,7 +178,7 @@ pub struct YulAssignment {
#[derive(Clone, Debug, Deserialize, Eq, Serialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct YulVariableDeclaration {
pub value: YulExpression,
pub value: Option<YulExpression>,
pub variables: Vec<YulTypedName>,
}

Expand All @@ -173,7 +199,7 @@ pub struct YulExpressionStatement {
#[serde(rename_all = "camelCase")]
pub struct YulFunctionDefinition {
pub name: String,
pub parameters: Vec<YulTypedName>,
pub return_parameters: Vec<YulTypedName>,
pub parameters: Option<Vec<YulTypedName>>,
pub return_parameters: Option<Vec<YulTypedName>>,
pub body: YulBlock,
}

0 comments on commit 284aeb3

Please sign in to comment.