Skip to content

Commit

Permalink
伪标签函数增加参数p.
Browse files Browse the repository at this point in the history
  • Loading branch information
enjoysport2022 committed Mar 14, 2022
1 parent 9fce06e commit 540a992
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions autox/autox_competition/process_data/get_pseudo_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from sklearn.model_selection import StratifiedKFold
from tqdm import tqdm

def get_pseudo_label(train, test, id_, target, used_cols):
def get_pseudo_label(train, test, id_, target, used_cols, p = 0.99):
assert 0.5 < p < 1
sub = test[[id_]]
sub[target] = 0

Expand All @@ -15,7 +16,7 @@ def get_pseudo_label(train, test, id_, target, used_cols):
pred = clf.predict_proba(test[used_cols])[:,1]
sub[target] = sub[target] + pred / skf.n_splits

pseudo_test = sub[(sub[target] <= 0.01) | (sub[target] >= 0.99)].copy()
pseudo_test = sub[(sub[target] <= (1-p)) | (sub[target] >= p)].copy()
pseudo_test.loc[pseudo_test[target] >= 0.5, target] = 1
pseudo_test.loc[pseudo_test[target] < 0.5, target] = 0
pseudo_test.index = range(len(pseudo_test))
Expand Down

0 comments on commit 540a992

Please sign in to comment.