Skip to content

Commit

Permalink
Fix detector
Browse files Browse the repository at this point in the history
  • Loading branch information
jgcrosta committed Nov 5, 2024
1 parent 1200958 commit 9a472db
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 102 deletions.
131 changes: 36 additions & 95 deletions detectors/unsafe-unwrap/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ use rustc_hir::{
def::Res,
def_id::LocalDefId,
intravisit::{walk_expr, FnKind, Visitor},
BinOpKind, Body, Expr, ExprKind, FnDecl, HirId, LetStmt, PathSegment, QPath, UnOp,
BinOpKind, Body, Expr, ExprKind, FnDecl, HirId, PathSegment, QPath, UnOp,
};
use rustc_lint::{LateContext, LateLintPass};
use rustc_span::{sym, Span, Symbol};
use std::{collections::HashSet, hash::Hash};
use utils::{fn_returns, get_node_type_opt, match_type_to_str};
use utils::{fn_returns, get_node_type_opt, match_type_to_str, ConstantAnalyzer};

const LINT_MESSAGE: &str = "Unsafe usage of `unwrap`";
const PANIC_INDUCING_FUNCTIONS: [&str; 2] = ["panic", "bail"];
Expand Down Expand Up @@ -154,13 +154,14 @@ impl ConditionalChecker {
/// Main unsafe-unwrap visitor
struct UnsafeUnwrapVisitor<'a, 'tcx> {
cx: &'a LateContext<'tcx>,
constant_analyzer: ConstantAnalyzer<'a, 'tcx>,
conditional_checker: HashSet<ConditionalChecker>,
checked_exprs: HashSet<HirId>,
linted_spans: HashSet<Span>,
returns_result_or_option: bool,
}

impl UnsafeUnwrapVisitor<'_, '_> {
impl<'a, 'tcx> UnsafeUnwrapVisitor<'a, 'tcx> {
fn get_help_message(&self, unwrap_type: UnwrapType) -> &'static str {
match (self.returns_result_or_option, unwrap_type) {
(true, UnwrapType::Option) => "Consider using `ok_or` to convert Option to Result",
Expand All @@ -174,7 +175,7 @@ impl UnsafeUnwrapVisitor<'_, '_> {
}
}

fn determine_unwrap_type(&self, receiver: &Expr<'_>) -> UnwrapType {
fn determine_unwrap_type(&self, receiver: &Expr<'tcx>) -> UnwrapType {
let type_opt = get_node_type_opt(self.cx, &receiver.hir_id);
if let Some(type_) = type_opt {
if match_type_to_str(self.cx, type_, "Result") {
Expand All @@ -184,7 +185,7 @@ impl UnsafeUnwrapVisitor<'_, '_> {
UnwrapType::Option
}

fn is_panic_inducing_call(&self, func: &Expr<'_>) -> bool {
fn is_panic_inducing_call(&self, func: &Expr<'tcx>) -> bool {
if let ExprKind::Path(QPath::Resolved(_, path)) = &func.kind {
return PANIC_INDUCING_FUNCTIONS.iter().any(|&func| {
path.segments
Expand All @@ -195,7 +196,7 @@ impl UnsafeUnwrapVisitor<'_, '_> {
false
}

fn get_unwrap_info(&self, receiver: &Expr<'_>) -> Option<HirId> {
fn get_unwrap_info(&self, receiver: &Expr<'tcx>) -> Option<HirId> {
if_chain! {
if let ExprKind::Path(QPath::Resolved(_, path)) = &receiver.kind;
if let Res::Local(hir_id) = path.res;
Expand Down Expand Up @@ -235,110 +236,43 @@ impl UnsafeUnwrapVisitor<'_, '_> {
});
}

fn is_literal_or_composed_of_literals(&self, expr: &Expr<'_>) -> bool {
let mut stack = vec![expr];

while let Some(current_expr) = stack.pop() {
match current_expr.kind {
ExprKind::Lit(_) => continue, // A literal is fine, continue processing.
ExprKind::Tup(elements) | ExprKind::Array(elements) => {
stack.extend(elements);
}
ExprKind::Struct(_, fields, _) => {
for field in fields {
stack.push(field.expr);
}
}
ExprKind::Repeat(element, _) => {
stack.push(element);
}
_ => return false, // If any element is not a literal or a compound of literals, return false.
fn is_method_call_unsafe(&mut self, path_segment: &PathSegment, receiver: &Expr<'tcx>) -> bool {
if path_segment.ident.name == sym::unwrap {
if self.constant_analyzer.is_constant(receiver) {
return false;
}
}

true // If the stack is emptied without finding a non-literal, all elements are literals.
}

fn is_method_call_unsafe(
&self,
path_segment: &PathSegment,
receiver: &Expr,
args: &[Expr],
) -> bool {
if path_segment.ident.name == sym::unwrap {
return self
.get_unwrap_info(receiver)
.map_or(true, |id| !self.checked_exprs.contains(&id));
}

args.iter().any(|arg| self.contains_unsafe_method_call(arg))
|| self.contains_unsafe_method_call(receiver)
}

fn contains_unsafe_method_call(&self, expr: &Expr) -> bool {
match &expr.kind {
ExprKind::MethodCall(path_segment, receiver, args, _) => {
self.is_method_call_unsafe(path_segment, receiver, args)
}
_ => false,
}
false
}

fn check_expr_for_unsafe_unwrap(&mut self, expr: &Expr) {
match &expr.kind {
ExprKind::MethodCall(path_segment, receiver, args, _) => {
if self.is_method_call_unsafe(path_segment, receiver, args)
&& !self.linted_spans.contains(&expr.span)
{
let unwrap_type = self.determine_unwrap_type(receiver);
let help_message = self.get_help_message(unwrap_type);
span_lint_and_help(
self.cx,
UNSAFE_UNWRAP,
expr.span,
LINT_MESSAGE,
None,
help_message,
);
self.linted_spans.insert(expr.span);
}
}
ExprKind::Call(func, args) => {
if let ExprKind::Path(QPath::Resolved(_, path)) = &func.kind {
let is_some_or_ok = path
.segments
.iter()
.any(|segment| matches!(segment.ident.name, sym::Some | sym::Ok));
let all_literals = args
.iter()
.all(|arg| self.is_literal_or_composed_of_literals(arg));
if is_some_or_ok && all_literals {
self.checked_exprs.insert(expr.hir_id);
return;
}
}
// Check arguments for unsafe expect
for arg in args.iter() {
self.check_expr_for_unsafe_unwrap(arg);
}
}
ExprKind::Tup(exprs) | ExprKind::Array(exprs) => {
for expr in exprs.iter() {
self.check_expr_for_unsafe_unwrap(expr);
}
fn check_expr_for_unsafe_unwrap(&mut self, expr: &Expr<'tcx>) {
if let ExprKind::MethodCall(path_segment, receiver, _, _) = &expr.kind {
if self.is_method_call_unsafe(path_segment, receiver)
&& !self.linted_spans.contains(&expr.span)
{
let unwrap_type = self.determine_unwrap_type(receiver);
let help_message = self.get_help_message(unwrap_type);

span_lint_and_help(
self.cx,
UNSAFE_UNWRAP,
expr.span,
LINT_MESSAGE,
None,
help_message,
);
self.linted_spans.insert(expr.span);
}
_ => {}
}
}
}

impl<'a, 'tcx> Visitor<'tcx> for UnsafeUnwrapVisitor<'a, 'tcx> {
fn visit_local(&mut self, local: &'tcx LetStmt<'tcx>) -> Self::Result {
if let Some(init) = local.init {
self.check_expr_for_unsafe_unwrap(init);
}
}

fn visit_expr(&mut self, expr: &'tcx Expr<'_>) {
// If we are inside an `if` or `if let` expression, we analyze its body
if !self.conditional_checker.is_empty() {
Expand Down Expand Up @@ -388,8 +322,15 @@ impl<'tcx> LateLintPass<'tcx> for UnsafeUnwrap {
return;
}

let mut constant_analyzer = ConstantAnalyzer {
cx,
constants: HashSet::new(),
};
constant_analyzer.visit_body(body);

let mut visitor = UnsafeUnwrapVisitor {
cx,
constant_analyzer,
checked_exprs: HashSet::new(),
conditional_checker: HashSet::new(),
linted_spans: HashSet::new(),
Expand Down
27 changes: 20 additions & 7 deletions utils/src/constant_analyzer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,23 @@ impl<'a, 'tcx> ConstantAnalyzer<'a, 'tcx> {
fn is_qpath_constant(&self, path: &QPath) -> bool {
if let QPath::Resolved(_, path) = path {
match path.res {
Res::Def(def_kind, _) => matches!(
def_kind,
DefKind::AnonConst
| DefKind::AssocConst
| DefKind::Const
| DefKind::InlineConst
),
Res::Def(def_kind, def_id) => {
matches!(
def_kind,
DefKind::AnonConst
| DefKind::AssocConst
| DefKind::Const
| DefKind::InlineConst
) || {
// Allow both Some and Ok variant constructors
if let DefKind::Ctor(..) = def_kind {
let def_path = self.cx.tcx.def_path_str(def_id);
def_path.ends_with("::Some") || def_path.ends_with("::Ok")
} else {
false
}
}
}
Res::Local(hir_id) => self.constants.contains(&hir_id),
_ => false,
}
Expand Down Expand Up @@ -61,6 +71,9 @@ impl<'a, 'tcx> ConstantAnalyzer<'a, 'tcx> {
ExprKind::Struct(_, expr_fields, _) => expr_fields
.iter()
.all(|field_expr| self.is_expr_constant(field_expr.expr)),
ExprKind::Call(func, args) => {
self.is_expr_constant(func) && args.iter().all(|arg| self.is_expr_constant(arg))
}
_ => false,
}
}
Expand Down

0 comments on commit 9a472db

Please sign in to comment.