Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust splitting #131

Merged
merged 1 commit into from
Dec 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions proteinflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,8 @@ def generate_data(
the sequence similarity threshold for excluding chains
exclude_clusters : bool, default False
if `True`, exclude clusters that contain chains similar to chains in the `exclude_chains` list
exclude_based_on_cdr : {"H1", "H2", "H3", "L1", "L2", "L3"}, optional
if given and `exclude_clusters` is `True` + the dataset is SAbDab, exclude files based on only the given CDR clusters
exclude_based_on_cdr : list, optional
if given and `exclude_clusters` is `True` + the dataset is SAbDab, exclude files based on only the given CDR clusters (choose from "H1", "H2", "H3", "L1", "L2", "L3")
load_ligands : bool, default False
if `True`, load ligands from the PDB files
exclude_chains_without_ligands : bool, default False
Expand Down Expand Up @@ -506,8 +506,8 @@ def split_data(
the sequence similarity threshold for excluding chains
exclude_clusters : bool, default False
if `True`, exclude clusters that contain chains similar to chains in the `exclude_chains` list
exclude_based_on_cdr : {"H1", "H2", "H3", "L1", "L2", "L3"}, optional
if given and `exclude_clusters` is `True` + the dataset is SAbDab, exclude files based on only the given CDR clusters
exclude_based_on_cdr : list, optional
if given and `exclude_clusters` is `True` + the dataset is SAbDab, exclude files based on only the given CDR clusters (choose from "H1", "H2", "H3", "L1", "L2", "L3")
random_seed : int, default 42
random seed for reproducibility (set to `None` to use a random seed)
exclude_chains_without_ligands : bool, default False
Expand Down
2 changes: 2 additions & 0 deletions proteinflow/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def download(**kwargs):
@click.option(
"--exclude_based_on_cdr",
type=click.Choice(["L1", "L2", "L3", "H1", "H2", "H3"]),
multiple=True,
help="if given and exclude_clusters is true + the dataset is SAbDab, exclude files based on only the given CDR clusters",
)
@click.option(
Expand Down Expand Up @@ -302,6 +303,7 @@ def generate(**kwargs):
@click.option(
"--exclude_based_on_cdr",
type=click.Choice(["L1", "L2", "L3", "H1", "H2", "H3"]),
multiple=True,
help="if given and exclude_clusters is true + the dataset is SAbDab, exclude files based on only the given CDR clusters",
)
@click.option(
Expand Down
31 changes: 12 additions & 19 deletions proteinflow/split/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1395,11 +1395,7 @@ def _exclude_biounits(

Since `proteinflow` assumes splitting at the level of biounits, when using `exclude_clusters` the dictionaries are adjusted to exclude
full biounits that the newly excluded chains / CDRs are part of. This is done by moving the full biounits to the excluded set
and removing the rest of the clusters they belong to from training / test / validation.

For example, if for antibody Ab CDR H1 is in cluster A, CDR H2 is in cluster B and CDR H3 is in cluster C, and cluster C is in
the excluded set, then clusters A and B are removed from the training / test / validation sets and added to the excluded set with only the Ab CDRs.
The files for the other biounits that are part of the excluded clusters are also moved to the excluded set but not added to the split dictionary.
and removing the corresponding entries from all training / test / validation clusters.

"""
set_to_exclude = set(excluded_biounits)
Expand All @@ -1416,8 +1412,9 @@ def _exclude_biounits(
for i, chain in enumerate(clusters_dict[cluster]):
if chain[0] in set_to_exclude:
if exclude_clusters:
if exclude_based_on_cdr is not None and cluster.endswith(
exclude_based_on_cdr
if (
exclude_based_on_cdr is not None
and cluster.split("__")[-1] in exclude_based_on_cdr
):
exclude_whole_cluster = True
elif exclude_based_on_cdr is None:
Expand All @@ -1444,19 +1441,15 @@ def _exclude_biounits(
test_clusters_dict,
]:
for cluster in list(clusters_dict.keys()):
excluded_biounit_in_cluster = False
if cluster in excluded_clusters_dict:
excluded_biounit_in_cluster = True
to_exclude = []
for i, (file, chain) in enumerate(clusters_dict[cluster]):
if file in excluded_biounits:
excluded_biounit_in_cluster = True
excluded_clusters_dict[cluster].append((file, chain))
excluded_biounits.add(file)
# remove cluster from training / validation / test set if at least one biounit in the cluster is excluded
if exclude_clusters and excluded_biounit_in_cluster:
chains = clusters_dict.pop(cluster)
excluded_biounits.update([x[0] for x in chains])
excluded_clusters_dict = {k: list(v) for k, v in excluded_clusters_dict.items()}
to_exclude.append(i)
clusters_dict[cluster] = [
x for i, x in enumerate(clusters_dict[cluster]) if i not in to_exclude
]
if len(clusters_dict[cluster]) == 0:
clusters_dict.pop(cluster)
return (
train_clusters_dict,
valid_clusters_dict,
Expand All @@ -1482,7 +1475,7 @@ def _split_data(
A list of files to exclude from the dataset
exclude_clusters : bool, default False
If True, exclude all files in a cluster if at least one file in the cluster is in `excluded_files`
exclude_based_on_cdr : str, optional
exclude_based_on_cdr : list, optional
If not `None`, exclude all files in a cluster if the cluster name does not end with `exclude_based_on_cdr`

"""
Expand Down
Loading