-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrandom_test_set_avoidleakage.py
121 lines (99 loc) · 3.5 KB
/
random_test_set_avoidleakage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import argparse
import logging
from tracemalloc import start
from turtle import shape
logging.basicConfig(level=logging.INFO)
import os
# import sys
# import time
# import pickle
import torch
import torch.nn.functional as F
# from openfold.utils.seed import seed_everything
from pytorch_lightning.utilities.seed import seed_everything
import shutil
import debugger
# def gather_job(pdb_dir):
# pdb_paths = []
# for f_path in os.listdir(pdb_dir):
# if f_path.endswith('.pdb'):
# pdb_path = os.path.join(pdb_dir, f_path)
# pdb_paths.append(pdb_path)
# return pdb_paths
def gather_job(pdb_dir):
pdb_names = []
for f_path in os.listdir(pdb_dir):
if f_path.endswith('.pdb'):
pdb_name = f_path[:5]
if pdb_name not in pdb_names:
pdb_names.append(pdb_name)
return pdb_names
def bool_type(bool_str: str):
bool_str_lower = bool_str.lower()
if bool_str_lower in ('false', 'f', 'no', 'n', '0'):
return False
elif bool_str_lower in ('true', 't', 'yes', 'y', '1'):
return True
else:
raise ValueError(f'Cannot interpret {bool_str} as bool')
def main(args):
if args.seed is not None:
seed_everything(args.seed)
output_dir = args.output_dir
if not os.path.exists(output_dir):
os.makedirs(output_dir)
output_test = os.path.join(output_dir,"test")
output_train = os.path.join(output_dir,"train")
if not os.path.exists(output_test):
os.makedirs(output_test)
if not os.path.exists(output_train):
os.makedirs(output_train)
jobs = gather_job(args.pdb_path)
logging.info(f'got {len(jobs)} jobs...')
# Get input
top_k = 40
prob = torch.rand(len(jobs))
_, indexes = torch.topk(prob, top_k)
indexes = indexes.tolist()
topk_test = [jobs[i] for i in indexes]
topk_train = [x for x in jobs if x not in topk_test]
post_string = ["_sample0", "_sample1", "_sample2"]
for job in topk_test:
for i in range(3):
name = job+post_string[i]
src = os.path.join(args.pdb_path,f"{name}.pdb")
tgt = os.path.join(output_test,f"{name}.pdb")
# print(f"#################")
# print(f">>> treating: {name}")
shutil.copyfile(src, tgt)
for job in topk_train:
for i in range(3):
name = job+post_string[i]
src = os.path.join(args.pdb_path,f"{name}.pdb")
tgt = os.path.join(output_train,f"{name}.pdb")
# print(f"#################")
# print(f">>> treating: {name}")
shutil.copyfile(src, tgt)
# # 2nd option
# shutil.copy(src, dst)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"pdb_path", type=str,
)
parser.add_argument(
"--output_dir", type=str, default=os.getcwd(),
help="Name of the directory in which to output the prediction",
)
parser.add_argument(
'--seed', type=int, default=42,
help="Random seed"
)
parser.add_argument(
"--name_length", type=int, default=13,
help="how many characters are used to name the protein"
)
args = parser.parse_args()
main(args)
# usage
# python /home/Xcwang/scratch/beluga/JointProteinFolding/random_test_set_avoidleakage.py /nfs/work04/chuanrui/data/sampled_protein/sampling --output_dir /nfs/work04/chuanrui/data/subset_1200_noleakage