Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bugs #121

Merged
merged 2 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions proteinflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,8 @@ def generate_data(
missing_middle_thr=0.1,
not_filter_methods=False,
not_remove_redundancies=False,
skip_splitting=False,
redundancy_thr=0.9,
skip_splitting=False,
n=None,
force=False,
split_tolerance=0.2,
Expand Down Expand Up @@ -299,10 +299,10 @@ def generate_data(
If `False`, only files obtained with X-ray or EM will be processed
not_remove_redundancies : bool, default False
If 'False', removes biounits that are doubles of others sequence wise
redundancy_thr : float, default 0.9
The threshold upon which sequences are considered as one and the same (default: 0.9)
skip_splitting : bool, default False
if `True`, skip the split dictionary creation and the file moving steps
redundancy_thr : float, default 0.9
The threshold upon which sequences are considered as one and the same (default: 90%)
n : int, default None
The number of files to process (for debugging purposes)
force : bool, default False
Expand Down
73 changes: 34 additions & 39 deletions proteinflow/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,16 @@ def download(**kwargs):
help="The threshold upon which sequences are considered as one and the same (default: 90%)",
)
@click.option(
"--valid_split",
default=0.05,
type=float,
help="The percentage of chains to put in the validation set (default 5%)",
"--skip_splitting", is_flag=True, help="Use this flag to skip splitting the data"
)
@click.option(
"--test_split",
default=0.05,
type=float,
help="The percentage of chains to put in the test set (default 5%)",
"--n",
default=None,
type=int,
help="The number of files to process (for debugging purposes)",
)
@click.option(
"--force", is_flag=True, help="When `True`, rewrite the files if they already exist"
)
@click.option(
"--split_tolerance",
Expand All @@ -125,27 +125,22 @@ def download(**kwargs):
help="The tolerance on the split ratio (default 20%)",
)
@click.option(
"--n",
default=None,
type=int,
help="The number of files to process (for debugging purposes)",
"--test_split",
default=0.05,
type=float,
help="The percentage of chains to put in the test set (default 5%)",
)
@click.option(
"--force", is_flag=True, help="When `True`, rewrite the files if they already exist"
"--valid_split",
default=0.05,
type=float,
help="The percentage of chains to put in the validation set (default 5%)",
)
@click.option(
"--pdb_snapshot",
type=str,
help="The pdb snapshot folder to load",
)
@click.option(
"--skip_splitting", is_flag=True, help="Use this flag to skip splitting the data"
)
@click.option(
"--skip_processing",
is_flag=True,
help="Use this flag to skip downloading and processing the data",
)
@click.option(
"--load_live",
is_flag=True,
Expand All @@ -172,11 +167,6 @@ def download(**kwargs):
is_flag=True,
help="Use this flag to require that the SAbDab files contain an antigen",
)
@click.option(
"--load_ligands",
is_flag=True,
help="Whether or not to load ligands found in the pdbs example: data['A']['ligand'][0]['X']",
)
@click.option(
"--exclude_chains",
"-e",
Expand Down Expand Up @@ -205,6 +195,11 @@ def download(**kwargs):
type=click.Choice(["L1", "L2", "L3", "H1", "H2", "H3"]),
help="if given and exclude_clusters is true + the dataset is SAbDab, exclude files based on only the given CDR clusters",
)
@click.option(
"--load_ligands",
is_flag=True,
help="Whether or not to load ligands found in the pdbs example: data['A']['ligand'][0]['X']",
)
@click.option(
"--exclude_chains_without_ligands",
is_flag=True,
Expand Down Expand Up @@ -253,9 +248,10 @@ def generate(**kwargs):
help="The folder where proteinflow datasets, temporary files and logs will be stored",
)
@click.option(
"--ignore_existing",
is_flag=True,
help="Unless this flag is used, proteinflow will not overwrite existing split dictionaries for this tag and will load them instead",
"--split_tolerance",
default=0.2,
type=float,
help="The tolerance on the split ratio (default 20%)",
)
@click.option(
"--valid_split",
Expand All @@ -270,10 +266,9 @@ def generate(**kwargs):
help="The percentage of chains to put in the test set (default 5%)",
)
@click.option(
"--split_tolerance",
default=0.2,
type=float,
help="The tolerance on the split ratio (default 20%)",
"--ignore_existing",
is_flag=True,
help="Unless this flag is used, proteinflow will not overwrite existing split dictionaries for this tag and will load them instead",
)
@click.option(
"--min_seq_id",
Expand Down Expand Up @@ -309,6 +304,12 @@ def generate(**kwargs):
type=click.Choice(["L1", "L2", "L3", "H1", "H2", "H3"]),
help="if given and exclude_clusters is true + the dataset is SAbDab, exclude files based on only the given CDR clusters",
)
@click.option(
"--random_seed",
default=42,
type=int,
help="The random seed to use for splitting",
)
@click.option(
"--exclude_chains_without_ligands",
is_flag=True,
Expand All @@ -324,12 +325,6 @@ def generate(**kwargs):
is_flag=True,
help="Whether to use FoldSeek to cluster the dataset",
)
@click.option(
"--random_seed",
default=42,
type=int,
help="The random seed to use for splitting",
)
@cli.command(
"split",
help="Split an existing ProteinFlow dataset into training, validation and test subset according to MMseqs clustering and homomer/heteromer/single chain proportions",
Expand Down
68 changes: 37 additions & 31 deletions proteinflow/data/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def __init__(
patch_around_mask=False,
initial_patch_size=128,
antigen_patch_size=128,
debug_verbose=False,
):
"""Initialize the dataset.

Expand Down Expand Up @@ -381,6 +382,8 @@ def __init__(
the size of the antigen patch (used if `patch_around_mask` is `True` and the dataset is SAbDab)

"""
self.debug = debug_verbose

alphabet = ALPHABET
self.alphabet_dict = defaultdict(lambda: 0)
for i, letter in enumerate(alphabet):
Expand Down Expand Up @@ -699,8 +702,6 @@ def _process(self, filename, rewrite=False, max_length=None, min_cdr_length=None
output_names = []
if self.cut_edges:
data_entry.cut_missing_edges()
if self.interpolate != "none":
data_entry.interpolate_coords(fill_ends=(self.interpolate == "all"))
for chains_i, chain_set in enumerate(chain_sets):
output_file = os.path.join(
self.features_folder, no_extension_name + f"_{chains_i}.pickle"
Expand Down Expand Up @@ -736,17 +737,49 @@ def _process(self, filename, rewrite=False, max_length=None, min_cdr_length=None
if length > 0
]
):
add_name = False
pass_set = True
add_name = False

if self.entry_type == "pair":
if not data_entry.is_valid_pair(*chain_set):
pass_set = True
add_name = False
out = {}
if add_name:
cdr_chain_set = set()
if data_entry.has_cdr():
out["cdr"] = torch.tensor(
data_entry.get_cdr(chain_set, encode=True)
)
chain_type_dict = data_entry.get_chain_type_dict(chain_set)
out["chain_type_dict"] = chain_type_dict
if "heavy" in chain_type_dict:
cdr_chain_set.update(
[
f"{chain_type_dict['heavy']}__{cdr}"
for cdr in ["H1", "H2", "H3"]
]
)
if "light" in chain_type_dict:
cdr_chain_set.update(
[
f"{chain_type_dict['light']}__{cdr}"
for cdr in ["L1", "L2", "L3"]
]
)
output_names.append(
(
os.path.basename(no_extension_name),
output_file,
chain_set if len(cdr_chain_set) == 0 else cdr_chain_set,
)
)
if pass_set:
continue

out = {}
if self.interpolate != "none":
data_entry.interpolate_coords(fill_ends=(self.interpolate == "all"))

out["pdb_id"] = no_extension_name.split("-")[0]
out["mask_original"] = torch.tensor(
data_entry.get_mask(chain_set, original=True)
Expand All @@ -767,39 +800,12 @@ def _process(self, filename, rewrite=False, max_length=None, min_cdr_length=None
out["ligand_smiles"],
out["ligand_chains"],
) = data_entry.get_ligand_features(ligands, chain_set)
cdr_chain_set = set()
if data_entry.has_cdr():
out["cdr"] = torch.tensor(data_entry.get_cdr(chain_set, encode=True))
chain_type_dict = data_entry.get_chain_type_dict(chain_set)
out["chain_type_dict"] = chain_type_dict
if "heavy" in chain_type_dict:
cdr_chain_set.update(
[
f"{chain_type_dict['heavy']}__{cdr}"
for cdr in ["H1", "H2", "H3"]
]
)
if "light" in chain_type_dict:
cdr_chain_set.update(
[
f"{chain_type_dict['light']}__{cdr}"
for cdr in ["L1", "L2", "L3"]
]
)
for name in self.feature_types:
if name not in self.feature_functions:
continue
func = self.feature_functions[name]
out[name] = torch.tensor(func(data_entry, chain_set))

if add_name:
output_names.append(
(
os.path.basename(no_extension_name),
output_file,
chain_set if len(cdr_chain_set) == 0 else cdr_chain_set,
)
)
with open(output_file, "wb") as f:
pickle.dump(out, f)

Expand Down
Loading