-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCheckMalloc.cpp
122 lines (107 loc) · 3.42 KB
/
CheckMalloc.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#include <unordered_set>
#include "clang/Frontend/FrontendPluginRegistry.h"
#include "clang/AST/AST.h"
#include "clang/AST/ASTConsumer.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/Basic/FileManager.h"
#include "clang/Frontend/CompilerInstance.h"
#include "llvm/Support/raw_ostream.h"
#include "clang/AST/Expr.h"
#include "clang/Lex/Lexer.h"
#include "clang/Basic/SourceManager.h"
#include "clang/Basic/SourceLocation.h"
#include "clang/Lex/Lexer.h"
#include <string>
using namespace clang;
class CheckMallocVisitor : public RecursiveASTVisitor<CheckMallocVisitor>
{
private:
ASTContext *context;
CompilerInstance &instance;
DiagnosticsEngine &d;
unsigned int warningID;
bool isInHeader(IfStmt *decl)
{
auto loc = decl->getBeginLoc();
auto floc = context->getFullLoc(loc);
if (floc.isInSystemHeader())
return true;
auto entry = floc.getFileEntry()->getName();
if (entry.endswith(".h") || entry.endswith(".hpp"))
{
return true;
}
return false;
}
bool checkMalloc(CallExpr *S, std::string name)
{
return (name.compare("malloc") == 0);
}
bool checkIfStmt(IfStmt *s, ASTContext *context)
{
Expr *e = s->getCond();
CharSourceRange cha = Lexer::getAsCharRange(e->getSourceRange(), context->getSourceManager(), context->getLangOpts());
llvm::StringRef text = Lexer::getSourceText(cha, context->getSourceManager(), context->getLangOpts());
return text.contains("!= NULL");
}
public:
explicit CheckMallocVisitor(ASTContext *context, CompilerInstance &instance) : context(context), instance(instance), d(instance.getDiagnostics())
{
warningID = d.getCustomDiagID(DiagnosticsEngine::Warning,
"Unchecked pointer: '%0'");
}
bool isMallocCalled = false;
bool isPtrChecked = false;
virtual bool VisitCallExpr(CallExpr *S)
{
isMallocCalled = checkMalloc(S, S->getCalleeDecl()->getAsFunction()->getNameAsString());
return true;
}
virtual bool VisitIfStmt(IfStmt *S)
{
isPtrChecked = checkIfStmt(S, context);
return true;
}
virtual bool VisitStmt(Stmt *S)
{
CharSourceRange cha = Lexer::getAsCharRange(S->getSourceRange(), context->getSourceManager(), context->getLangOpts());
llvm::StringRef text = Lexer::getSourceText(cha, context->getSourceManager(), context->getLangOpts());
if (isMallocCalled & !isPtrChecked &
text.contains("*")) {
auto loc = context->getFullLoc(S->getBeginLoc());
d.Report(loc, warningID) << "Check pointer for NULL before using";
isMallocCalled = false;
isPtrChecked = false;
}
return true;
}
};
class CheckMallocConsumer : public ASTConsumer
{
CompilerInstance &instance;
CheckMallocVisitor visitor;
public:
CheckMallocConsumer(CompilerInstance &instance)
: instance(instance), visitor(&instance.getASTContext(), instance) {}
virtual void HandleTranslationUnit(ASTContext &context) override
{
visitor.TraverseDecl(context.getTranslationUnitDecl());
}
};
class CheckMallocAction : public PluginASTAction
{
protected:
virtual std::unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance &instance, llvm::StringRef) override
{
return std::make_unique<CheckMallocConsumer>(instance);
}
virtual bool ParseArgs(const CompilerInstance &Compiler, const std::vector<std::string> &) override
{
return true;
}
virtual PluginASTAction::ActionType getActionType() override
{
return PluginASTAction::AddAfterMainAction;
}
};
static FrontendPluginRegistry::Add<CheckMallocAction> CheckMalloc("CheckMalloc", "Warn against unchecked pointers");