Skip to content

Commit

Permalink
Merge pull request #131 from adaptyvbio/adjust_splitting
Browse files Browse the repository at this point in the history
Adjust splitting
  • Loading branch information
elkoz authored Dec 29, 2023
2 parents 93c76eb + 72ed6a6 commit 3728358
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 23 deletions.
8 changes: 4 additions & 4 deletions proteinflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,8 @@ def generate_data(
the sequence similarity threshold for excluding chains
exclude_clusters : bool, default False
if `True`, exclude clusters that contain chains similar to chains in the `exclude_chains` list
exclude_based_on_cdr : {"H1", "H2", "H3", "L1", "L2", "L3"}, optional
if given and `exclude_clusters` is `True` + the dataset is SAbDab, exclude files based on only the given CDR clusters
exclude_based_on_cdr : list, optional
if given and `exclude_clusters` is `True` + the dataset is SAbDab, exclude files based on only the given CDR clusters (choose from "H1", "H2", "H3", "L1", "L2", "L3")
load_ligands : bool, default False
if `True`, load ligands from the PDB files
exclude_chains_without_ligands : bool, default False
Expand Down Expand Up @@ -506,8 +506,8 @@ def split_data(
the sequence similarity threshold for excluding chains
exclude_clusters : bool, default False
if `True`, exclude clusters that contain chains similar to chains in the `exclude_chains` list
exclude_based_on_cdr : {"H1", "H2", "H3", "L1", "L2", "L3"}, optional
if given and `exclude_clusters` is `True` + the dataset is SAbDab, exclude files based on only the given CDR clusters
exclude_based_on_cdr : list, optional
if given and `exclude_clusters` is `True` + the dataset is SAbDab, exclude files based on only the given CDR clusters (choose from "H1", "H2", "H3", "L1", "L2", "L3")
random_seed : int, default 42
random seed for reproducibility (set to `None` to use a random seed)
exclude_chains_without_ligands : bool, default False
Expand Down
2 changes: 2 additions & 0 deletions proteinflow/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def download(**kwargs):
@click.option(
"--exclude_based_on_cdr",
type=click.Choice(["L1", "L2", "L3", "H1", "H2", "H3"]),
multiple=True,
help="if given and exclude_clusters is true + the dataset is SAbDab, exclude files based on only the given CDR clusters",
)
@click.option(
Expand Down Expand Up @@ -302,6 +303,7 @@ def generate(**kwargs):
@click.option(
"--exclude_based_on_cdr",
type=click.Choice(["L1", "L2", "L3", "H1", "H2", "H3"]),
multiple=True,
help="if given and exclude_clusters is true + the dataset is SAbDab, exclude files based on only the given CDR clusters",
)
@click.option(
Expand Down
31 changes: 12 additions & 19 deletions proteinflow/split/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1395,11 +1395,7 @@ def _exclude_biounits(
Since `proteinflow` assumes splitting at the level of biounits, when using `exclude_clusters` the dictionaries are adjusted to exclude
full biounits that the newly excluded chains / CDRs are part of. This is done by moving the full biounits to the excluded set
and removing the rest of the clusters they belong to from training / test / validation.
For example, if for antibody Ab CDR H1 is in cluster A, CDR H2 is in cluster B and CDR H3 is in cluster C, and cluster C is in
the excluded set, then clusters A and B are removed from the training / test / validation sets and added to the excluded set with only the Ab CDRs.
The files for the other biounits that are part of the excluded clusters are also moved to the excluded set but not added to the split dictionary.
and removing the corresponding entries from all training / test / validation clusters.
"""
set_to_exclude = set(excluded_biounits)
Expand All @@ -1416,8 +1412,9 @@ def _exclude_biounits(
for i, chain in enumerate(clusters_dict[cluster]):
if chain[0] in set_to_exclude:
if exclude_clusters:
if exclude_based_on_cdr is not None and cluster.endswith(
exclude_based_on_cdr
if (
exclude_based_on_cdr is not None
and cluster.split("__")[-1] in exclude_based_on_cdr
):
exclude_whole_cluster = True
elif exclude_based_on_cdr is None:
Expand All @@ -1444,19 +1441,15 @@ def _exclude_biounits(
test_clusters_dict,
]:
for cluster in list(clusters_dict.keys()):
excluded_biounit_in_cluster = False
if cluster in excluded_clusters_dict:
excluded_biounit_in_cluster = True
to_exclude = []
for i, (file, chain) in enumerate(clusters_dict[cluster]):
if file in excluded_biounits:
excluded_biounit_in_cluster = True
excluded_clusters_dict[cluster].append((file, chain))
excluded_biounits.add(file)
# remove cluster from training / validation / test set if at least one biounit in the cluster is excluded
if exclude_clusters and excluded_biounit_in_cluster:
chains = clusters_dict.pop(cluster)
excluded_biounits.update([x[0] for x in chains])
excluded_clusters_dict = {k: list(v) for k, v in excluded_clusters_dict.items()}
to_exclude.append(i)
clusters_dict[cluster] = [
x for i, x in enumerate(clusters_dict[cluster]) if i not in to_exclude
]
if len(clusters_dict[cluster]) == 0:
clusters_dict.pop(cluster)
return (
train_clusters_dict,
valid_clusters_dict,
Expand All @@ -1482,7 +1475,7 @@ def _split_data(
A list of files to exclude from the dataset
exclude_clusters : bool, default False
If True, exclude all files in a cluster if at least one file in the cluster is in `excluded_files`
exclude_based_on_cdr : str, optional
exclude_based_on_cdr : list, optional
If not `None`, exclude all files in a cluster if the cluster name does not end with `exclude_based_on_cdr`
"""
Expand Down

0 comments on commit 3728358

Please sign in to comment.