From b2ef752254012810918f85423ae0f2c68afc10e3 Mon Sep 17 00:00:00 2001 From: Abdulrahman Semrie Date: Tue, 16 Apr 2019 21:11:08 +0300 Subject: [PATCH] Added test score filter to post-process filtering --- crossval/post_process.py | 8 ++++++-- tests/data/session_folds/fold_2.csv | 20 ++++++++++---------- tests/test_post_process.py | 2 +- webserver/apimain.py | 3 ++- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/crossval/post_process.py b/crossval/post_process.py index 87e8348..3ed4ad9 100644 --- a/crossval/post_process.py +++ b/crossval/post_process.py @@ -11,7 +11,7 @@ class PostProcess: - def __init__(self, filter_type, filter_value, mnemonic): + def __init__(self, filter_type, filter_value, mnemonic, overfitness=1): self.filter_type = filter_type self.filter_value = filter_value self.mnemonic = mnemonic @@ -20,6 +20,7 @@ def __init__(self, filter_type, filter_value, mnemonic): self.session = None self.models = [] + self.overfitness = overfitness def _retrieve_folds(self, base_dir): swd = os.path.join(base_dir, f"session_{self.mnemonic}") @@ -45,7 +46,10 @@ def filter_models(self, base_dir=DATASET_DIR): self._folds_to_models() filtered_models = [] - filter_cls = loader.get_overfitness_filter(self.filter_type) + if self.overfitness == 1: + filter_cls = loader.get_overfitness_filter(self.filter_type) + else: + filter_cls = loader.get_score_filters(self.filter_type) if filter_cls is not None: filtered_models = filter_cls.filter_negatives(self.models) diff --git a/tests/data/session_folds/fold_2.csv b/tests/data/session_folds/fold_2.csv index 12e14d1..7ea934a 100644 --- a/tests/data/session_folds/fold_2.csv +++ b/tests/data/session_folds/fold_2.csv @@ -1,10 +1,10 @@ -model,complexity,recall_test,precision_test,accuracy_test,f1_test,p_value_test,recall_train,precision_train,accuracy_train,f1_train,p_value_train -and(or(and(!$AMN $APLNR) $AK4) $ACOX2),4,0.625,0.7692307692307693,0.7096774193548387,0.6896551724137931,0.09228515625000003,0.7941176470588235,0.7941176470588235,0.7971014492753623,0.7941176470588235,0.0011201348827110807 -and(!$AMN $APLNR),2,0.5,0.7272727272727273,0.6451612903225806,0.5925925925925926,0.22656250000000003,0.7058823529411765,0.8275862068965517,0.782608695652174,0.7619047619047619,0.0008302254277448881 -and(or($AK4 !$AMN) $ACOX2),3,0.625,0.7142857142857143,0.6774193548387096,0.6666666666666666,0.17956542968750003,0.7941176470588235,0.7714285714285715,0.782608695652174,0.782608695652174,0.0023457869795667934 -and(or($AK4 $APLNR) $ACOX2),3,0.6875,0.7857142857142857,0.7419354838709677,0.7333333333333334,0.05737304687500002,0.7941176470588235,0.7714285714285715,0.782608695652174,0.782608695652174,0.0023457869795667934 -or(and($ACOX2 $APLNR) and($AK4 !$AMN)),4,0.8125,0.6842105263157895,0.7096774193548387,0.742857142857143,0.16706848144531244,0.8823529411764706,0.8333333333333334,0.855072463768116,0.8571428571428571,0.00012641846373680533 -and(or($ACOX2 $AK4) !$AMN),3,0.6875,0.6875,0.6774193548387096,0.6875,0.21011352539062492,0.7647058823529411,0.896551724137931,0.8405797101449275,0.8253968253968255,4.402038910988803e-05 -or(and($ACOX2 $AK4) and($ACOX2 $APLNR) and($AK4 !$AMN)),6,0.8125,0.6842105263157895,0.7096774193548387,0.742857142857143,0.16706848144531244,0.9117647058823529,0.7948717948717948,0.8405797101449275,0.8493150684931507,0.000426982234989981 -or(and($ACOX2 !$AMN) and($ACOX2 $APLNR) and($AK4 !$AMN)),6,0.8125,0.65,0.6774193548387096,0.7222222222222223,0.2631759643554687,0.8823529411764706,0.8108108108108109,0.8405797101449275,0.8450704225352113,0.0002982932471420178 -or(and($ACOX2 $APLNR) and($AK4 !$AMN) and($AK4 $APLNR)),6,0.9375,0.6521739130434783,0.7096774193548387,0.7692307692307693,0.21003961563110343,0.8823529411764706,0.8108108108108109,0.8405797101449275,0.8450704225352113,0.0002982932471420178 +model,complexity,recall_test,precision_test,accuracy_test,f1_test,p_value_test,recall_train,precision_train,accuracy_train,f1_train,p_value_train +and(or(and(!$AMN $APLNR) $AK4) $ACOX2),4,0.625,0.769230769,0.709677419,0.689655172,0.092285156,0.794117647,0.794117647,0.797101449,0.794117647,0.001120135 +and(!$AMN $APLNR),2,0.5,0.727272727,0.64516129,0.592592593,0.2265625,0.705882353,0.827586207,0.782608696,0.761904762,0.000830225 +and(or($AK4 !$AMN) $ACOX2),3,0.625,0.714285714,0.677419355,0.666666667,0.17956543,0.794117647,0.771428571,0.782608696,0.782608696,0.002345787 +and(or($AK4 $APLNR) $ACOX2),3,0.6875,0.785714286,0.741935484,0.733333333,0.057373047,0.794117647,0.771428571,0.782608696,0.782608696,0.002345787 +or(and($ACOX2 $APLNR) and($AK4 !$AMN)),4,0.8125,0.684210526,0.709677419,0.742857143,0.167068481,0.882352941,0.833333333,0.855072464,0.857142857,0.000126418 +and(or($ACOX2 $AK4) !$AMN),3,0.6875,0.6875,0.677419355,0.6875,0.210113525,0.764705882,0.896551724,0.84057971,0.825396825,4.40E-05 +or(and($ACOX2 $AK4) and($ACOX2 $APLNR) and($AK4 !$AMN)),6,0.8125,0.684210526,0.709677419,0.742857143,0.167068481,0.911764706,0.794871795,0.84057971,0.849315068,0.000426982 +or(and($ACOX2 !$AMN) and($ACOX2 $APLNR) and($AK4 !$AMN)),6,0.8125,0.65,0.677419355,0.722222222,0.263175964,0.882352941,0.810810811,0.84057971,0.845070423,0.000298293 +or(and($ACOX2 $APLNR) and($AK4 !$AMN) and($AK4 $APLNR)),6,0.9375,0.652173913,0.92,0.769230769,0.210039616,0.882352941,0.810810811,0.84057971,0.845070423,0.000298293 \ No newline at end of file diff --git a/tests/test_post_process.py b/tests/test_post_process.py index ccd0b2e..859fd31 100644 --- a/tests/test_post_process.py +++ b/tests/test_post_process.py @@ -41,7 +41,7 @@ def test_fold_to_models(self, client): @patch("pymongo.MongoClient") def test_filter_models(self, client): mock_db = mongomock.MongoClient().db - post_process = PostProcess("accuracy", 0.2, self.mnemonic) + post_process = PostProcess("accuracy", 0.8, self.mnemonic, 0) post_process.db = mock_db result_models = post_process.filter_models(TEST_DATA_DIR) diff --git a/webserver/apimain.py b/webserver/apimain.py index dd39ae6..3b5d627 100644 --- a/webserver/apimain.py +++ b/webserver/apimain.py @@ -92,6 +92,7 @@ def send_result(mnemonic): def filter_models(mnemonic): filter_type = request.args.get("filter") filter_value = request.args.get("value") + overfitness = int(request.args.get("overfitness")) if filter_type is None or filter_value is None: return jsonify({"message":"Filter type and value required"}), 400 @@ -111,7 +112,7 @@ def filter_models(mnemonic): if client.exists(mnemonic) > 0: client.delete(mnemonic) - post_process_filter = PostProcess(filter_type, value, mnemonic) + post_process_filter = PostProcess(filter_type, value, mnemonic, overfitness) try: models = post_process_filter.filter_models() except AssertionError: