forked from ncbi/histonedb
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest.py
executable file
·116 lines (96 loc) · 5.9 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import unittest, json
from Bio import SeqIO
from prediction_app.prediction.path_variables import *
class TestHistonedbTypeClassifier(unittest.TestCase):
def setUp(self):
from prediction.histonedb_classifier import HistonedbTypeClassifier
# self.classifier = HistonedbTypeClassifier()
with open(VARIANTS_JSON) as f:
variant_json = json.loads(f.read())
self.classification_tree = variant_json['tree']
self.classification_tree.pop('Archaeal')
self.classification_tree.pop('Viral')
self.classifier = HistonedbTypeClassifier(classification_tree=self.classification_tree)
def test_create_hmms(self):
# self.classifier.create_hmms()
self.classifier.create_hmms(seed_directory=os.path.join(DATA_DIRECTORY, "draft_seeds"))
self.assertEqual(self.classifier.combined_hmm_file, os.path.join(COMBINED_HMM_DIRECTORY, 'types_combined.hmm'))
def test_predict_type(self):
# self.classifier.create_hmms()
self.classifier.create_hmms(seed_directory=os.path.join(DATA_DIRECTORY, "draft_seeds"))
res = self.classifier.predict(sequences=os.path.join(PREDICTION_DIRECTORY, "test"))
self.classifier.save_prediction_info(file_name=os.path.join(PREDICTION_RESULTS_DIRECTORY, "res_types.csv"))
with open(os.path.join(PREDICTION_DIRECTORY, "test")) as t:
test_data = list(SeqIO.parse(t, "fasta"))
self.assertIsInstance(self.classifier.prediction_info, list)
self.assertIsInstance(res, list)
self.assertIsInstance(self.classifier.predicted_results, list)
self.assertGreater(len(self.classifier.prediction_info), len(test_data))
try: #NP_001295191.1
self.assertEqual(len(res), len(test_data)-1)
self.assertEqual(len(self.classifier.predicted_results), len(test_data) - 1)
except:
print(f"Symetric_difference is {set(res.values('accession')).symmetric_difference(set([s.id for s in test_data]))}")
raise
def test_dump_results(self):
import pickle
# self.classifier.create_hmms()
# self.classifier.create_blastdbs()
self.classifier.create_hmms(seed_directory=os.path.join(DATA_DIRECTORY, "draft_seeds"))
res = self.classifier.predict(sequences=os.path.join(PREDICTION_DIRECTORY, "test"))
self.classifier.dump_results(file_name=os.path.join(PREDICTION_RESULTS_DIRECTORY, "res_types.pickle"))
with open(os.path.join(PREDICTION_RESULTS_DIRECTORY, "res_types.pickle"), 'rb') as f:
data_new = pickle.load(f)
self.assertEqual(len(data_new), len(self.classifier.prediction_info))
self.assertEqual(len(data_new.get_keys()), len(self.classifier.prediction_info.get_keys()))
class TestHistonedbVariantClassifier(unittest.TestCase):
def setUp(self):
from prediction.histonedb_classifier import HistonedbVariantClassifier
# self.classifier = HistonedbVariantClassifier()
with open(VARIANTS_JSON) as f:
variant_json = json.loads(f.read())
self.classification_tree = variant_json['tree']
self.classification_tree.pop('Archaeal')
self.classification_tree.pop('Viral')
self.classifier = HistonedbVariantClassifier(classification_tree=self.classification_tree)
def test_create_hmms(self):
# self.classifier.create_hmms()
self.classifier.create_hmms(seed_directory=os.path.join(DATA_DIRECTORY, "draft_seeds"))
self.assertEqual(self.classifier.combined_hmm_file, os.path.join(COMBINED_HMM_DIRECTORY, 'types_combined.hmm'))
def test_create_blastdbs(self):
self.classifier.create_blastdbs()
def test_predict_variant(self):
# self.classifier.create_hmms()
self.classifier.create_hmms(seed_directory=os.path.join(DATA_DIRECTORY, "draft_seeds"))
self.classifier.create_blastdbs()
self.classifier.create_hmms(seed_directory=os.path.join(DATA_DIRECTORY, "draft_seeds"))
res = self.classifier.predict(sequences=os.path.join(PREDICTION_DIRECTORY, "test"))
self.classifier.save_prediction_info(file_name=os.path.join(PREDICTION_RESULTS_DIRECTORY, "res_variants.csv"))
with open(os.path.join(PREDICTION_DIRECTORY, "test")) as t:
test_data = list(SeqIO.parse(t, "fasta"))
self.assertIsInstance(self.classifier.prediction_info, list)
self.assertIsInstance(res, list)
self.assertIsInstance(self.classifier.predicted_results, list)
self.assertGreater(len(self.classifier.prediction_info), len(test_data))
self.assertGreater(len(self.classifier.prediction_info),
len(list(filter(lambda d: d['best'], self.classifier.prediction_info))))
try: #NP_001295191.1
self.assertEqual(len(list(filter(lambda d: d['best'], self.classifier.prediction_info))), len(test_data)-1)
self.assertEqual(len(res), len(test_data)-1)
self.assertEqual(len(self.classifier.predicted_results), len(test_data)-1)
except:
print(f"Symetric_difference is {set(res.values('accession')).symmetric_difference(set([s.id for s in test_data]))}")
raise
def test_dump_results(self):
import pickle
# self.classifier.create_hmms()
self.classifier.create_hmms(seed_directory=os.path.join(DATA_DIRECTORY, "draft_seeds"))
self.classifier.create_blastdbs()
res = self.classifier.predict(sequences=os.path.join(PREDICTION_DIRECTORY, "test"))
self.classifier.dump_results(file_name=os.path.join(PREDICTION_RESULTS_DIRECTORY, "res_variants.pickle"))
with open(os.path.join(PREDICTION_RESULTS_DIRECTORY, "res_variants.pickle"), 'rb') as f:
data_new = pickle.load(f)
self.assertEqual(len(data_new), len(self.classifier.prediction_info))
self.assertEqual(len(data_new.get_keys()), len(self.classifier.prediction_info.get_keys()))
if __name__ == '__main__':
unittest.main()