Skip to content

Commit

Permalink
Added test score filter to post-process filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
Habush committed Apr 16, 2019
1 parent 17d1f85 commit b2ef752
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 14 deletions.
8 changes: 6 additions & 2 deletions crossval/post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class PostProcess:

def __init__(self, filter_type, filter_value, mnemonic):
def __init__(self, filter_type, filter_value, mnemonic, overfitness=1):
self.filter_type = filter_type
self.filter_value = filter_value
self.mnemonic = mnemonic
Expand All @@ -20,6 +20,7 @@ def __init__(self, filter_type, filter_value, mnemonic):

self.session = None
self.models = []
self.overfitness = overfitness

def _retrieve_folds(self, base_dir):
swd = os.path.join(base_dir, f"session_{self.mnemonic}")
Expand All @@ -45,7 +46,10 @@ def filter_models(self, base_dir=DATASET_DIR):
self._folds_to_models()

filtered_models = []
filter_cls = loader.get_overfitness_filter(self.filter_type)
if self.overfitness == 1:
filter_cls = loader.get_overfitness_filter(self.filter_type)
else:
filter_cls = loader.get_score_filters(self.filter_type)

if filter_cls is not None:
filtered_models = filter_cls.filter_negatives(self.models)
Expand Down
20 changes: 10 additions & 10 deletions tests/data/session_folds/fold_2.csv
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
model,complexity,recall_test,precision_test,accuracy_test,f1_test,p_value_test,recall_train,precision_train,accuracy_train,f1_train,p_value_train
and(or(and(!$AMN $APLNR) $AK4) $ACOX2),4,0.625,0.7692307692307693,0.7096774193548387,0.6896551724137931,0.09228515625000003,0.7941176470588235,0.7941176470588235,0.7971014492753623,0.7941176470588235,0.0011201348827110807
and(!$AMN $APLNR),2,0.5,0.7272727272727273,0.6451612903225806,0.5925925925925926,0.22656250000000003,0.7058823529411765,0.8275862068965517,0.782608695652174,0.7619047619047619,0.0008302254277448881
and(or($AK4 !$AMN) $ACOX2),3,0.625,0.7142857142857143,0.6774193548387096,0.6666666666666666,0.17956542968750003,0.7941176470588235,0.7714285714285715,0.782608695652174,0.782608695652174,0.0023457869795667934
and(or($AK4 $APLNR) $ACOX2),3,0.6875,0.7857142857142857,0.7419354838709677,0.7333333333333334,0.05737304687500002,0.7941176470588235,0.7714285714285715,0.782608695652174,0.782608695652174,0.0023457869795667934
or(and($ACOX2 $APLNR) and($AK4 !$AMN)),4,0.8125,0.6842105263157895,0.7096774193548387,0.742857142857143,0.16706848144531244,0.8823529411764706,0.8333333333333334,0.855072463768116,0.8571428571428571,0.00012641846373680533
and(or($ACOX2 $AK4) !$AMN),3,0.6875,0.6875,0.6774193548387096,0.6875,0.21011352539062492,0.7647058823529411,0.896551724137931,0.8405797101449275,0.8253968253968255,4.402038910988803e-05
or(and($ACOX2 $AK4) and($ACOX2 $APLNR) and($AK4 !$AMN)),6,0.8125,0.6842105263157895,0.7096774193548387,0.742857142857143,0.16706848144531244,0.9117647058823529,0.7948717948717948,0.8405797101449275,0.8493150684931507,0.000426982234989981
or(and($ACOX2 !$AMN) and($ACOX2 $APLNR) and($AK4 !$AMN)),6,0.8125,0.65,0.6774193548387096,0.7222222222222223,0.2631759643554687,0.8823529411764706,0.8108108108108109,0.8405797101449275,0.8450704225352113,0.0002982932471420178
or(and($ACOX2 $APLNR) and($AK4 !$AMN) and($AK4 $APLNR)),6,0.9375,0.6521739130434783,0.7096774193548387,0.7692307692307693,0.21003961563110343,0.8823529411764706,0.8108108108108109,0.8405797101449275,0.8450704225352113,0.0002982932471420178
model,complexity,recall_test,precision_test,accuracy_test,f1_test,p_value_test,recall_train,precision_train,accuracy_train,f1_train,p_value_train
and(or(and(!$AMN $APLNR) $AK4) $ACOX2),4,0.625,0.769230769,0.709677419,0.689655172,0.092285156,0.794117647,0.794117647,0.797101449,0.794117647,0.001120135
and(!$AMN $APLNR),2,0.5,0.727272727,0.64516129,0.592592593,0.2265625,0.705882353,0.827586207,0.782608696,0.761904762,0.000830225
and(or($AK4 !$AMN) $ACOX2),3,0.625,0.714285714,0.677419355,0.666666667,0.17956543,0.794117647,0.771428571,0.782608696,0.782608696,0.002345787
and(or($AK4 $APLNR) $ACOX2),3,0.6875,0.785714286,0.741935484,0.733333333,0.057373047,0.794117647,0.771428571,0.782608696,0.782608696,0.002345787
or(and($ACOX2 $APLNR) and($AK4 !$AMN)),4,0.8125,0.684210526,0.709677419,0.742857143,0.167068481,0.882352941,0.833333333,0.855072464,0.857142857,0.000126418
and(or($ACOX2 $AK4) !$AMN),3,0.6875,0.6875,0.677419355,0.6875,0.210113525,0.764705882,0.896551724,0.84057971,0.825396825,4.40E-05
or(and($ACOX2 $AK4) and($ACOX2 $APLNR) and($AK4 !$AMN)),6,0.8125,0.684210526,0.709677419,0.742857143,0.167068481,0.911764706,0.794871795,0.84057971,0.849315068,0.000426982
or(and($ACOX2 !$AMN) and($ACOX2 $APLNR) and($AK4 !$AMN)),6,0.8125,0.65,0.677419355,0.722222222,0.263175964,0.882352941,0.810810811,0.84057971,0.845070423,0.000298293
or(and($ACOX2 $APLNR) and($AK4 !$AMN) and($AK4 $APLNR)),6,0.9375,0.652173913,0.92,0.769230769,0.210039616,0.882352941,0.810810811,0.84057971,0.845070423,0.000298293
2 changes: 1 addition & 1 deletion tests/test_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_fold_to_models(self, client):
@patch("pymongo.MongoClient")
def test_filter_models(self, client):
mock_db = mongomock.MongoClient().db
post_process = PostProcess("accuracy", 0.2, self.mnemonic)
post_process = PostProcess("accuracy", 0.8, self.mnemonic, 0)
post_process.db = mock_db

result_models = post_process.filter_models(TEST_DATA_DIR)
Expand Down
3 changes: 2 additions & 1 deletion webserver/apimain.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def send_result(mnemonic):
def filter_models(mnemonic):
filter_type = request.args.get("filter")
filter_value = request.args.get("value")
overfitness = int(request.args.get("overfitness"))

if filter_type is None or filter_value is None:
return jsonify({"message":"Filter type and value required"}), 400
Expand All @@ -111,7 +112,7 @@ def filter_models(mnemonic):
if client.exists(mnemonic) > 0:
client.delete(mnemonic)

post_process_filter = PostProcess(filter_type, value, mnemonic)
post_process_filter = PostProcess(filter_type, value, mnemonic, overfitness)
try:
models = post_process_filter.filter_models()
except AssertionError:
Expand Down

0 comments on commit b2ef752

Please sign in to comment.