Skip to content

Commit

Permalink
Remove unhandled node types, add strict node deserialization, more wo…
Browse files Browse the repository at this point in the history
…rk on AST building code
  • Loading branch information
camden-smallwood committed Jun 8, 2023
1 parent b4f1621 commit 25ede88
Show file tree
Hide file tree
Showing 8 changed files with 428 additions and 153 deletions.
1 change: 1 addition & 0 deletions eth-lang-utils/src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ pub enum NodeType {
RevertStatement,
ForStatement,
WhileStatement,
DoWhileStatement,
ModifierDefinition,
ModifierInvocation,
EnumDefinition,
Expand Down
2 changes: 2 additions & 0 deletions solidity/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,7 @@ edition = "2021"
[dependencies]
eth-lang-utils = { path = "../eth-lang-utils" }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
simd-json = "0.7"
solang-parser = "0.3.0"
yul = { path = "../yul" }
205 changes: 154 additions & 51 deletions solidity/src/ast/builder.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use eth_lang_utils::ast::*;
use super::*;

#[derive(Default)]
Expand Down Expand Up @@ -403,19 +404,7 @@ impl AstBuilder {

FunctionDefinition {
base_functions: None, // TODO
body: input.body.as_ref()
.map(|body| {
match body {
solang_parser::pt::Statement::Block { loc, statements, .. } => Block {
statements: statements.iter()
.map(|stmt| self.build_statement(stmt))
.collect(),
src: self.loc_to_src(loc),
id: self.next_node_id(),
},
stmt => panic!("Invalid function body statement: {stmt:?}"),
}
}),
body: input.body.as_ref().map(|body| self.build_block(function_scope, body)),
documentation: None, // TODO
function_selector: None, // TODO
implemented: input.body.is_some(), // TODO: is this correct?
Expand Down Expand Up @@ -446,15 +435,7 @@ impl AstBuilder {
scope: function_scope,
state_variable: false,
storage_location: parameter.as_ref()
.map(|x| {
x.storage.as_ref()
.map(|x| match x {
solang_parser::pt::StorageLocation::Memory(_) => StorageLocation::Memory,
solang_parser::pt::StorageLocation::Storage(_) => StorageLocation::Storage,
solang_parser::pt::StorageLocation::Calldata(_) => StorageLocation::Calldata,
})
.unwrap_or_else(|| StorageLocation::Default)
})
.map(|x| self.build_storage_location(&x.storage))
.unwrap(),
type_descriptions: TypeDescriptions {
type_identifier: None, // TODO
Expand Down Expand Up @@ -487,15 +468,7 @@ impl AstBuilder {
scope: function_scope,
state_variable: false,
storage_location: parameter.as_ref()
.map(|x| {
x.storage.as_ref()
.map(|x| match x {
solang_parser::pt::StorageLocation::Memory(_) => StorageLocation::Memory,
solang_parser::pt::StorageLocation::Storage(_) => StorageLocation::Storage,
solang_parser::pt::StorageLocation::Calldata(_) => StorageLocation::Calldata,
})
.unwrap_or_else(|| StorageLocation::Default)
})
.map(|x| self.build_storage_location(&x.storage))
.unwrap(),
type_descriptions: TypeDescriptions {
type_identifier: None, // TODO
Expand Down Expand Up @@ -603,6 +576,16 @@ impl AstBuilder {
}
}

pub fn build_storage_location(&mut self, input: &Option<solang_parser::pt::StorageLocation>) -> StorageLocation {
input.as_ref()
.map(|x| match x {
solang_parser::pt::StorageLocation::Memory(_) => StorageLocation::Memory,
solang_parser::pt::StorageLocation::Storage(_) => StorageLocation::Storage,
solang_parser::pt::StorageLocation::Calldata(_) => StorageLocation::Calldata,
})
.unwrap_or_else(|| StorageLocation::Default)
}

pub fn build_type_name(&mut self, input: &solang_parser::pt::Expression) -> TypeName {
match input {
solang_parser::pt::Expression::Type(_loc, ty) => match ty {
Expand Down Expand Up @@ -781,15 +764,7 @@ impl AstBuilder {
scope: -1, // TODO
state_variable: false,
storage_location: parameter.as_ref()
.map(|x| {
x.storage.as_ref()
.map(|x| match x {
solang_parser::pt::StorageLocation::Memory(_) => StorageLocation::Memory,
solang_parser::pt::StorageLocation::Storage(_) => StorageLocation::Storage,
solang_parser::pt::StorageLocation::Calldata(_) => StorageLocation::Calldata,
})
.unwrap_or_else(|| StorageLocation::Default)
})
.map(|x| self.build_storage_location(&x.storage))
.unwrap(),
type_descriptions: TypeDescriptions {
type_identifier: None, // TODO
Expand Down Expand Up @@ -824,15 +799,7 @@ impl AstBuilder {
scope: -1, // TODO
state_variable: false,
storage_location: parameter.as_ref()
.map(|x| {
x.storage.as_ref()
.map(|x| match x {
solang_parser::pt::StorageLocation::Memory(_) => StorageLocation::Memory,
solang_parser::pt::StorageLocation::Storage(_) => StorageLocation::Storage,
solang_parser::pt::StorageLocation::Calldata(_) => StorageLocation::Calldata,
})
.unwrap_or_else(|| StorageLocation::Default)
})
.map(|x| self.build_storage_location(&x.storage))
.unwrap(),
type_descriptions: TypeDescriptions {
type_identifier: None, // TODO
Expand Down Expand Up @@ -872,8 +839,144 @@ impl AstBuilder {
}
}

pub fn build_statement(&mut self, input: &solang_parser::pt::Statement) -> Statement {
todo!()
pub fn build_block(&mut self, scope: i64, input: &solang_parser::pt::Statement) -> Block {
match input {
solang_parser::pt::Statement::Block { loc, statements, .. } => Block {
statements: statements.iter()
.map(|stmt| self.build_statement(scope, stmt))
.collect(),
src: self.loc_to_src(loc),
id: self.next_node_id(),
},
stmt => panic!("Invalid block statement: {stmt:?}"),
}
}

pub fn build_block_or_statement(&mut self, scope: i64, input: &solang_parser::pt::Statement) -> BlockOrStatement {
match input {
solang_parser::pt::Statement::Block { .. } => BlockOrStatement::Block(Box::new(self.build_block(scope, input))),
_ => BlockOrStatement::Statement(Box::new(self.build_statement(scope, input))),
}
}

pub fn build_statement(&mut self, scope: i64, input: &solang_parser::pt::Statement) -> Statement {
match input {
solang_parser::pt::Statement::Block { unchecked, .. } => {
if !*unchecked {
panic!("Generic block passed as statement: {input:#?}");
}

let unchecked_scope = self.next_scope();

Statement::UncheckedBlock(self.build_block(unchecked_scope, input))
}

solang_parser::pt::Statement::Assembly { loc, dialect, flags, block } => todo!(),

solang_parser::pt::Statement::Args(_, _) => todo!(),

solang_parser::pt::Statement::If(loc, condition, true_body, false_body) => {
let if_true_scope = self.next_scope();
let if_false_scope = self.next_scope();

Statement::IfStatement(IfStatement {
condition: self.build_expression(condition),
true_body: self.build_block_or_statement(if_true_scope, true_body),
false_body: false_body.as_ref().map(|x| self.build_block_or_statement(if_false_scope, x)),
src: self.loc_to_src(loc),
id: self.next_node_id(),
})
}

solang_parser::pt::Statement::While(loc, condition, body) => {
let while_scope = self.next_scope();

Statement::WhileStatement(WhileStatement {
condition: self.build_expression(condition),
body: self.build_block_or_statement(while_scope, body),
src: self.loc_to_src(loc),
id: self.next_node_id(),
})
}

solang_parser::pt::Statement::Expression(_loc, x) => {
Statement::ExpressionStatement(ExpressionStatement {
expression: self.build_expression(x),
})
}

solang_parser::pt::Statement::VariableDefinition(loc, variable, value) => {
Statement::VariableDeclarationStatement(VariableDeclarationStatement {
assignments: vec![], // TODO
declarations: vec![
Some(VariableDeclaration {
base_functions: None, // TODO
constant: false,
documentation: None,
function_selector: None, // TODO
indexed: None,
mutability: None, // TODO
name: variable.name.as_ref().map(|x| x.name.clone()).unwrap(),
name_location: variable.name.as_ref().map(|x| self.loc_to_src(&x.loc)),
overrides: None, // TODO
scope,
state_variable: false, // TODO
storage_location: self.build_storage_location(&variable.storage),
type_descriptions: TypeDescriptions {
type_identifier: None, // TODO
type_string: None, // TODO
},
type_name: Some(self.build_type_name(&variable.ty)),
value: None,
visibility: Visibility::Public, // TODO
src: self.loc_to_src(&variable.loc),
id: self.next_node_id(),
})
],
initial_value: value.as_ref().map(|x| self.build_expression(x)),
src: self.loc_to_src(loc),
id: self.next_node_id(),
})
}

solang_parser::pt::Statement::For(loc, init, condition, update, body) => {
let for_scope = self.next_scope();

Statement::ForStatement(ForStatement {
initialization_expression: init.as_ref().map(|x| Box::new(self.build_statement(for_scope, x))),
condition: condition.as_ref().map(|x| self.build_expression(x)),
loop_expression: update.as_ref().map(|x| Box::new(Statement::ExpressionStatement(ExpressionStatement {
expression: self.build_expression(x),
}))),
body: body.as_ref().map(|x| self.build_block_or_statement(for_scope, x)).unwrap(),
src: self.loc_to_src(loc),
id: self.next_node_id(),
})
}

solang_parser::pt::Statement::DoWhile(_loc, _body, _condition) => todo!(),

solang_parser::pt::Statement::Continue(loc) => {
Statement::Continue {
src: self.loc_to_src(loc),
id: self.next_node_id(),
}
}

solang_parser::pt::Statement::Break(loc) => {
Statement::Break {
src: self.loc_to_src(loc),
id: self.next_node_id(),
}
}

solang_parser::pt::Statement::Return(_, _) => todo!(),
solang_parser::pt::Statement::Revert(_, _, _) => todo!(),
solang_parser::pt::Statement::RevertNamedArgs(_, _, _) => todo!(),
solang_parser::pt::Statement::Emit(_, _) => todo!(),
solang_parser::pt::Statement::Try(_, _, _, _) => todo!(),
solang_parser::pt::Statement::Error(_) => todo!(),
}
}

pub fn build_literal(&mut self, input: &solang_parser::pt::Expression) -> Literal {
Expand Down
22 changes: 21 additions & 1 deletion solidity/src/ast/contracts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ impl Display for ContractKind {
}
}

#[derive(Clone, Debug, Deserialize, Eq, Serialize, PartialEq)]
#[derive(Clone, Debug, Eq, Serialize, PartialEq)]
#[serde(untagged)]
pub enum ContractDefinitionNode {
UsingForDirective(UsingForDirective),
Expand All @@ -31,6 +31,26 @@ pub enum ContractDefinitionNode {
UserDefinedValueTypeDefinition(UserDefinedValueTypeDefinition),
}

impl<'de> Deserialize<'de> for ContractDefinitionNode {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let json = serde_json::Value::deserialize(deserializer)?;
let node_type = json.get("nodeType").unwrap().as_str().unwrap();

match node_type {
"UsingForDirective" => Ok(ContractDefinitionNode::UsingForDirective(serde_json::from_value(json).unwrap())),
"StructDefinition" => Ok(ContractDefinitionNode::StructDefinition(serde_json::from_value(json).unwrap())),
"EnumDefinition" => Ok(ContractDefinitionNode::EnumDefinition(serde_json::from_value(json).unwrap())),
"VariableDeclaration" => Ok(ContractDefinitionNode::VariableDeclaration(serde_json::from_value(json).unwrap())),
"EventDefinition" => Ok(ContractDefinitionNode::EventDefinition(serde_json::from_value(json).unwrap())),
"FunctionDefinition" => Ok(ContractDefinitionNode::FunctionDefinition(serde_json::from_value(json).unwrap())),
"ModifierDefinition" => Ok(ContractDefinitionNode::ModifierDefinition(serde_json::from_value(json).unwrap())),
"ErrorDefinition" => Ok(ContractDefinitionNode::ErrorDefinition(serde_json::from_value(json).unwrap())),
"UserDefinedValueTypeDefinition" => Ok(ContractDefinitionNode::UserDefinedValueTypeDefinition(serde_json::from_value(json).unwrap())),
_ => panic!("Invalid contract definition node type: {node_type}"),
}
}
}

impl Display for ContractDefinitionNode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Expand Down
35 changes: 25 additions & 10 deletions solidity/src/ast/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use eth_lang_utils::ast::*;
use serde::{Deserialize, Serialize};
use std::fmt::{Display, Write};

#[derive(Clone, Debug, Deserialize, Eq, Serialize, PartialEq)]
#[derive(Clone, Debug, Eq, Serialize, PartialEq)]
#[serde(untagged)]
pub enum Expression {
Literal(Literal),
Expand All @@ -20,13 +20,31 @@ pub enum Expression {
ElementaryTypeNameExpression(ElementaryTypeNameExpression),
TupleExpression(TupleExpression),
NewExpression(NewExpression),
}

#[serde(rename_all = "camelCase")]
UnhandledExpression {
node_type: NodeType,
src: Option<String>,
id: Option<NodeID>,
},
impl<'de> Deserialize<'de> for Expression {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let json = serde_json::Value::deserialize(deserializer)?;
let node_type = json.get("nodeType").unwrap().as_str().unwrap();

match node_type {
"Literal" => Ok(Expression::Literal(serde_json::from_value(json).unwrap())),
"Identifier" => Ok(Expression::Identifier(serde_json::from_value(json).unwrap())),
"UnaryOperation" => Ok(Expression::UnaryOperation(serde_json::from_value(json).unwrap())),
"BinaryOperation" => Ok(Expression::BinaryOperation(serde_json::from_value(json).unwrap())),
"Conditional" => Ok(Expression::Conditional(serde_json::from_value(json).unwrap())),
"Assignment" => Ok(Expression::Assignment(serde_json::from_value(json).unwrap())),
"FunctionCall" => Ok(Expression::FunctionCall(serde_json::from_value(json).unwrap())),
"FunctionCallOptions" => Ok(Expression::FunctionCallOptions(serde_json::from_value(json).unwrap())),
"IndexAccess" => Ok(Expression::IndexAccess(serde_json::from_value(json).unwrap())),
"IndexRangeAccess" => Ok(Expression::IndexRangeAccess(serde_json::from_value(json).unwrap())),
"MemberAccess" => Ok(Expression::MemberAccess(serde_json::from_value(json).unwrap())),
"ElementaryTypeNameExpression" => Ok(Expression::ElementaryTypeNameExpression(serde_json::from_value(json).unwrap())),
"TupleExpression" => Ok(Expression::TupleExpression(serde_json::from_value(json).unwrap())),
"NewExpression" => Ok(Expression::NewExpression(serde_json::from_value(json).unwrap())),
_ => panic!("Invalid expression node type: {node_type:?}"),
}
}
}

impl Expression {
Expand Down Expand Up @@ -123,7 +141,6 @@ impl Expression {
Expression::ElementaryTypeNameExpression(ElementaryTypeNameExpression { type_descriptions, .. }) => Some(type_descriptions),
Expression::TupleExpression(TupleExpression { type_descriptions, .. }) => Some(type_descriptions),
Expression::NewExpression(NewExpression { type_descriptions, .. }) => Some(type_descriptions),
Expression::UnhandledExpression { .. } => None
}
}

Expand All @@ -143,8 +160,6 @@ impl Expression {
Expression::ElementaryTypeNameExpression(ElementaryTypeNameExpression { src, .. }) => src.as_str(),
Expression::TupleExpression(TupleExpression { src, .. }) => src.as_str(),
Expression::NewExpression(NewExpression { src, .. }) => src.as_str(),
Expression::UnhandledExpression { src: Some(src), .. } => src.as_str(),
_ => return Err(std::io::Error::from(std::io::ErrorKind::NotFound))
})
}
}
Expand Down
Loading

0 comments on commit 25ede88

Please sign in to comment.