diff --git a/proteinflow/__init__.py b/proteinflow/__init__.py index fc69cb8..58575c9 100644 --- a/proteinflow/__init__.py +++ b/proteinflow/__init__.py @@ -444,6 +444,7 @@ def split_data( random_seed=42, exclude_chains_without_ligands=False, tanimoto_clustering=False, + foldseek=False, ): """Split `proteinflow` entry files into training, test and validation. @@ -496,6 +497,8 @@ def split_data( if `True`, exclude biounits that don't contain ligands tanimoto_clustering: bool, default False cluster chains based on the tanimoto similarity of their ligands + foldseek: bool, default False + if `True`, use FoldSeek to cluster chains based on their structure similarity Returns ------- @@ -545,6 +548,7 @@ def split_data( out_split_dict_folder=out_split_dict_folder, min_seq_id=min_seq_id, tanimoto_clustering=tanimoto_clustering, + foldseek=foldseek, ) shutil.rmtree(temp_folder) diff --git a/proteinflow/cli.py b/proteinflow/cli.py index 2648397..c92468a 100644 --- a/proteinflow/cli.py +++ b/proteinflow/cli.py @@ -281,6 +281,11 @@ def generate(**kwargs): is_flag=True, help="Whether to use Tanimoto Clustering instead of MMSeqs2. Only works if the dataset contains ligands", ) +@click.option( + "--foldseek", + is_flag=True, + help="Whether to use FoldSeek to cluster the dataset", +) @click.option( "--random_seed", default=42, diff --git a/proteinflow/split/__init__.py b/proteinflow/split/__init__.py index 737a42f..ef48c10 100644 --- a/proteinflow/split/__init__.py +++ b/proteinflow/split/__init__.py @@ -18,7 +18,7 @@ import numpy as np from tqdm import tqdm -from proteinflow.data import PDBEntry +from proteinflow.data import PDBEntry, ProteinEntry from proteinflow.ligand import ( _load_smiles, _merge_chains_ligands, @@ -71,9 +71,34 @@ def _run_mmseqs2(fasta_file, tmp_folder, min_seq_id, cdr=None): subprocess.run(args) -def _read_clusters(tmp_folder, cdr=None): +def _run_foldseek(data_folder, tmp_folder, min_seq_id): + """Run the FoldSeek command with the parameters we want. + + Results are stored in the tmp_folder/MMSeqs2 directory. + """ - Read the output from MMSeqs2 and produces 2 dictionaries that store the clusters information. + folder = "MMSeqs2_results" + os.makedirs(os.path.join(tmp_folder, folder), exist_ok=True) + method = "easy-cluster" + args = [ + "foldseek", + method, + data_folder, + os.path.join(tmp_folder, folder, "clusterRes"), + os.path.join(tmp_folder, folder, "tmp"), + "--min-seq-id", + str(min_seq_id), + "--chain-name-mode", + "1", + "-v", + "1", + ] + subprocess.run(args) + subprocess.run(["rm", "-r", os.path.join(tmp_folder, folder, "tmp")]) + + +def _read_clusters(tmp_folder, cdr=None): + """Read the output from MMSeqs2 and produces 2 dictionaries that store the clusters information. In cluster_dict, values are the full names (pdb + chains) whereas in cluster_pdb_dict, values are just the PDB ids (so less clusters but bigger). """ @@ -95,11 +120,19 @@ def _read_clusters(tmp_folder, cdr=None): if line[0] == ">" and found_header: cluster_name = line[1:-1] sequence_name = line[1:-1] + cluster_name = "".join(cluster_name.split(".pdb")) + sequence_name = "".join(sequence_name.split(".pdb")) + if "-" in cluster_name: + cluster_name = cluster_name[:4] + cluster_name[6:] + sequence_name = sequence_name[:4] + sequence_name[6:] if cdr is not None: cluster_name += "__" + cdr elif line[0] == ">": sequence_name = line[1:-1] + sequence_name = "".join(sequence_name.split(".pdb")) + if "-" in sequence_name: + sequence_name = sequence_name[:4] + sequence_name[6:] found_header = True else: @@ -109,15 +142,17 @@ def _read_clusters(tmp_folder, cdr=None): for k in cluster_pdb_dict.keys(): cluster_pdb_dict[k] = np.unique(cluster_pdb_dict[k]) + print(f"{cluster_dict=}") + print(f"{cluster_pdb_dict=}") return cluster_dict, cluster_pdb_dict def _make_graph(cluster_pdb_dict): - """ - Produce a graph that relates clusters together. + """Produce a graph that relates clusters together. Connections represent a PDB shared by 2 clusters. The more shared PDBs, the stronger the connection. + """ keys = list(cluster_pdb_dict.keys()) keys_mapping = {length: k for length, k in enumerate(keys)} @@ -195,10 +230,10 @@ def _divide_according_to_chains_interactions(pdb_seqs_dict, dataset_dir): def _find_chains_in_graph( graph, clusters_dict, biounit_chains_array, pdbs_array, chains_array ): - """ - Find all the biounit chains present in a given graph or subgraph. + """Find all the biounit chains present in a given graph or subgraph. Return a dictionary for which each key is a cluster name (merged chains name) and the values are all the biounit chains contained in this cluster. + """ res_dict = {} for k, node in enumerate(graph): @@ -224,11 +259,11 @@ def _find_chains_in_graph( def _find_repartition(chains_dict, homomers, heteromers): - """ - Return a dictionary similar to the one created by find_chains_in_graph, with an additional level of classification for single chains, homomers and heteromers. + """Return a dictionary similar to the one created by find_chains_in_graph, with an additional level of classification for single chains, homomers and heteromers. Dictionary structure : `{'single_chains' : {cluster_name : [biounit chains]}, 'homomers' : {cluster_name : [biounit chains]}, 'heteromers' : {cluster_name : [biounit chains]}}`. Additionally return the number of chains in each class (single chains, ...). + """ classes_dict = { "single_chains": defaultdict(lambda: []), @@ -272,12 +307,12 @@ def _find_subgraphs_infos( homomers, heteromers, ): - """ - Given a list of subgraphs, return a list of dictionaries and an array of sizes of the same length. + """Given a list of subgraphs, return a list of dictionaries and an array of sizes of the same length. Dictionaries are the `chains_dict` and `classes_dict` corresponding to each subgraph, returned by the `find_chains_in_graph`. and `find_repartition` functions respectively. The array of sizes is of shape (len(subgraph), 3). It gives the number of single chains, homomers and heteromers present in each subgraph. + """ size_array = np.zeros((len(subgraphs), 3)) dict_list = [] @@ -301,12 +336,12 @@ def _find_subgraphs_infos( def _construct_dataset(dict_list, size_array, indices): - """ - Get a supergraph containing all subgraphs indicated by `indices`. + """Get a supergraph containing all subgraphs indicated by `indices`. Given the `dict_list` and `size_array` returned by `find_subgraphs_info`, return the 2 dictionaries (`chains_dict` and `classes_dict`). corresponding to the graph encompassing all the subgraphs indicated by indices. Additionally return the number of single chains, homomers and heteromers in this supergraph. + """ dataset_clusters_dict = {} dataset_classes_dict = {"single_chains": {}, "homomers": {}, "heteromers": {}} @@ -340,10 +375,10 @@ def _remove_elements_from_dataset( size_array, tolerance=0.2, ): - """ - Remove values from indices until we get the required (`size_obj`) number of chains in the class of interest (`chain_class`). + """Remove values from indices until we get the required (`size_obj`) number of chains in the class of interest (`chain_class`). Parameter `chain_class` corresponds to the single chain (0), homomer (1) or heteromer (2) class. + """ sizes = [s[chain_class] for s in size_array[indices]] sorted_sizes_indices = np.argsort(sizes)[::-1] @@ -396,10 +431,10 @@ def _add_elements_to_dataset( size_array, tolerance=0.2, ): - """ - Add values to indices until we get the required (`size_obj`) number of chains in the class of interest (`chain_class`). + """Add values to indices until we get the required (`size_obj`) number of chains in the class of interest (`chain_class`). Parameter `chain_class` corresponds to the single chain (0), homomer (1) or heteromer (2) class. + """ sizes = [s[chain_class] for s in size_array[remaining_indices]] sorted_sizes_indices = np.argsort(sizes)[::-1] @@ -446,11 +481,11 @@ def _adjust_dataset( ht_available, tolerance=0.2, ): - """ - If required, remove and add values in indices so that the number of chains in each class correspond to the required numbers within a tolerance. + """If required, remove and add values in indices so that the number of chains in each class correspond to the required numbers within a tolerance. First remove and then add (if necessary, for each class separately). In the end, we might end up with more chains than desired in the first 2 classes but for a reasonable tolerance (~10-20 %), this should not happen. + """ if single_chains_size > (1 + tolerance) * n_single_chains and sc_available: ( @@ -582,12 +617,12 @@ def _fill_dataset( n_max_iter=100, tolerance=0.2, ): - """ - Construct a dataset from subgraphs indicated by `indices`. + """Construct a dataset from subgraphs indicated by `indices`. Given a list of indices to choose from (`remaining_indices`), choose a list of subgraphs to construct a dataset containing the required number of. biounits for each class (single chains, ...) within a tolerance. Return the same outputs as the construct_dataset function, as long as the list of remaining indices after selection. + """ single_chains_size, homomers_size, heteromers_size = 0, 0, 0 sc_available, hm_available, ht_available = _test_availability( @@ -721,10 +756,10 @@ def _get_subgraph_files( chain_arr, files_arr, ): - """ - Given a list of subgraphs, return a dictionary. + """Given a list of subgraphs, return a dictionary. + + Of the form `{cluster: [(filename, chain__cdr)]}`. - Of the form {cluster: [(filename, chain__cdr)]}. """ out = {} # cluster: [(file, chain__cdr)] for subgraph in subgraphs: @@ -746,10 +781,10 @@ def _split_subgraphs( num_clusters_test, tolerance, ): - """ - Split the list of subgraphs into three sets (train, valid, test). + """Split the list of subgraphs into three sets (train, valid, test). According to the number of biounits in each subgraph. + """ for _ in range(50): indices = np.random.permutation(np.arange(1, len(lengths))) @@ -844,6 +879,7 @@ def _split_dataset_with_graphs( the list of all biounit chains (string names) that are in a homomeric state (in their biounit) heteromers : list the list of all biounit chains (string names) that are in a heteromeric state (in their biounit) + """ sample_cluster = list(clusters_dict.keys())[0] sabdab = "__" in sample_cluster @@ -1044,14 +1080,16 @@ def _build_dataset_partition( min_seq_id=0.3, sabdab=False, tanimoto_clustering=False, + foldseek=False, ): - """ - Build training, validation and test sets from a curated dataset of biounit, using MMSeqs2 for clustering. + """Build training, validation and test sets from a curated dataset of biounit, using MMSeqs2 for clustering. Parameters ---------- dataset_dir : str the path to the dataset + tmp_folder : str + the path to a temporary folder to store temporary files valid_split : float in [0, 1], default 0.05 the validation split ratio test_split : float in [0, 1], default 0.05 @@ -1062,6 +1100,8 @@ def _build_dataset_partition( whether the dataset is the SAbDab dataset or not tanimoto_clustering: bool, default False whether to cluster chains based on Tanimoto Clustering + foldseek: bool, default False + whether to cluster chains based on FoldSeek Output ------ @@ -1092,31 +1132,48 @@ def _build_dataset_partition( merged_seqs_dict, min_seq_id, tmp_folder ) else: - cdrs = ["L1", "L2", "L3", "H1", "H2", "H3"] if sabdab else [None] - for cdr in cdrs: - if cdr is not None: - print(f"Clustering with MMSeqs2 for CDR {cdr}...") - else: - print("Clustering with MMSeqs2...") - # retrieve all sequences and create a merged_seqs_dict - merged_seqs_dict = _load_pdbs( - dataset_dir, cdr=cdr - ) # keys: pdb_id, values: list of chains and sequences - lengths = [] - for k, v in merged_seqs_dict.items(): - lengths += [len(x[1]) for x in v] - merged_seqs_dict = _merge_chains( - merged_seqs_dict - ) # remove redundant chains - # write sequences to a fasta file for clustering with MMSeqs2, run MMSeqs2 and delete the fasta file - fasta_file = os.path.join(tmp_folder, "all_seqs.fasta") - _write_fasta( - fasta_file, merged_seqs_dict - ) # write all sequences from merged_seqs_dict to fasta file - _run_mmseqs2( - fasta_file, tmp_folder, min_seq_id, cdr=cdr - ) # run MMSeqs2 on fasta file - subprocess.run(["rm", fasta_file]) + if foldseek: + print("Clustering with FoldSeek...") + if os.path.exists(os.path.join(tmp_folder, "pdbs")): + subprocess.run(["rm", "-r", os.path.join(tmp_folder, "pdbs")]) + os.mkdir(os.path.join(tmp_folder, "pdbs")) + for file in tqdm(os.listdir(dataset_dir)): + if not file.endswith(".pickle"): + continue + ProteinEntry.from_pickle(os.path.join(dataset_dir, file)).to_pdb( + os.path.join(tmp_folder, "pdbs", file.split(".")[0] + ".pdb") + ) + _run_foldseek( + os.path.join(tmp_folder, "pdbs"), tmp_folder, min_seq_id=min_seq_id + ) + cdrs = [None] + merged_seqs_dict = _load_pdbs(dataset_dir, cdr=None) + else: + cdrs = ["L1", "L2", "L3", "H1", "H2", "H3"] if sabdab else [None] + for cdr in cdrs: + if cdr is not None: + print(f"Clustering with MMSeqs2 for CDR {cdr}...") + else: + print("Clustering with MMSeqs2...") + # retrieve all sequences and create a merged_seqs_dict + merged_seqs_dict = _load_pdbs( + dataset_dir, cdr=cdr + ) # keys: pdb_id, values: list of chains and sequences + lengths = [] + for k, v in merged_seqs_dict.items(): + lengths += [len(x[1]) for x in v] + merged_seqs_dict = _merge_chains( + merged_seqs_dict + ) # remove redundant chains + # write sequences to a fasta file for clustering with MMSeqs2, run MMSeqs2 and delete the fasta file + fasta_file = os.path.join(tmp_folder, "all_seqs.fasta") + _write_fasta( + fasta_file, merged_seqs_dict + ) # write all sequences from merged_seqs_dict to fasta file + _run_mmseqs2( + fasta_file, tmp_folder, min_seq_id, cdr=cdr + ) # run MMSeqs2 on fasta file + subprocess.run(["rm", fasta_file]) # retrieve MMSeqs2 clusters and build a graph with these clusters clusters_dict = {} @@ -1174,6 +1231,7 @@ def _get_split_dictionaries( out_split_dict_folder="./data/dataset_splits_dict", min_seq_id=0.3, tanimoto_clustering=False, + foldseek=False, ): """Split preprocessed data into training, validation and test. @@ -1193,6 +1251,10 @@ def _get_split_dictionaries( The folder where the dictionaries containing the train/validation/test splits information will be saved" min_seq_id : float in [0, 1], default 0.3 minimum sequence identity for `mmseqs` + tanimoto_clustering: bool, default False + whether to cluster chains based on Tanimoto Clustering + foldseek: bool, default False + whether to cluster chains based on FoldSeek """ if len([x for x in os.listdir(output_folder) if x.endswith(".pickle")]) == 0: @@ -1201,6 +1263,11 @@ def _get_split_dictionaries( ind = sample_file.split(".")[0].split("-")[1] sabdab = not ind.isnumeric() + if sabdab and tanimoto_clustering: + raise RuntimeError("Tanimoto Clustering cannot be used with SAbDab data") + if sabdab and foldseek: + raise RuntimeError("FoldSeek cannot be used with SAbDab data") + os.makedirs(out_split_dict_folder, exist_ok=True) ( train_clusters_dict, @@ -1218,6 +1285,7 @@ def _get_split_dictionaries( min_seq_id=min_seq_id, sabdab=sabdab, tanimoto_clustering=tanimoto_clustering, + foldseek=foldseek, ) classes_dict = train_classes_dict