From e6c03d0205745143fb37c10b36b39111d148d0db Mon Sep 17 00:00:00 2001 From: Liza Kozlova Date: Tue, 14 Nov 2023 11:26:29 +0000 Subject: [PATCH 1/2] fix: cli bug (+ sort cli arguments) --- proteinflow/__init__.py | 6 ++-- proteinflow/cli.py | 73 +++++++++++++++++++---------------------- 2 files changed, 37 insertions(+), 42 deletions(-) diff --git a/proteinflow/__init__.py b/proteinflow/__init__.py index 4509910..9344021 100644 --- a/proteinflow/__init__.py +++ b/proteinflow/__init__.py @@ -229,8 +229,8 @@ def generate_data( missing_middle_thr=0.1, not_filter_methods=False, not_remove_redundancies=False, - skip_splitting=False, redundancy_thr=0.9, + skip_splitting=False, n=None, force=False, split_tolerance=0.2, @@ -299,10 +299,10 @@ def generate_data( If `False`, only files obtained with X-ray or EM will be processed not_remove_redundancies : bool, default False If 'False', removes biounits that are doubles of others sequence wise + redundancy_thr : float, default 0.9 + The threshold upon which sequences are considered as one and the same (default: 0.9) skip_splitting : bool, default False if `True`, skip the split dictionary creation and the file moving steps - redundancy_thr : float, default 0.9 - The threshold upon which sequences are considered as one and the same (default: 90%) n : int, default None The number of files to process (for debugging purposes) force : bool, default False diff --git a/proteinflow/cli.py b/proteinflow/cli.py index f01cfe8..4a0bff7 100644 --- a/proteinflow/cli.py +++ b/proteinflow/cli.py @@ -107,16 +107,16 @@ def download(**kwargs): help="The threshold upon which sequences are considered as one and the same (default: 90%)", ) @click.option( - "--valid_split", - default=0.05, - type=float, - help="The percentage of chains to put in the validation set (default 5%)", + "--skip_splitting", is_flag=True, help="Use this flag to skip splitting the data" ) @click.option( - "--test_split", - default=0.05, - type=float, - help="The percentage of chains to put in the test set (default 5%)", + "--n", + default=None, + type=int, + help="The number of files to process (for debugging purposes)", +) +@click.option( + "--force", is_flag=True, help="When `True`, rewrite the files if they already exist" ) @click.option( "--split_tolerance", @@ -125,27 +125,22 @@ def download(**kwargs): help="The tolerance on the split ratio (default 20%)", ) @click.option( - "--n", - default=None, - type=int, - help="The number of files to process (for debugging purposes)", + "--test_split", + default=0.05, + type=float, + help="The percentage of chains to put in the test set (default 5%)", ) @click.option( - "--force", is_flag=True, help="When `True`, rewrite the files if they already exist" + "--valid_split", + default=0.05, + type=float, + help="The percentage of chains to put in the validation set (default 5%)", ) @click.option( "--pdb_snapshot", type=str, help="The pdb snapshot folder to load", ) -@click.option( - "--skip_splitting", is_flag=True, help="Use this flag to skip splitting the data" -) -@click.option( - "--skip_processing", - is_flag=True, - help="Use this flag to skip downloading and processing the data", -) @click.option( "--load_live", is_flag=True, @@ -172,11 +167,6 @@ def download(**kwargs): is_flag=True, help="Use this flag to require that the SAbDab files contain an antigen", ) -@click.option( - "--load_ligands", - is_flag=True, - help="Whether or not to load ligands found in the pdbs example: data['A']['ligand'][0]['X']", -) @click.option( "--exclude_chains", "-e", @@ -205,6 +195,11 @@ def download(**kwargs): type=click.Choice(["L1", "L2", "L3", "H1", "H2", "H3"]), help="if given and exclude_clusters is true + the dataset is SAbDab, exclude files based on only the given CDR clusters", ) +@click.option( + "--load_ligands", + is_flag=True, + help="Whether or not to load ligands found in the pdbs example: data['A']['ligand'][0]['X']", +) @click.option( "--exclude_chains_without_ligands", is_flag=True, @@ -253,9 +248,10 @@ def generate(**kwargs): help="The folder where proteinflow datasets, temporary files and logs will be stored", ) @click.option( - "--ignore_existing", - is_flag=True, - help="Unless this flag is used, proteinflow will not overwrite existing split dictionaries for this tag and will load them instead", + "--split_tolerance", + default=0.2, + type=float, + help="The tolerance on the split ratio (default 20%)", ) @click.option( "--valid_split", @@ -270,10 +266,9 @@ def generate(**kwargs): help="The percentage of chains to put in the test set (default 5%)", ) @click.option( - "--split_tolerance", - default=0.2, - type=float, - help="The tolerance on the split ratio (default 20%)", + "--ignore_existing", + is_flag=True, + help="Unless this flag is used, proteinflow will not overwrite existing split dictionaries for this tag and will load them instead", ) @click.option( "--min_seq_id", @@ -309,6 +304,12 @@ def generate(**kwargs): type=click.Choice(["L1", "L2", "L3", "H1", "H2", "H3"]), help="if given and exclude_clusters is true + the dataset is SAbDab, exclude files based on only the given CDR clusters", ) +@click.option( + "--random_seed", + default=42, + type=int, + help="The random seed to use for splitting", +) @click.option( "--exclude_chains_without_ligands", is_flag=True, @@ -324,12 +325,6 @@ def generate(**kwargs): is_flag=True, help="Whether to use FoldSeek to cluster the dataset", ) -@click.option( - "--random_seed", - default=42, - type=int, - help="The random seed to use for splitting", -) @cli.command( "split", help="Split an existing ProteinFlow dataset into training, validation and test subset according to MMseqs clustering and homomer/heteromer/single chain proportions", From d1d89f98b7801dd39efe359a0149332bda7eeaba Mon Sep 17 00:00:00 2001 From: Liza Kozlova Date: Tue, 14 Nov 2023 14:48:05 +0000 Subject: [PATCH 2/2] fix: dataset re-creation with force=False bug --- proteinflow/data/torch.py | 68 +++++++++++++++++++++------------------ 1 file changed, 37 insertions(+), 31 deletions(-) diff --git a/proteinflow/data/torch.py b/proteinflow/data/torch.py index 4dba6d2..ebd7692 100644 --- a/proteinflow/data/torch.py +++ b/proteinflow/data/torch.py @@ -309,6 +309,7 @@ def __init__( patch_around_mask=False, initial_patch_size=128, antigen_patch_size=128, + debug_verbose=False, ): """Initialize the dataset. @@ -381,6 +382,8 @@ def __init__( the size of the antigen patch (used if `patch_around_mask` is `True` and the dataset is SAbDab) """ + self.debug = debug_verbose + alphabet = ALPHABET self.alphabet_dict = defaultdict(lambda: 0) for i, letter in enumerate(alphabet): @@ -699,8 +702,6 @@ def _process(self, filename, rewrite=False, max_length=None, min_cdr_length=None output_names = [] if self.cut_edges: data_entry.cut_missing_edges() - if self.interpolate != "none": - data_entry.interpolate_coords(fill_ends=(self.interpolate == "all")) for chains_i, chain_set in enumerate(chain_sets): output_file = os.path.join( self.features_folder, no_extension_name + f"_{chains_i}.pickle" @@ -736,17 +737,49 @@ def _process(self, filename, rewrite=False, max_length=None, min_cdr_length=None if length > 0 ] ): - add_name = False pass_set = True + add_name = False if self.entry_type == "pair": if not data_entry.is_valid_pair(*chain_set): pass_set = True add_name = False + out = {} + if add_name: + cdr_chain_set = set() + if data_entry.has_cdr(): + out["cdr"] = torch.tensor( + data_entry.get_cdr(chain_set, encode=True) + ) + chain_type_dict = data_entry.get_chain_type_dict(chain_set) + out["chain_type_dict"] = chain_type_dict + if "heavy" in chain_type_dict: + cdr_chain_set.update( + [ + f"{chain_type_dict['heavy']}__{cdr}" + for cdr in ["H1", "H2", "H3"] + ] + ) + if "light" in chain_type_dict: + cdr_chain_set.update( + [ + f"{chain_type_dict['light']}__{cdr}" + for cdr in ["L1", "L2", "L3"] + ] + ) + output_names.append( + ( + os.path.basename(no_extension_name), + output_file, + chain_set if len(cdr_chain_set) == 0 else cdr_chain_set, + ) + ) if pass_set: continue - out = {} + if self.interpolate != "none": + data_entry.interpolate_coords(fill_ends=(self.interpolate == "all")) + out["pdb_id"] = no_extension_name.split("-")[0] out["mask_original"] = torch.tensor( data_entry.get_mask(chain_set, original=True) @@ -767,39 +800,12 @@ def _process(self, filename, rewrite=False, max_length=None, min_cdr_length=None out["ligand_smiles"], out["ligand_chains"], ) = data_entry.get_ligand_features(ligands, chain_set) - cdr_chain_set = set() - if data_entry.has_cdr(): - out["cdr"] = torch.tensor(data_entry.get_cdr(chain_set, encode=True)) - chain_type_dict = data_entry.get_chain_type_dict(chain_set) - out["chain_type_dict"] = chain_type_dict - if "heavy" in chain_type_dict: - cdr_chain_set.update( - [ - f"{chain_type_dict['heavy']}__{cdr}" - for cdr in ["H1", "H2", "H3"] - ] - ) - if "light" in chain_type_dict: - cdr_chain_set.update( - [ - f"{chain_type_dict['light']}__{cdr}" - for cdr in ["L1", "L2", "L3"] - ] - ) for name in self.feature_types: if name not in self.feature_functions: continue func = self.feature_functions[name] out[name] = torch.tensor(func(data_entry, chain_set)) - if add_name: - output_names.append( - ( - os.path.basename(no_extension_name), - output_file, - chain_set if len(cdr_chain_set) == 0 else cdr_chain_set, - ) - ) with open(output_file, "wb") as f: pickle.dump(out, f)