-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathclassifier_ert.cpp
70 lines (65 loc) · 2.77 KB
/
classifier_ert.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include "classifier_ert.h"
ClassifierERT::ClassifierERT()
{
par_MaxDepth = 4;
par_MinSampleCount = 2;
par_RegressionAccuracy = 0.f;
par_UseSurrogates = false;
par_MaxCategories = 16;
par_CalcVarImportance = false;
par_NactiveVars = 1;
par_MaxNumTreesInForest = 5;
par_ForestAccuracy = 0.f;
par_TermCritType = rf_termcrit_idx[0];
}
void ClassifierERT::trainData(const std::vector<cv::Point> &data, const std::vector<int> &labels)
{
cls.clear();
loadData(data, labels);
cv::Mat var_types( 1, pData.cols + 1, CV_8UC1, cv::Scalar(CV_VAR_ORDERED) );
var_types.at<uchar>( pData.cols ) = CV_VAR_CATEGORICAL;
cv::RandomTreeParams params;
params.max_depth = par_MaxDepth;
params.min_sample_count = par_MinSampleCount;
params.regression_accuracy = par_RegressionAccuracy;
params.use_surrogates = par_UseSurrogates;
params.max_categories = par_MaxCategories;
params.calc_var_importance = par_CalcVarImportance;
params.nactive_vars = par_NactiveVars;
params.term_crit = cv::TermCriteria(par_TermCritType, par_MaxNumTreesInForest, par_ForestAccuracy);
cls.train(pData, CV_ROW_SAMPLE, lData, cv::Mat(), cv::Mat(), var_types, cv::Mat(), params);
isTrainedFlag = true;
}
int ClassifierERT::classify(int x, int y)
{
testSample.at<float>(0) = (float)x;
testSample.at<float>(1) = (float)y;
return cvRound(cls.predict(testSample));
}
QString ClassifierERT::toQString() const
{
return QString("ERT{max_depth=%1, min_sample_count=%2, regression_accuracy=%3, use_surrogates=%4, max_categories=%5, calc_var_importance=%6, nactive_vars==%7, termCrit=(%8,%9,%10)}")
.arg(par_MaxDepth)
.arg(par_MinSampleCount)
.arg(par_RegressionAccuracy)
.arg(par_UseSurrogates)
.arg(par_MaxCategories)
.arg(par_CalcVarImportance)
.arg(par_NactiveVars)
.arg(par_TermCritType)
.arg(par_MaxNumTreesInForest)
.arg(par_ForestAccuracy);
}
void ClassifierERT::setParameters(int maxDepth, int minSampleCount, float regressionAccuracy, bool useSurrogates, int maxCategories, bool calcVarImportance, int nactiveVars, int maxNumTreesInForest, float forestAccuracy, int termCritType)
{
par_MaxDepth = maxDepth;
par_MinSampleCount = minSampleCount;
par_RegressionAccuracy = regressionAccuracy;
par_UseSurrogates = useSurrogates;
par_MaxCategories = maxCategories;
par_CalcVarImportance = calcVarImportance;
par_NactiveVars = nactiveVars;
par_MaxNumTreesInForest = maxNumTreesInForest;
par_ForestAccuracy = forestAccuracy;
par_TermCritType = termCritType;
}