-
Notifications
You must be signed in to change notification settings - Fork 1
/
f_trees_fast.cpp
138 lines (97 loc) · 4.91 KB
/
f_trees_fast.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#include <Rcpp.h>
using namespace Rcpp;
#if !defined(WIN32) && !defined(__WIN32) && !defined(__WIN32__)
#include <unistd.h>
#include <Rinterface.h>
#endif
// [[Rcpp::plugins(cpp11)]]
// [[Rcpp::export]]
int findElement(IntegerVector x,int feat_name){
const int n = x.size();
for (int i=0; i < n;i++){
if(x[i] == feat_name){
return i;
}
}
return -1;
}
// [[Rcpp::export]]
double f_trees_cpp(const IntegerVector &Features,const NumericVector &Feat_vals,const IntegerVector &feat_names,
const NumericVector &Prediction,const IntegerVector &Yes,const IntegerVector &No,
const NumericVector &Split,const NumericVector &Cover,int node){
int feat = Features[node];
if (feat == -1) {
double pred = Prediction[node];
return pred;
} else {
int pos = findElement(feat_names,feat);
if (pos >= 0) {
double feat_val = Feat_vals[pos];
double split_val = Split[node];
int NextNode = Yes[node];
if (feat_val > split_val) {
NextNode = No[node];
}
return f_trees_cpp(Features,Feat_vals,feat_names,Prediction,Yes,No,Split,Cover,NextNode);
} else {
int YesNode = Yes[node];
int NoNode = No[node];
return f_trees_cpp(Features,Feat_vals,feat_names,Prediction,Yes,No,Split,Cover,YesNode)*Cover[YesNode] + f_trees_cpp(Features,Feat_vals,feat_names,Prediction,Yes,No,Split,Cover,NoNode)*Cover[NoNode];
}
}
}
// [[Rcpp::export]]
NumericVector f_trees_cpp_vec(const IntegerVector &Features,const NumericVector &Feat_vals,const IntegerVector &feat_names,
const NumericVector &Prediction,const IntegerVector &Yes,const IntegerVector &No,
const NumericVector &Split,const NumericVector &Cover,const IntegerVector &nodes){
int tot = nodes.size();
NumericVector all_trees (tot);
for (int i=0; i < tot;i++){
int node = nodes[i];
double calc = f_trees_cpp(Features,Feat_vals,feat_names,Prediction,Yes,No,Split,Cover,node);
all_trees[i] = calc;
}
return all_trees;
}
// [[Rcpp::export]]
double subSAGE_per_S_linreg(const IntegerVector &Features,const DataFrame &Feat_vals_S, const DataFrame &Feat_vals_SuK, const IntegerVector &feat_names_S,
const IntegerVector &feat_names_SuK,const NumericVector &Prediction,const IntegerVector &Yes,const IntegerVector &No,
const NumericVector &Split,const NumericVector &Cover,const IntegerVector &feature_trees,const IntegerVector &roots_at_bar_trees,
const NumericVector &response,int &n_inds){
NumericVector ret (n_inds);
for (int j=0; j < n_inds;j++){
NumericVector feat_vals_S_ind = Feat_vals_S[j];
NumericVector feat_vals_SuK_ind = Feat_vals_SuK[j];
double EfSuK = sum(f_trees_cpp_vec(Features,feat_vals_SuK_ind,feat_names_SuK ,Prediction,
Yes,No,Split,Cover,feature_trees));
double EfS = sum(f_trees_cpp_vec(Features,feat_vals_S_ind,feat_names_S ,Prediction,
Yes,No,Split,Cover,feature_trees));
double Ef_bartrees_S = sum(f_trees_cpp_vec(Features,feat_vals_S_ind,feat_names_S ,Prediction,
Yes,No,Split,Cover,roots_at_bar_trees));
double response_j = response[j];
ret[j] = 2*response_j*(EfSuK-EfS) + EfS*EfS - EfSuK*EfSuK + 2*Ef_bartrees_S*(EfSuK-EfS);
}
double MeanRet = mean(ret);
return MeanRet;
}
// [[Rcpp::export]]
double subSAGE_per_S_logreg(const IntegerVector &Features,const DataFrame &Feat_vals_S, const DataFrame &Feat_vals_SuK, const IntegerVector &feat_names_S,
const IntegerVector &feat_names_SuK,const NumericVector &Prediction,const IntegerVector &Yes,const IntegerVector &No,
const NumericVector &Split,const NumericVector &Cover,const IntegerVector &feature_trees,const IntegerVector &roots_at_bar_trees,
const NumericVector &response,int &n_inds){
NumericVector ret (n_inds);
for (int j=0; j < n_inds;j++){
NumericVector feat_vals_S_ind = Feat_vals_S[j];
NumericVector feat_vals_SuK_ind = Feat_vals_SuK[j];
double EfSuK = sum(f_trees_cpp_vec(Features,feat_vals_SuK_ind,feat_names_SuK ,Prediction,
Yes,No,Split,Cover,feature_trees));
double EfS = sum(f_trees_cpp_vec(Features,feat_vals_S_ind,feat_names_S ,Prediction,
Yes,No,Split,Cover,feature_trees));
double Ef_bartrees_S = sum(f_trees_cpp_vec(Features,feat_vals_S_ind,feat_names_S ,Prediction,
Yes,No,Split,Cover,roots_at_bar_trees));
double response_j = response[j];
ret[j] = (1-response_j)*(EfS-EfSuK) + log((1+exp(-EfS-Ef_bartrees_S))/(1+exp(-EfSuK-Ef_bartrees_S)));
}
double MeanRet = mean(ret);
return MeanRet;
}