From fc6b2091c7b1fbc2e4f8f6827580d6f267f050c1 Mon Sep 17 00:00:00 2001 From: Jose Garcia Crosta Date: Fri, 30 Aug 2024 12:38:43 -0300 Subject: [PATCH] Remove stuff that is in utils --- .../unprotected-mapping-operation/src/lib.rs | 92 ++++++------------- 1 file changed, 28 insertions(+), 64 deletions(-) diff --git a/detectors/unprotected-mapping-operation/src/lib.rs b/detectors/unprotected-mapping-operation/src/lib.rs index f5122e44..c4575530 100644 --- a/detectors/unprotected-mapping-operation/src/lib.rs +++ b/detectors/unprotected-mapping-operation/src/lib.rs @@ -1,17 +1,14 @@ #![feature(rustc_private)] -extern crate rustc_ast; extern crate rustc_hir; extern crate rustc_middle; extern crate rustc_span; -use std::collections::{HashMap, HashSet}; - use clippy_wrappers::span_lint_and_help; use if_chain::if_chain; use rustc_hir::{ intravisit::{walk_expr, FnKind, Visitor}, - Body, Expr, ExprKind, FnDecl, HirId, + Body, Expr, ExprKind, FnDecl, }; use rustc_lint::{LateContext, LateLintPass}; use rustc_middle::ty::{GenericArgKind, Ty, TyKind}; @@ -19,7 +16,10 @@ use rustc_span::{ def_id::{DefId, LocalDefId}, Span, Symbol, }; -use utils::FunctionCallVisitor; +use std::collections::{HashMap, HashSet}; +use utils::{ + get_node_type_opt, is_soroban_address, is_soroban_function, is_soroban_map, FunctionCallVisitor, +}; const LINT_MESSAGE: &str = "This mapping operation is called without access control on a different key than the caller's address"; @@ -37,9 +37,6 @@ dylint_linting::impl_late_lint! { } } -const SOROBAN_MAP: &str = "soroban_sdk::Map"; -const SOROBAN_ADDRESS: &str = "soroban_sdk::Address"; - #[derive(Default)] struct UnprotectedMappingOperation { function_call_graph: HashMap>, @@ -48,42 +45,17 @@ struct UnprotectedMappingOperation { unauthorized_mapping_calls: HashMap>, } -impl<'tcx> UnprotectedMappingOperation { - fn is_soroban_function(&self, cx: &LateContext<'tcx>, def_id: &DefId) -> bool { - let def_path_str = cx.tcx.def_path_str(*def_id); - let mut parts = def_path_str.rsplitn(2, "::"); - - let function_name = parts.next().unwrap(); - let contract_path = parts.next().unwrap_or(""); - - if contract_path.is_empty() { - return false; - } - - // Define the patterns to check against - let patterns = [ - format!("{}Client::<'a>::try_{}", contract_path, function_name), - format!("{}::{}", contract_path, function_name), - format!("{}::spec_xdr_{}", contract_path, function_name), - format!("{}Client::<'a>::{}", contract_path, function_name), - ]; - - patterns - .iter() - .all(|pattern| self.checked_functions.contains(pattern.as_str())) - } -} - impl<'tcx> LateLintPass<'tcx> for UnprotectedMappingOperation { fn check_crate_post(&mut self, cx: &LateContext<'tcx>) { for (callee_def_id, mapping_spans) in &self.unauthorized_mapping_calls { - let is_callee_soroban = self.is_soroban_function(cx, callee_def_id); + let is_callee_soroban = is_soroban_function(cx, &self.checked_functions, callee_def_id); let (is_called_by_soroban, is_soroban_caller_authed) = self .function_call_graph .iter() .fold((false, true), |acc, (caller, callees)| { if callees.contains(callee_def_id) { - let is_caller_soroban = self.is_soroban_function(cx, caller); + let is_caller_soroban = + is_soroban_function(cx, &self.checked_functions, caller); // Update if the caller is Soroban and check if it's authorized only if it's a Soroban caller ( acc.0 || is_caller_soroban, @@ -167,31 +139,22 @@ impl<'tcx> UnprotectedMappingOperationVisitor<'_, 'tcx> { // Check that the receiver expression is a field (e.g., accessing a struct's field). if let ExprKind::Field(..) = &receiver.kind; - // Verify that the type of the receiver is an ADT corresponding to 'soroban_sdk::Map'. - if let TyKind::Adt(map_adt_def, args) = receiver_type.kind(); - if self.cx.tcx.def_path_str(map_adt_def.did()) == SOROBAN_MAP; + // Verify that the type of the receiver is a 'soroban_sdk::Map'. + if is_soroban_map(self.cx, receiver_type); // Retrieve the first generic argument, ensure it exists and is of type Ty. + if let TyKind::Adt(_, args) = receiver_type.kind(); if let Some(first_arg) = args.first(); if let GenericArgKind::Type(first_type) = first_arg.unpack(); - // Verify that the type of the receiver is an ADT corresponding to 'soroban_sdk::Address'. - if let Some(address_adt_def) = first_type.ty_adt_def(); - if self.cx.tcx.def_path_str(address_adt_def.did()) == SOROBAN_ADDRESS; + // Verify that the type of the first argument is 'soroban_sdk::Address'. + if is_soroban_address(self.cx, first_type); then { return true; } } false } - - fn is_soroban_address(&self, type_: Ty<'tcx>) -> bool { - type_.to_string().contains(SOROBAN_ADDRESS) - } - - fn get_node_type(&self, hir_id: HirId) -> Ty<'tcx> { - self.cx.typeck_results().node_type(hir_id) - } } impl<'a, 'tcx> Visitor<'tcx> for UnprotectedMappingOperationVisitor<'a, 'tcx> { @@ -202,21 +165,22 @@ impl<'a, 'tcx> Visitor<'tcx> for UnprotectedMappingOperationVisitor<'a, 'tcx> { if let ExprKind::MethodCall(path_segment, receiver, _args, _) = &expr.kind { // Get the method expression type and check if it's a map with address - let receiver_type = self.get_node_type(receiver.hir_id); - - // Check if the method call is require_auth() on an address - if self.is_soroban_address(receiver_type) - && path_segment.ident.name == Symbol::intern("require_auth") - { - self.auth_found = true; - } + let receiver_type = get_node_type_opt(self.cx, &receiver.hir_id); + if let Some(type_) = receiver_type { + // Check if the method call is require_auth() on an address + if is_soroban_address(self.cx, type_) + && path_segment.ident.name == Symbol::intern("require_auth") + { + self.auth_found = true; + } - // Look for usage of soroban map with address - // Anything that looks like `soroban_sdk::Map::` is in our interest - if self.is_soroban_map_with_address(receiver, receiver_type) - && path_segment.ident.name == Symbol::intern("set") - { - self.mapping_spans.push(expr.span); + // Look for usage of soroban map with address + // Anything that looks like `soroban_sdk::Map::` is in our interest + if self.is_soroban_map_with_address(receiver, type_) + && path_segment.ident.name == Symbol::intern("set") + { + self.mapping_spans.push(expr.span); + } } }