Skip to content

Commit

Permalink
Merge pull request #118 from adaptyvbio/split_fixes
Browse files Browse the repository at this point in the history
Improve basic functions
  • Loading branch information
elkoz authored Nov 10, 2023
2 parents 7e09671 + bdabe46 commit 7818f1f
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 82 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ run_metrics.py
*.csv
esmfold_output/
igfold_output/
*egg-info
52 changes: 29 additions & 23 deletions proteinflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,34 +520,40 @@ def split_data(
temp_folder = os.path.join(tempfile.gettempdir(), "proteinflow")
if not os.path.exists(temp_folder):
os.makedirs(temp_folder)
if exclude_chains_file is not None or exclude_chains is not None:
excluded_biounits = _get_excluded_files(
tag,
local_datasets_folder,
temp_folder,
exclude_chains,
exclude_chains_file,
exclude_threshold,
)
else:
excluded_biounits = []
if exclude_chains_without_ligands:
excluded_biounits += _exclude_files_with_no_ligand(
tag,
local_datasets_folder,
)

output_folder = os.path.join(local_datasets_folder, f"proteinflow_{tag}")
out_split_dict_folder = os.path.join(output_folder, "splits_dict")
exists = False
if ignore_existing and os.path.exists(out_split_dict_folder):
shutil.rmtree(out_split_dict_folder)

if os.path.exists(out_split_dict_folder):
if not ignore_existing:
warnings.warn(
f"Found an existing dictionary for tag {tag}. proteinflow will load it and ignore the parameters! Run with --ignore_existing to overwrite."
if os.path.exists(os.path.join(output_folder, "splits_dict", "excluded.pickle")):
warnings.warn(
"Found an existing dictionary for excluded chains. proteinflow will load it and ignore the exclusion parameters! Run with --ignore_existing to overwrite the splitting."
)
excluded_biounits = None
else:
if exclude_chains_file is not None or exclude_chains is not None:
excluded_biounits = _get_excluded_files(
tag,
local_datasets_folder,
temp_folder,
exclude_chains,
exclude_chains_file,
exclude_threshold,
)
exists = True
if not exists:
else:
excluded_biounits = []
if exclude_chains_without_ligands:
excluded_biounits += _exclude_files_with_no_ligand(
tag,
local_datasets_folder,
)

if os.path.exists(out_split_dict_folder):
warnings.warn(
f"Found an existing dictionary for tag {tag}. proteinflow will load it and ignore the parameters! Run with --ignore_existing to overwrite."
)
else:
_check_mmseqs()
random.seed(random_seed)
np.random.seed(random_seed)
Expand Down
16 changes: 1 addition & 15 deletions proteinflow/download/boto.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def _download_dataset_from_s3(
shutil.move(s3_path, dataset_path)
with zipfile.ZipFile(local_zip_path, "r") as zip_ref:
zip_ref.extractall(os.path.dirname(dataset_path))
os.remove(local_zip_path)


def _get_s3_paths_from_tag(tag):
Expand All @@ -145,21 +146,6 @@ def _get_s3_paths_from_tag(tag):
return data_path, dict_path


def _download_zip_dataset_from_s3(
dataset_path="./data/proteinflow_20221110/",
s3_path="s3://ml4-main-storage/proteinflow_20221110/",
):
"""Download the pre-processed files."""
if s3_path.startswith("s3"):
print("Downloading the dataset from s3...")
subprocess.run(
["aws", "s3", "sync", "--no-sign-request", s3_path, dataset_path]
)
print("Done!")
else:
shutil.move(s3_path, dataset_path)


async def _getobj(client, key):
"""Get an object from S3."""
resp = await client.get_object(Bucket="pdbsnapshots", Key=key)
Expand Down
101 changes: 57 additions & 44 deletions proteinflow/split/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,7 +1381,7 @@ def _get_excluded_files(

def _split_data(
dataset_path="./data/proteinflow_20221110/",
excluded_files=None,
excluded_biounits=None,
exclude_clusters=False,
exclude_based_on_cdr=None,
):
Expand All @@ -1399,8 +1399,15 @@ def _split_data(
If not `None`, exclude all files in a cluster if the cluster name does not end with `exclude_based_on_cdr`
"""
if excluded_files is None:
excluded_files = []
if excluded_biounits is None:
if os.path.exists(os.path.join(dataset_path, "splits_dict", "excluded.pickle")):
with open(
os.path.join(dataset_path, "splits_dict", "excluded.pickle"), "rb"
) as f:
excluded_clusters_dict = pickle.load(f)
excluded_biounits = _biounits_in_clusters_dict(excluded_clusters_dict, [])
else:
excluded_biounits = []

dict_folder = os.path.join(dataset_path, "splits_dict")
with open(os.path.join(dict_folder, "train.pickle"), "rb") as f:
Expand All @@ -1410,9 +1417,9 @@ def _split_data(
with open(os.path.join(dict_folder, "test.pickle"), "rb") as f:
test_clusters_dict = pickle.load(f)

train_biounits = _biounits_in_clusters_dict(train_clusters_dict, excluded_files)
valid_biounits = _biounits_in_clusters_dict(valid_clusters_dict, excluded_files)
test_biounits = _biounits_in_clusters_dict(test_clusters_dict, excluded_files)
train_biounits = _biounits_in_clusters_dict(train_clusters_dict, excluded_biounits)
valid_biounits = _biounits_in_clusters_dict(valid_clusters_dict, excluded_biounits)
test_biounits = _biounits_in_clusters_dict(test_clusters_dict, excluded_biounits)
train_path = os.path.join(dataset_path, "train")
valid_path = os.path.join(dataset_path, "valid")
test_path = os.path.join(dataset_path, "test")
Expand All @@ -1427,43 +1434,49 @@ def _split_data(
if not os.path.exists(test_path):
os.makedirs(test_path)

if len(excluded_files) > 0:
set_to_exclude = set(excluded_files)
excluded_files = set()
excluded_clusters_dict = defaultdict(list)
for clusters_dict in [
train_clusters_dict,
valid_clusters_dict,
test_clusters_dict,
]:
for cluster in list(clusters_dict.keys()):
idx_to_exclude = []
exclude_whole_cluster = False
for i, chain in enumerate(clusters_dict[cluster]):
if chain[0] in set_to_exclude:
if exclude_clusters:
if exclude_based_on_cdr is not None and cluster.endswith(
exclude_based_on_cdr
):
exclude_whole_cluster = True
elif exclude_based_on_cdr is None:
exclude_whole_cluster = True
if exclude_whole_cluster:
break
excluded_clusters_dict[cluster].append(chain)
idx_to_exclude.append(i)
if exclude_whole_cluster:
excluded_clusters_dict[cluster] = clusters_dict.pop(cluster)
else:
clusters_dict[cluster] = [
x
for i, x in enumerate(clusters_dict[cluster])
if i not in idx_to_exclude
]
if len(clusters_dict[cluster]) == 0:
clusters_dict.pop(cluster)
excluded_files.update(set_to_exclude)
excluded_clusters_dict = {k: list(v) for k, v in excluded_clusters_dict.items()}
if len(excluded_biounits) > 0:
if not os.path.exists(
os.path.join(dataset_path, "splits_dict", "excluded.pickle")
):
set_to_exclude = set(excluded_biounits)
excluded_biounits = set()
excluded_clusters_dict = defaultdict(list)
for clusters_dict in [
train_clusters_dict,
valid_clusters_dict,
test_clusters_dict,
]:
for cluster in list(clusters_dict.keys()):
idx_to_exclude = []
exclude_whole_cluster = False
for i, chain in enumerate(clusters_dict[cluster]):
if chain[0] in set_to_exclude:
if exclude_clusters:
if (
exclude_based_on_cdr is not None
and cluster.endswith(exclude_based_on_cdr)
):
exclude_whole_cluster = True
elif exclude_based_on_cdr is None:
exclude_whole_cluster = True
if exclude_whole_cluster:
break
excluded_clusters_dict[cluster].append(chain)
idx_to_exclude.append(i)
if exclude_whole_cluster:
excluded_clusters_dict[cluster] = clusters_dict.pop(cluster)
else:
clusters_dict[cluster] = [
x
for i, x in enumerate(clusters_dict[cluster])
if i not in idx_to_exclude
]
if len(clusters_dict[cluster]) == 0:
clusters_dict.pop(cluster)
excluded_clusters_dict = {
k: list(v) for k, v in excluded_clusters_dict.items()
}
excluded_biounits = _biounits_in_clusters_dict(excluded_clusters_dict, [])
excluded_path = os.path.join(dataset_path, "excluded")
if not os.path.exists(excluded_path):
os.makedirs(excluded_path)
Expand All @@ -1477,7 +1490,7 @@ def _split_data(
with open(os.path.join(dict_folder, "excluded.pickle"), "wb") as f:
pickle.dump(excluded_clusters_dict, f)
print("Moving excluded files...")
for biounit in tqdm(excluded_files):
for biounit in tqdm(excluded_biounits):
shutil.move(os.path.join(dataset_path, biounit), excluded_path)
print("Moving files in the train set...")
for biounit in tqdm(train_biounits):
Expand Down

0 comments on commit 7818f1f

Please sign in to comment.