diff --git a/.gitattributes b/.gitattributes index c1417e6..8c8430e 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,4 +1,3 @@ gigahorse-toolchain/* linguist-vendored -*.html linguist-language=Rust -*.py linguist-language=Rust \ No newline at end of file +*.html linguist-vendored \ No newline at end of file diff --git a/README.md b/README.md index 879c36d..32cb8e1 100644 --- a/README.md +++ b/README.md @@ -10,3 +10,13 @@ An Attacker Contract Identification Tool Implemented in Rust based on BlockWatch ```shell RUST_LOG=info cargo run -- ETH 0x10C509AA9ab291C76c45414e7CdBd375e1D5AcE8 ``` + +## Run + +### Local + +### Docker + +```shell +docker build -t lydia:v1.0 . +``` diff --git a/src/contract/data_structure.rs b/src/contract/data_structure.rs index 1af7816..e7603eb 100644 --- a/src/contract/data_structure.rs +++ b/src/contract/data_structure.rs @@ -163,7 +163,7 @@ impl From for SensitiveOpOfBadRandomnessAfterExternalCall { pub struct SensitiveOpOfDoSAfterExternalCall { pub func_sign: String, - pub call_stm: String, + pub call_stmt: String, pub call_ret_var: String, pub call_ret_index: String, pub sensitive_var: String, @@ -174,7 +174,7 @@ impl From for SensitiveOpOfDoSAfterExternalCall { // Parse and create SensitiveOpOfDoSAfterExternalCall SensitiveOpOfDoSAfterExternalCall { func_sign: record[0].to_string(), - call_stm: record[1].to_string(), + call_stmt: record[1].to_string(), call_ret_var: record[2].to_string(), call_ret_index: record[3].to_string(), sensitive_var: record[4].to_string(), @@ -199,3 +199,141 @@ impl From for TaintedCallArg { } } } + +pub struct FuncArgToSensitiveVar { + pub func_sign: String, + pub call_stmt: String, + pub func_arg: String, + pub func_arg_index: String, + pub sensitive_var: String, + pub call_func_sign: String, +} + +impl From for FuncArgToSensitiveVar { + fn from(record: StringRecord) -> Self { + // Parse and create FuncArgToSensitiveVar + FuncArgToSensitiveVar { + func_sign: record[0].to_string(), + call_stmt: record[1].to_string(), + func_arg: record[2].to_string(), + func_arg_index: record[3].to_string(), + sensitive_var: record[4].to_string(), + call_func_sign: record[5].to_string(), + } + } +} + +pub struct CallRetToFuncRet { + pub call_stmt: String, + pub call_ret: String, + pub call_ret_index: String, + pub func_sign: String, + pub func_ret_index: String, + pub func_ret: String, +} + +impl From for CallRetToFuncRet { + fn from(record: StringRecord) -> Self { + // Parse and create SpreadCallRetToFuncRet + CallRetToFuncRet { + call_stmt: record[0].to_string(), + call_ret: record[1].to_string(), + call_ret_index: record[2].to_string(), + func_sign: record[3].to_string(), + func_ret_index: record[4].to_string(), + func_ret: record[5].to_string(), + } + } +} + +pub struct CallRetToCallArg { + pub call_stmt1: String, + pub call_ret: String, + pub call_ret_index: String, + pub call_stmt2: String, + pub call_arg_index: String, + pub call_arg: String, +} + +impl From for CallRetToCallArg { + fn from(record: StringRecord) -> Self { + // Parse and create SpreadCallRetToCallArg + CallRetToCallArg { + call_stmt1: record[0].to_string(), + call_ret: record[1].to_string(), + call_ret_index: record[2].to_string(), + call_stmt2: record[3].to_string(), + call_arg_index: record[4].to_string(), + call_arg: record[5].to_string(), + } + } +} + +pub struct CallArgs { + pub call_stmt: String, + pub call_arg_index: String, +} + +pub struct FuncArgToCallArg { + pub func_sign: String, + pub func_arg_index: String, + pub func_arg: String, + pub call_stmt: String, + pub call_arg_index: String, + pub call_arg: String, +} + +impl From for FuncArgToCallArg { + fn from(record: StringRecord) -> Self { + // Parse and create SpreadFuncArgToCallArg + FuncArgToCallArg { + func_sign: record[0].to_string(), + func_arg_index: record[1].to_string(), + func_arg: record[2].to_string(), + call_stmt: record[3].to_string(), + call_arg_index: record[4].to_string(), + call_arg: record[5].to_string(), + } + } +} + +pub struct FuncArgToCallee { + pub func_sign: String, + pub func_arg_index: String, + pub func_arg: String, + pub call_stmt: String, + pub call_arg_index: String, +} +impl From for FuncArgToCallee { + fn from(record: StringRecord) -> Self { + // Parse and create SpreadFuncArgToCallee + FuncArgToCallee { + func_sign: record[0].to_string(), + func_arg_index: record[1].to_string(), + func_arg: record[2].to_string(), + call_stmt: record[3].to_string(), + call_arg_index: record[4].to_string(), + } + } +} + +pub struct FuncArgToFuncRet { + pub func_sign: String, + pub func_arg_index: String, + pub func_arg: String, + pub func_ret_index: String, + pub func_ret: String, +} + +impl From for FuncArgToFuncRet { + fn from(record: StringRecord) -> Self { + // Parse and create SpreadFuncArgToFuncRet + FuncArgToFuncRet { + func_sign: record[0].to_string(), + func_arg_index: record[1].to_string(), + func_arg: record[2].to_string(), + func_ret_index: record[3].to_string(), + func_ret: record[4].to_string(), + } + } +} diff --git a/src/flow/flow_analysis.rs b/src/flow/flow_analysis.rs index 27bd0dd..b5f65b6 100644 --- a/src/flow/flow_analysis.rs +++ b/src/flow/flow_analysis.rs @@ -1,35 +1,46 @@ -use std::cell::RefCell; use std::collections::HashSet; -use std::rc::Rc; use std::{collections::HashMap, error::Error, path::Path}; -use csv::{ReaderBuilder, StringRecord}; - use crate::contract::contract::Contract; -use crate::contract::data_structure::{self, ExternalCall, TaintedCallArg}; -use log::info; +use crate::contract::data_structure::{ + self, CallArgs, ExternalCall, FuncArgToSensitiveVar, TaintedCallArg, +}; +use csv::{ReaderBuilder, StringRecord}; +use log::{error, info}; const OUTPUT_PATH: &str = "./gigahorse-toolchain/"; const TEMP_PATH: &str = "./gigahorse-toolchain/.temp/"; -#[derive(Debug)] -struct ProgramPoint { - caller_addr: String, - call_site: String, - caller_func_sign: String, - target_contract_addr: String, - target_func_sign: String, - index: i32, - program_point_type: String, +#[derive(Debug, Clone)] +pub struct ProgramPoint { + pub caller_addr: String, + pub call_site: String, + pub caller_func_sign: String, + pub target_contract_addr: String, + pub target_func_sign: String, + pub index: String, + pub program_point_type: String, +} + +#[derive(Debug, PartialEq, Clone)] +struct ReachableSiteInfo { + caller: String, + caller_callback_func_sign: String, +} + +#[derive(Debug, PartialEq)] +struct ReenterInfo { + reenter_target: String, + reenter_func_sign: String, } pub struct FlowAnalysis<'a> { contracts: &'a HashMap, main_contract_sign_list: Vec, - external_call_in_func_sigature: HashSet, - visited_contracts: Vec, - visited_funcs: Vec, - intra_callsigs: Vec, - sensitive_callsigs: Vec, + external_call_in_func_signature: HashSet, + pub visited_contracts: HashSet, + pub visited_funcs: HashSet, + intra_callsigns: Vec, + sensitive_callsigns: Vec, attack_matrix: HashMap, victim_callback_info: HashMap>, attack_reenter_info: HashMap>, @@ -39,18 +50,18 @@ impl<'a> FlowAnalysis<'a> { pub fn new( contracts: &'a HashMap, main_contract_sign_list: Vec, - external_call_in_func_sigature: HashSet, - visited_contracts: Vec, - visited_funcs: Vec, + external_call_in_func_signature: HashSet, + visited_contracts: HashSet, + visited_funcs: HashSet, ) -> Self { FlowAnalysis { contracts: contracts, main_contract_sign_list: main_contract_sign_list, - external_call_in_func_sigature: external_call_in_func_sigature, + external_call_in_func_signature: external_call_in_func_signature, visited_contracts: visited_contracts, visited_funcs: visited_funcs, - intra_callsigs: Vec::new(), - sensitive_callsigs: Vec::new(), + intra_callsigns: Vec::new(), + sensitive_callsigns: Vec::new(), attack_matrix: HashMap::new(), victim_callback_info: HashMap::new(), attack_reenter_info: HashMap::new(), @@ -80,13 +91,17 @@ impl<'a> FlowAnalysis<'a> { let temp_address = key.split("_").collect::>()[2]; let temp_func_sign = key.split("_").collect::>()[3]; let mut br_analysis_df = Vec::new(); - self.read_csv::( - &format!( - "{}{}/out/Leslie_SensitiveOpOfBadRandomnessAfterExternalCall.csv", - TEMP_PATH, temp_address - ), - &mut br_analysis_df, - ); + if let Err(err) = self + .read_csv::( + &format!( + "{}{}/out/Leslie_SensitiveOpOfBadRandomnessAfterExternalCall.csv", + TEMP_PATH, temp_address + ), + &mut br_analysis_df, + ) + { + error!("Error reading CSV: {}", err); + } for br_analysis in br_analysis_df { if br_analysis.func_sign == temp_func_sign { return true; @@ -104,13 +119,17 @@ impl<'a> FlowAnalysis<'a> { let temp_address = key.split("_").collect::>()[2]; let temp_func_sign = key.split("_").collect::>()[3]; let mut dos_analysis_df = Vec::new(); - self.read_csv::( - &format!( - "{}{}/out/Leslie_SensitiveOpOfDoSAfterExternalCall.csv", - TEMP_PATH, temp_address - ), - &mut dos_analysis_df, - ); + if let Err(err) = self + .read_csv::( + &format!( + "{}{}/out/Leslie_SensitiveOpOfDoSAfterExternalCall.csv", + TEMP_PATH, temp_address + ), + &mut dos_analysis_df, + ) + { + error!("Error reading CSV: {}", err); + } for dos_analysis in dos_analysis_df { if dos_analysis.func_sign == temp_func_sign { return true; @@ -122,7 +141,249 @@ impl<'a> FlowAnalysis<'a> { return false; } - fn find_executed_pp( + // other intra analysis + pub fn op_multicreate_analysis(&self) -> bool { + for key in self.contracts.keys() { + if self.contracts[key].level == 0 { + let temp_address = key.split("_").collect::>()[2]; + let mut temp_func_sign = key.split("_").collect::>()[3]; + if key.contains("__function_selector__") { + temp_func_sign = "__function_selector__"; + } + let mut op_multicreate_analysis_df = Vec::new(); + if let Err(err) = self + .read_csv::( + &format!( + "{}{}/out/Leslie_Op_CreateInLoop.csv", + TEMP_PATH, temp_address + ), + &mut op_multicreate_analysis_df, + ) + { + error!("Error reading CSV: {}", err); + } + for op_multicreate_analysis in op_multicreate_analysis_df { + if op_multicreate_analysis.func_sign == temp_func_sign { + return true; + } + } + return false; + } + } + return false; + } + + pub fn op_solecreate_analysis(&self) -> bool { + for key in self.contracts.keys() { + if self.contracts[key].level == 0 { + let temp_address = key.split("_").collect::>()[2]; + let mut temp_func_sign = key.split("_").collect::>()[3]; + if key.contains("__function_selector__") { + temp_func_sign = "__function_selector__"; + } + let mut op_solecreate_analysis_df = Vec::new(); + if let Err(err) = self + .read_csv::( + &format!("{}{}/out/Leslie_Op_SoleCreate.csv", TEMP_PATH, temp_address), + &mut op_solecreate_analysis_df, + ) + { + error!("Error reading CSV: {}", err); + } + for op_solecreate_analysis in op_solecreate_analysis_df { + if op_solecreate_analysis.func_sign == temp_func_sign { + return true; + } + } + return false; + } + } + return false; + } + + pub fn op_selfdestruct_analysis(&self) -> bool { + for key in self.contracts.keys() { + if self.contracts[key].level == 0 { + let temp_address = key.split("_").collect::>()[2]; + let mut temp_func_sign = key.split("_").collect::>()[3]; + if key.contains("__function_selector__") { + temp_func_sign = "__function_selector__"; + } + let mut op_selfdestruct_analysis_df = Vec::new(); + if let Err(err) = self + .read_csv::( + &format!( + "{}{}/out/Leslie_Op_Selfdestruct.csv", + TEMP_PATH, temp_address + ), + &mut op_selfdestruct_analysis_df, + ) + { + error!("Error reading CSV: {}", err); + } + for op_selfdestruct_analysis in op_selfdestruct_analysis_df { + if op_selfdestruct_analysis.func_sign == temp_func_sign { + return true; + } + } + return false; + } + } + return false; + } + + fn spread_call_ret_func_ret( + &self, + contract_address: &str, + call_stmt: &str, + func_sign: &str, + ret_index: &str, + ) -> Vec { + let mut func_ret_index = Vec::new(); + let mut call_ret_func_ret_df = Vec::new(); + if let Err(err) = self.read_csv::( + &format!( + "{}{}/out/Leslie_Spread_CallRetToFuncRet.csv", + TEMP_PATH, contract_address + ), + &mut call_ret_func_ret_df, + ) { + error!("Error reading CSV: {}", err); + } + for call_ret_func_ret in call_ret_func_ret_df { + if call_ret_func_ret.func_sign == func_sign + && call_ret_func_ret.call_stmt == call_stmt + && call_ret_func_ret.call_ret_index == ret_index + { + func_ret_index.push(call_ret_func_ret.call_ret_index); + } + } + func_ret_index + } + + #[allow(unused_variables)] + fn spread_call_ret_call_arg( + &self, + contract_address: &str, + call_stmt: &str, + ret_index: &str, + ) -> Vec { + let mut call_args = Vec::new(); + let mut call_ret_call_arg_df = Vec::new(); + if let Err(err) = self.read_csv::( + &format!( + "{}{}/out/Leslie_Spread_CallRetToCallArg.csv", + TEMP_PATH, contract_address + ), + &mut call_ret_call_arg_df, + ) { + error!("Error reading CSV: {}", err); + } + for call_ret_call_arg in call_ret_call_arg_df { + if call_ret_call_arg.call_stmt1 == call_stmt + && call_ret_call_arg.call_ret_index == ret_index + { + call_args.push(CallArgs { + call_stmt: call_ret_call_arg.call_stmt2, + call_arg_index: call_ret_call_arg.call_arg_index, + }); + } + } + call_args + } + + #[allow(unused_variables)] + fn spread_func_arg_call_arg( + &self, + contract_address: &str, + func_sign: &str, + func_arg_index: &str, + ) -> Vec { + let mut call_args = Vec::new(); + let mut func_arg_call_arg_df = Vec::new(); + if let Err(err) = self.read_csv::( + &format!( + "{}{}/out/Leslie_Spread_FuncArgToCallArg.csv", + TEMP_PATH, contract_address + ), + &mut func_arg_call_arg_df, + ) { + error!("Error reading CSV: {}", err); + } + for func_arg_call_arg in func_arg_call_arg_df { + if func_arg_call_arg.func_sign == func_sign + && func_arg_call_arg.func_arg_index == func_arg_index + { + call_args.push(CallArgs { + call_stmt: func_arg_call_arg.call_stmt, + call_arg_index: func_arg_call_arg.call_arg_index, + }); + } + } + call_args + } + + #[allow(unused_variables)] + fn spread_func_arg_callee( + &self, + contract_address: &str, + func_sign: &str, + func_arg_index: &str, + ) -> Vec { + let mut call_args = Vec::new(); + let mut func_arg_callee_df = Vec::new(); + if let Err(err) = self.read_csv::( + &format!( + "{}{}/out/Leslie_Spread_FuncArgToCalleeVar.csv", + TEMP_PATH, contract_address + ), + &mut func_arg_callee_df, + ) { + error!("Error reading CSV: {}", err); + } + for func_arg_callee in func_arg_callee_df { + if func_arg_callee.func_sign == func_sign + && func_arg_callee.func_arg_index == func_arg_index + { + call_args.push(CallArgs { + call_stmt: func_arg_callee.call_stmt, + call_arg_index: func_arg_callee.func_arg_index, + }); + } + } + call_args + } + + #[allow(unused_variables)] + fn spread_func_arg_func_ret( + &self, + contract_address: &str, + func_sign: &str, + func_arg_index: &str, + ) -> Vec { + let mut func_ret_index = Vec::new(); + let mut func_arg_func_ret_df = Vec::new(); + if let Err(err) = self.read_csv::( + &format!( + "{}{}/out/Leslie_Spread_CallRetToFuncRet.csv", + TEMP_PATH, contract_address + ), + &mut func_arg_func_ret_df, + ) { + error!("Error reading CSV: {}", err); + } + for func_arg_func_ret in func_arg_func_ret_df { + if func_arg_func_ret.func_sign == func_sign + && func_arg_func_ret.func_arg_index == func_arg_index + { + func_ret_index.push(func_arg_func_ret.func_ret_index); + } + } + func_ret_index + } + + #[allow(unused_variables)] + fn find_executed_program_point( &self, caller: &str, call_site: &str, @@ -155,18 +416,23 @@ impl<'a> FlowAnalysis<'a> { call_site: &str, target_contract_addr: &str, target_func_sign: &str, - index: i32, + index: &str, caller_func_sign: &str, program_point_type: &str, ) -> ProgramPoint { - let addr = self.find_executed_pp(caller, call_site, target_contract_addr, target_func_sign); + let addr = self.find_executed_program_point( + caller, + call_site, + target_contract_addr, + target_func_sign, + ); ProgramPoint { caller_addr: caller.to_string(), call_site: call_site.to_string(), caller_func_sign: caller_func_sign.to_string(), target_contract_addr: addr, target_func_sign: target_func_sign.to_string(), - index: index, + index: index.to_string(), program_point_type: program_point_type.to_string(), } } @@ -188,6 +454,7 @@ impl<'a> FlowAnalysis<'a> { None } + #[allow(unused_variables)] fn get_call_args_flow_from_sources( &self, contract_addr: &str, @@ -195,33 +462,37 @@ impl<'a> FlowAnalysis<'a> { ) -> Vec { let mut call_args = Vec::new(); let mut tainted_call_arg_df = Vec::new(); - self.read_csv::( + if let Err(err) = self.read_csv::( &format!( "{}{}/out/Leslie_TaintedCallArg.csv", TEMP_PATH, contract_addr ), &mut tainted_call_arg_df, - ); + ) { + error!("Error reading CSV: {}", err); + } for result in tainted_call_arg_df { - let call_arg = TaintedCallArg { - call_stmt: result.call_stmt.clone(), - func_sign: result.func_sign.clone(), - call_arg_index: result.call_arg_index.clone(), - }; - call_args.push(call_arg); + if func_sign == result.func_sign { + let call_arg = TaintedCallArg { + call_stmt: result.call_stmt.clone(), + func_sign: result.func_sign.clone(), + call_arg_index: result.call_arg_index.clone(), + }; + call_args.push(call_arg); + } } // info!("call args of {}: {:?}", contract_addr, call_args); call_args } - fn get_pps_near_source(&self) -> Vec { + fn get_program_points_near_source(&self) -> Vec { let mut pps_near_source = Vec::new(); for (key, contract) in self.contracts.iter() { // info!("contract key: {}", key); let parts: Vec<&str> = key.split('_').collect(); - let (temp_caller, temp_callsite, temp_address, temp_func_sign) = + let (_temp_caller, _temp_callsite, temp_address, _temp_func_sign) = (parts[0], parts[1], parts[2], parts[3]); if contract.level == 0 { @@ -245,7 +516,7 @@ impl<'a> FlowAnalysis<'a> { &temp_call_arg.call_stmt, &temp_external_call_logic_addr, &temp_external_call_func_sign, - temp_call_arg.call_arg_index.parse::().unwrap(), + &temp_call_arg.call_arg_index, temp_caller_func_sign, "call_arg", )); @@ -258,9 +529,283 @@ impl<'a> FlowAnalysis<'a> { pps_near_source } + #[allow(unused_variables)] + fn get_func_args_flow_to_sink( + &self, + contract_addr: &str, + func_sign: &str, + ) -> (Vec, Vec) { + let mut func_args = Vec::new(); + let mut sensitive_call_signs = Vec::new(); + + let mut func_arg_to_sensitive_var_df: Vec = Vec::new(); + if let Err(err) = self.read_csv::( + &format!( + "{}{}/out/Leslie_FuncArgToSensitiveVar.csv", + TEMP_PATH, contract_addr + ), + &mut func_arg_to_sensitive_var_df, + ) { + error!("Error reading CSV: {}", err); + } + + for result in func_arg_to_sensitive_var_df { + if result.func_sign == func_sign { + let func_arg = FuncArgToSensitiveVar { + func_sign: result.func_sign.clone(), + call_stmt: result.call_stmt.clone(), + func_arg: result.func_arg.clone(), + func_arg_index: result.func_arg_index.clone(), + sensitive_var: result.sensitive_var.clone(), + call_func_sign: result.call_func_sign.clone(), + }; + func_args.push(func_arg); + sensitive_call_signs.push(result.call_func_sign.clone().replace( + "00000000000000000000000000000000000000000000000000000000", + "", + )); + } + } + // info!("call args of {}: {:?}", contract_addr, call_args); + (func_args, sensitive_call_signs) + } + + fn get_program_points_near_sink(&self) -> (Vec, Vec) { + let mut program_points_near_sink = Vec::new(); + let mut sensitive_callsigs = Vec::new(); + + for (key, _value) in self.contracts.iter() { + let parts: Vec<&str> = key.split('_').collect(); + let _temp_caller = parts[0]; + let _temp_callsite = parts[1]; + let temp_address = parts[2]; + let temp_func_sign = parts[3]; + let _temp_caller_func_sign = parts[4]; + // log information if needed + + let (temp_call_args, signs_func_arg) = + self.get_func_args_flow_to_sink(temp_address, temp_func_sign); + + if !temp_call_args.is_empty() { + for temp_call_arg in temp_call_args { + let ( + temp_external_call_caller, + temp_external_call_logic_addr, + temp_external_call_func_sign, + ) = self + .get_external_call_info( + &temp_call_arg.call_stmt, + &self.contracts[key].external_calls, + ) + .unwrap(); + + program_points_near_sink.push(self.get_new_program_point( + &temp_external_call_caller, + &temp_call_arg.call_stmt, + &temp_external_call_logic_addr, + &temp_external_call_func_sign, + &temp_call_arg.func_arg_index, + &self.contracts[key].func_sign, + "call_arg", + )); + } + // log information if needed + } + + for signs in signs_func_arg { + sensitive_callsigs.push(signs); + } + } + + (program_points_near_sink, sensitive_callsigs) + } + + fn find_parent( + &self, + logic_addr: &str, + func_sign: &str, + caller: &str, + call_site: &str, + ) -> Option<&Contract> { + for (_, contract) in self.contracts.iter() { + for external_call in &contract.external_calls { + if external_call.target_logic_addr == logic_addr + && external_call.target_func_sign == func_sign + && external_call.caller_addr == caller + && external_call.call_site == call_site + { + return Some(contract); + } + } + } + None + } + + fn find_contract( + &self, + caller: &str, + callsite: &str, + contract_addr: &str, + func_sign: &str, + caller_func_sign: &str, + ) -> Option<&Contract> { + let key = format!( + "{}_{}_{}_{}_{}", + caller, callsite, contract_addr, func_sign, caller_func_sign + ); + self.contracts.get(&key) + } + + fn is_same(&self, first: &ProgramPoint, second: &ProgramPoint) -> bool { + first.caller_addr == second.caller_addr + && first.call_site == second.call_site + && first.target_func_sign == second.target_func_sign + && first.index == second.index + && first.program_point_type == second.program_point_type + && first.caller_func_sign == second.caller_func_sign + } + + fn is_reachable(&self, first: &ProgramPoint, second: &ProgramPoint) -> bool { + if self.is_same(first, second) { + return true; + } + let mut pending = vec![first.clone()]; + while let Some(temp) = pending.pop() { + for program_point in self.transfer(&temp) { + if self.is_same(&program_point, second) { + return true; + } + pending.push(program_point); + } + } + false // Return false when the loop has zero elements to iterate on + } + + fn transfer(&self, program_point: &ProgramPoint) -> Vec { + let mut next_program_points = Vec::new(); + + // Assuming find_parent returns an Option<&Contract> + let parent_contract = self.find_parent( + &program_point.target_contract_addr, + &program_point.target_func_sign, + &program_point.caller_addr, + &program_point.call_site, + ); + let child_contract = match self.find_contract( + &program_point.caller_addr, + &program_point.call_site, + &program_point.target_contract_addr, + &program_point.target_func_sign, + &program_point.caller_func_sign, + ) { + Some(contract) => contract, + None => return next_program_points, + }; + + match program_point.program_point_type.as_str() { + "func_ret" => { + // Implement logic for "func_ret" + if let Some(parent) = parent_contract { + let indexes = self.spread_call_ret_func_ret( + &program_point.caller_addr, + &program_point.call_site, + &parent.func_sign, + &program_point.index, + ); + for index in indexes.iter() { + next_program_points.push(self.get_new_program_point( + &parent.caller, + &parent.call_site, + &parent.logic_addr, + &parent.func_sign, + index, + &program_point.caller_func_sign, + "func_ret", + )) + } + } + let call_args = self.spread_call_ret_call_arg( + &program_point.target_contract_addr, + &program_point.call_site, + &program_point.index, + ); + for call_arg in call_args.iter() { + let (temp_caller, temp_logic_addr, temp_func_sign) = self + .get_external_call_info(&call_arg.call_stmt, &child_contract.external_calls) + .unwrap(); + next_program_points.push(self.get_new_program_point( + &temp_caller, + &call_arg.call_stmt, + &temp_logic_addr, + &temp_func_sign, // temp func sign is the called function that lies in the attacker contract + &call_arg.call_arg_index, + &program_point.target_func_sign, // pp[func_sign] is the function that calls back to attacker contract + "call_arg", + )) + } + } + "call_arg" => { + let mut call_args: Vec = Vec::new(); + call_args.extend(self.spread_func_arg_call_arg( + &program_point.target_contract_addr, + &program_point.target_func_sign, + &program_point.index, + )); + call_args.extend(self.spread_func_arg_callee( + &program_point.target_contract_addr, + &program_point.target_func_sign, + &program_point.index, + )); + + for call_arg in call_args.iter() { + let temp_result = self.get_external_call_info( + &call_arg.call_stmt, + &child_contract.external_calls, + ); + + if temp_result != None { + let (_temp_caller, temp_logic_addr, temp_func_sign) = temp_result.unwrap(); + next_program_points.push(self.get_new_program_point( + &program_point.target_contract_addr, + &call_arg.call_stmt, + &temp_logic_addr, + &temp_func_sign, + &call_arg.call_arg_index, + &program_point.target_func_sign, + "call_arg", + )) + } + } + // the return index of the function call + let indexes = self.spread_func_arg_func_ret( + &program_point.target_contract_addr, + &program_point.target_func_sign, + &program_point.index, + ); + for index in indexes.iter() { + next_program_points.push(self.get_new_program_point( + &program_point.caller_addr, + &program_point.call_site, + &program_point.target_contract_addr, + &program_point.target_func_sign, + index, + &program_point.caller_func_sign, + "func_ret", + )); + } + } + _ => (), + } + + next_program_points.clone() + } + pub fn detect(&mut self) -> (bool, HashMap) { let mut cross_contract = false; - for (key, contract) in self.contracts.iter() { + self.attack_matrix.insert("br".to_string(), false); + self.attack_matrix.insert("dos".to_string(), false); + self.attack_matrix.insert("reentrancy".to_string(), false); + for (_key, contract) in self.contracts.iter() { if contract.level != 0 { cross_contract = true; break; @@ -276,75 +821,95 @@ impl<'a> FlowAnalysis<'a> { // Assuming intraprocedural_br_analysis and intraprocedural_dos_analysis are methods returning bool if self.intraprocedural_br_analysis() { self.attack_matrix.insert("br".to_string(), true); - } else { - self.attack_matrix.insert("br".to_string(), false); } if self.intraprocedural_dos_analysis() { self.attack_matrix.insert("dos".to_string(), true); - } else { - self.attack_matrix.insert("dos".to_string(), false); } - - let source = self.get_pps_near_source(); + let source = self.get_program_points_near_source(); // info!("pp near source: {:?}", source); - // Assuming get_pps_near_source and get_pps_near_sink return appropriate data structures - let mut reachable_site: HashMap = HashMap::new(); - - // Assuming is_same and is_reachable are methods - // for pp1 in &pps_near_source { - // for pp2 in &pps_near_sink { - // if self.is_same(pp1, pp2) || self.is_reachable(pp1, pp2) { - // reachable = true; - // let caller = pp2.caller.clone(); - // let caller_func_sign = pp2.caller_func_sign.clone(); - // reachable_site.insert(pp2.func_sign.clone(), (caller, caller_func_sign)); - // } - // } - // } - - // let mut victim_callback_info = HashMap::new(); - // let mut attack_reenter_info = HashMap::new(); - - // if reachable { - // let overlap: HashSet<_> = sensitive_callsigs - // .intersection(&self.external_call_in_func_signature) - // .collect(); - // if !overlap.is_empty() { - // for i in &overlap { - // victim_callback_info - // .entry(i.clone()) - // .or_insert_with(Vec::new); - // attack_reenter_info - // .entry(i.clone()) - // .or_insert_with(Vec::new); - - // if let Some(site) = reachable_site.get(i) { - // if !victim_callback_info.get_mut(i).unwrap().contains(site) { - // victim_callback_info.get_mut(i).unwrap().push(site.clone()); - // } - // } - // for (key, contract) in &self.contracts { - // if contract.func_sign == *i && contract.level == 0 { - // for ec in &contract.external_calls { - // let temp_target_address = ec.logic_addr.clone(); - // let temp_func_sign = ec.func_sign.clone(); - // let res = (temp_target_address.clone(), temp_func_sign.clone()); - // if !attack_reenter_info.get_mut(i).unwrap().contains(&res) - // && self.visited_contracts.contains(&temp_target_address) - // && self.visited_funcs.contains(&temp_func_sign) - // { - // attack_reenter_info.get_mut(i).unwrap().push(res); - // } - // } - // result = true; - // self.attack_matrix.insert("reentrancy".to_string(), true); - // } - // } - // } - // } - // } + let (sink, sensitive_call_signs) = self.get_program_points_near_sink(); + + self.sensitive_callsigns = sensitive_call_signs; + + let mut reachable: bool = false; + let mut reachable_site: HashMap = HashMap::new(); + + for program_point_source in &source { + for program_point_sink in &sink { + if self.is_same(program_point_source, program_point_sink) + || self.is_reachable(program_point_source, program_point_sink) + { + reachable = true; + result = true; + let caller = program_point_sink.caller_addr.clone(); + let caller_func_sign = program_point_sink.caller_func_sign.clone(); + reachable_site.insert( + program_point_sink.target_func_sign.clone(), + ReachableSiteInfo { + caller: caller, + caller_callback_func_sign: caller_func_sign, + }, + ); + } + } + } + + let mut victim_callback_info: HashMap> = HashMap::new(); + let mut attacker_reenter_info: HashMap> = HashMap::new(); + if reachable { + let sensitive_call_signs_set: HashSet<_> = + self.sensitive_callsigns.iter().cloned().collect(); + let overlap: HashSet<_> = sensitive_call_signs_set + .intersection( + &self + .external_call_in_func_signature + .iter() + .cloned() + .collect(), + ) + .cloned() + .collect(); + if !overlap.is_empty() { + for i in overlap { + // initialize + victim_callback_info + .entry(i.clone()) + .or_insert_with(Vec::new); + attacker_reenter_info + .entry(i.clone()) + .or_insert_with(Vec::new); + + if let Some(site) = reachable_site.get(&i) { + let entry = victim_callback_info.entry(i.clone()).or_default(); + if !entry.contains(site) { + entry.push(site.clone()); + } + } + for (_, contract) in self.contracts { + if contract.func_sign.eq(&i) && contract.level == 0 { + for ec in &contract.external_calls { + let res = ReenterInfo { + reenter_target: ec.target_logic_addr.clone(), + reenter_func_sign: ec.target_func_sign.clone(), + }; + + let entry = attacker_reenter_info.entry(i.clone()).or_default(); + if !entry.contains(&res) + && self.visited_contracts.contains(&res.reenter_target) + && self.visited_funcs.contains(&res.reenter_func_sign) + { + entry.push(res); + } + } + result = true; + self.attack_matrix.insert("reentrancy".to_string(), true); + } + } + } + } + } // self.victim_callback_info = victim_callback_info; // self.attack_reenter_info = attack_reenter_info; diff --git a/src/graph/call_graph.rs b/src/graph/call_graph.rs index 05bc7e7..01efb3f 100644 --- a/src/graph/call_graph.rs +++ b/src/graph/call_graph.rs @@ -1,13 +1,11 @@ use crate::{contract::contract::Contract, Source}; -use std::cell::RefCell; use std::collections::{HashMap, HashSet}; -use std::rc::Rc; #[allow(dead_code)] pub struct CallGraph<'a> { output: String, visited_contracts: HashSet, visited_funcs: HashSet, - max_level: i32, + pub max_level: i32, platform: String, contracts: &'a mut HashMap, } @@ -27,8 +25,12 @@ impl<'a> CallGraph<'a> { &self.output } - pub fn get_contracts(&self) -> &HashMap { - &self.contracts + pub fn get_visited_contracts(&self) -> &HashSet { + &self.visited_contracts + } + + pub fn get_visited_funcs(&self) -> &HashSet { + &self.visited_funcs } pub async fn construct_cross_contract_call_graph( @@ -74,6 +76,7 @@ impl<'a> CallGraph<'a> { continue; } self.visited_funcs.insert(temp.func_sign.clone()); + self.visited_contracts.insert(temp.logic_addr.clone()); let mut new_contract = Contract::new( temp.platform.clone(), diff --git a/src/main.rs b/src/main.rs index db6a8d3..9866e3d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,13 +11,10 @@ use crate::outputter::result_structure::{ #[allow(unused_imports)] use log::{debug, error, info, log_enabled, Level}; use serde_json; -use std::borrow::BorrowMut; -use std::cell::RefCell; use std::collections::{HashMap, HashSet}; use std::env; use std::fs::File; use std::io::Write; -use std::rc::Rc; use std::time::Instant; const JSON_PATH: &str = "./output/"; @@ -79,17 +76,15 @@ async fn main() { let external_call_in_func_signature = external_call_in_func_signature.clone(); - let visited_contracts: Vec = Vec::new(); - let visited_funcs: Vec = Vec::new(); + let mut visited_contracts: HashSet = HashSet::new(); + let mut visited_funcs: HashSet = HashSet::new(); let mut call_path: Vec = Vec::new(); - // let m_call_depth: u32 = 0; - // let call_graph_str: String = String::new(); + let mut max_call_depth: i32 = 0; let mut contracts = HashMap::new(); if input_contract.is_createbin().clone() { } else { - let mut max_call_depth: u32 = 0; for func_sign in external_call_in_func_signature.clone().into_iter() { // let mut contracts_mut = contracts.borrow_mut(); println!("{}", func_sign); @@ -113,7 +108,14 @@ async fn main() { { eprintln!("An error occurred during call graph construction: {}", e); }; - let call_graph_str = cross_contract_call_graph.get_output(); + let call_graph_str: &str = cross_contract_call_graph.get_output(); + visited_contracts.extend(cross_contract_call_graph.get_visited_contracts().clone()); + visited_funcs.extend(cross_contract_call_graph.get_visited_funcs().clone()); + + if cross_contract_call_graph.max_level > max_call_depth { + max_call_depth = cross_contract_call_graph.max_level; + } + call_path.push(call_graph_str.to_string()); println!("{}", call_graph_str); } @@ -138,10 +140,6 @@ async fn main() { external_call: ExternalCall { externalcall_inhook: false, externalcall_infallback: false, - hooks_focused: vec![ - "tokensReceived".to_string(), - // ... add other strings here - ], }, call_paths: Vec::new(), visited_contracts: Vec::new(), @@ -171,6 +169,11 @@ async fn main() { result.is_attack = res_bool; result.attack_matrix = res; result.call_paths = call_path; + result.max_call_depth = max_call_depth as u32; + result.visited_contracts = detector.visited_contracts.clone().drain().collect(); + result.visited_contracts_num = result.visited_contracts.len(); + result.visited_funcs = detector.visited_funcs.clone().drain().collect(); + result.visited_funcs_num = result.visited_funcs.len(); let serialized = serde_json::to_string_pretty(&result).unwrap(); let mut file = File::create(format!("{}{}.json", JSON_PATH, logic_address)).unwrap(); diff --git a/src/outputter/result_structure.rs b/src/outputter/result_structure.rs index 277424d..1c22d5e 100644 --- a/src/outputter/result_structure.rs +++ b/src/outputter/result_structure.rs @@ -42,7 +42,7 @@ pub struct OpCreation { pub struct ExternalCall { pub externalcall_inhook: bool, pub externalcall_infallback: bool, - pub hooks_focused: Vec, + // pub hooks_focused: Vec, } #[derive(Serialize, Deserialize)] pub struct Overlap {