diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..a9d6575
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2024 Jeremy Wohlwend, Gabriele Corso, Saro Passaro
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..c067611
--- /dev/null
+++ b/README.md
@@ -0,0 +1,65 @@
+
Boltz-1:
+
+Democratizing Biomolecular Interaction Modeling
+
+
+
+
+Boltz-1 is an open-source model which predicts the 3D structure of proteins, rna, dna and small molecules; it handles modified residues, covalent ligands and glycans, as well as condition the generation on pocket residues.
+
+For more information about the model, see our [technical report](https://gcorso.github.io/assets/boltz1.pdf).
+
+## Installation
+Install boltz with PyPI (recommended):
+
+```
+pip install boltz
+```
+
+or directly from GitHub for daily updates:
+
+```
+git clone https://github.com/jwohlwend/boltz.git
+cd boltz; pip install -e .
+```
+> Note: we recommend installing boltz in a fresh python environment
+
+## Inference
+
+You can run inference using Boltz-1 with:
+
+```
+boltz predict input_path
+```
+
+Boltz currently accepts three input formats:
+
+1. Fasta file, for most use cases
+
+2. A comprehensive YAML schema, for more complex use cases
+
+3. A directory containing files of the above formats, for batched processing
+
+To see all available options: `boltz predict --help` and for more informaton on these input formats, see our [prediction instructions](docs/prediction.md).
+
+## Training
+
+If you're interested in retraining the model, see our [training instructions](docs/training.md).
+
+## Contributing
+
+We welcome external contributions and are eager to engage with the community. Connect with us on our [Slack channel](https://boltz-community.slack.com/archives/C0818M6DWH2) to discuss advancements, share insights, and foster collaboration around Boltz-1.
+
+## Coming very soon
+
+- [ ] Pocket conditioning support
+- [ ] More examples
+- [ ] Full data processing pipeline
+- [ ] Colab notebook for inference
+- [ ] Confidence model checkpoint
+- [ ] Support for custom paired MSA
+- [ ] Kernel integration
+
+## License
+
+Our model and code are released under MIT License, and can be freely used for both academic and commercial purposes.
diff --git a/docs/boltz1_pred_figure.png b/docs/boltz1_pred_figure.png
new file mode 100644
index 0000000..5e72667
Binary files /dev/null and b/docs/boltz1_pred_figure.png differ
diff --git a/docs/prediction.md b/docs/prediction.md
new file mode 100644
index 0000000..8026471
--- /dev/null
+++ b/docs/prediction.md
@@ -0,0 +1,140 @@
+# Prediction
+
+Once you have installed `boltz`, you can start making predictions by simply running:
+
+`boltz predict `
+
+where `` is a path to the input file or a directory. The input file can either be in fasta (enough for most use cases) or YAML format (for more complex inputs). If you specify a directory, `boltz` will run predictions on each `.yaml` or `.fasta` file in the directory.
+
+Before diving into more details about the input formats, here are the key differences in what they each support:
+
+| Feature | Fasta | YAML |
+| -------- |--------------------| ------- |
+| Polymers | :white_check_mark: | :white_check_mark: |
+| Smiles | :white_check_mark: | :white_check_mark: |
+| CCD code | :white_check_mark: | :white_check_mark: |
+| Custom MSA | :white_check_mark: | :white_check_mark: |
+| Modified Residues | :x: | :white_check_mark: |
+| Covalent bonds | :x: | :white_check_mark: |
+| Pocket conditioning | :x: | :white_check_mark: |
+
+
+
+## Fasta format
+
+The fasta format should contain entries as follows:
+
+```
+>CHAIN_ID|ENTITY_TYPE|MSA_PATH
+SEQUENCE
+```
+
+Where `CHAIN_ID` is a unique identifier for each input chain, `ENTITY_TYPE` can be one of `protein`, `dna`, `rna`, `smiles`, `ccd` and `MSA_PATH` is only specified for protein entities and is the path to the `.a3m` file containing a computed MSA for the sequence of the protein. Note that we support both smiles and CCD code for ligands.
+
+For each of these cases, the corresponding `SEQUENCE` will contain an amino acid sequence (e.g. `EFKEAFSLF`), a sequence of nucleotide bases (e.g. `ATCG`), a smiles string (e.g. `CC1=CC=CC=C1`), or a CCD code (e.g. `ATP`), depending on the entity.
+
+As an example:
+
+```yaml
+>A|protein|./examples/msa/seq1.a3m
+MVTPEGNVSLVDESLLVGVTDEDRAVRSAHQFYERLIGLWAPAVMEAAHELGVFAALAEAPADSGELARRLDCDARAMRVLLDALYAYDVIDRIHDTNGFRYLLSAEARECLLPGTLFSLVGKFMHDINVAWPAWRNLAEVVRHGARDTSGAESPNGIAQEDYESLVGGINFWAPPIVTTLSRKLRASGRSGDATASVLDVGCGTGLYSQLLLREFPRWTATGLDVERIATLANAQALRLGVEERFATRAGDFWRGGWGTGYDLVLFANIFHLQTPASAVRLMRHAAACLAPDGLVAVVDQIVDADREPKTPQDRFALLFAASMTNTGGGDAYTFQEYEEWFTAAGLQRIETLDTPMHRILLARRATEPSAVPEGQASENLYFQ
+>B|protein|./examples/msa/seq1.a3m
+MVTPEGNVSLVDESLLVGVTDEDRAVRSAHQFYERLIGLWAPAVMEAAHELGVFAALAEAPADSGELARRLDCDARAMRVLLDALYAYDVIDRIHDTNGFRYLLSAEARECLLPGTLFSLVGKFMHDINVAWPAWRNLAEVVRHGARDTSGAESPNGIAQEDYESLVGGINFWAPPIVTTLSRKLRASGRSGDATASVLDVGCGTGLYSQLLLREFPRWTATGLDVERIATLANAQALRLGVEERFATRAGDFWRGGWGTGYDLVLFANIFHLQTPASAVRLMRHAAACLAPDGLVAVVDQIVDADREPKTPQDRFALLFAASMTNTGGGDAYTFQEYEEWFTAAGLQRIETLDTPMHRILLARRATEPSAVPEGQASENLYFQ
+>C|ccd
+SAH
+>D|ccd
+SAH
+>E|smiles
+N[C@@H](Cc1ccc(O)cc1)C(=O)O
+>F|smiles
+N[C@@H](Cc1ccc(O)cc1)C(=O)O
+```
+
+
+## YAML format
+
+The YAML format is more flexible and allows for more complex inputs, particularly around covalent bonds. The schema of the YAML is the following:
+
+```yaml
+sequences:
+ - ENTITY_TYPE:
+ id: CHAIN_ID
+ sequence: SEQUENCE # only for protein, dna, rna
+ smiles: SMILES # only for ligand, exclusive with ccd
+ ccd: CCD # only for ligand, exclusive with smiles
+ msa: MSA_PATH # only for protein
+ modifications:
+ - position: RES_IDX # index of residue, starting from 1
+ ccd: CCD # CCD code of the modified residue
+
+ - ENTITY_TYPE:
+ id: [CHAIN_ID, CHAIN_ID] # multiple ids in case of multiple identical entities
+ ...
+constraints:
+ - bond:
+ atom1: [CHAIN_ID, RES_IDX, ATOM_NAME]
+ atom2: [CHAIN_ID, RES_IDX, ATOM_NAME]
+ - pocket:
+ binder: CHAIN_ID
+ contacts: [[CHAIN_ID, RES_IDX], [CHAIN_ID, RES_IDX]]
+```
+`sequences` has one entry for every unique chain/molecule in the input. Each polymer entity as a `ENTITY_TYPE` either `protein`, `dna` or`rna` and have a `sequence` attribute. Non-polymer entities are indicated by `ENTITY_TYPE` equal to `ligand` and have a `smiles` or `ccd` attribute. `CHAIN_ID` is the unique identifier for each chain/molecule, and it should be set as a list in case of multiple identical entities in the structure. Protein entities should also contain an `msa` attribute with `MSA_PATH` indicating the path to the `.a3m` file containing a computed MSA for the sequence of the protein.
+
+The `modifications` field is an optional field that allows you to specify modified residues in the polymer (`protein`, `dna` or`rna`). The `position` field specifies the index (starting from 1) of the residue, and `ccd` is the CCD code of the modified residue. This field is currently only supported for CCD ligands.
+
+`constraints` is an optional field that allows you to specify additional information about the input structure. Currently, we support just `bond`. The `bond` constraint specifies a covalent bonds between two atoms (`atom1` and `atom2`). It is currently only supported for CCD ligands and canonical residues, `CHAIN_ID` refers to the id of the residue set above, `RES_IDX` is the index (starting from 1) of the residue (1 for ligands), and `ATOM_NAME` is the standardized atom name (can be verified in CIF file of that component on the RCSB website).
+
+As an example:
+
+```yaml
+version: 1
+sequences:
+ - protein:
+ id: [A, B]
+ sequence: MVTPEGNVSLVDESLLVGVTDEDRAVRSAHQFYERLIGLWAPAVMEAAHELGVFAALAEAPADSGELARRLDCDARAMRVLLDALYAYDVIDRIHDTNGFRYLLSAEARECLLPGTLFSLVGKFMHDINVAWPAWRNLAEVVRHGARDTSGAESPNGIAQEDYESLVGGINFWAPPIVTTLSRKLRASGRSGDATASVLDVGCGTGLYSQLLLREFPRWTATGLDVERIATLANAQALRLGVEERFATRAGDFWRGGWGTGYDLVLFANIFHLQTPASAVRLMRHAAACLAPDGLVAVVDQIVDADREPKTPQDRFALLFAASMTNTGGGDAYTFQEYEEWFTAAGLQRIETLDTPMHRILLARRATEPSAVPEGQASENLYFQ
+ msa: ./examples/msa/seq1.a3m
+ - ligand:
+ id: [C, D]
+ ccd: SAH
+ - ligand:
+ id: [E, F]
+ smiles: N[C@@H](Cc1ccc(O)cc1)C(=O)O
+```
+
+
+## Options
+
+The following options are available for the `predict` command:
+
+ boltz predict [OPTIONS] input_path
+
+| **Option** | **Type** | **Default** | **Description** |
+|-----------------------------|-----------------|--------------------|---------------------------------------------------------------------------------|
+| `--out_dir PATH` | `PATH` | `./` | The path where to save the predictions. |
+| `--cache PATH` | `PATH` | `~/.boltz` | The directory where to download the data and model. |
+| `--checkpoint PATH` | `PATH` | None | An optional checkpoint. Uses the provided Boltz-1 model by default. |
+| `--devices INTEGER` | `INTEGER` | `1` | The number of devices to use for prediction. |
+| `--accelerator` | `[gpu,cpu,tpu]` | `gpu` | The accelerator to use for prediction. |
+| `--recycling_steps INTEGER` | `INTEGER` | `3` | The number of recycling steps to use for prediction. |
+| `--sampling_steps INTEGER` | `INTEGER` | `200` | The number of sampling steps to use for prediction. |
+| `--diffusion_samples INTEGER` | `INTEGER` | `1` | The number of diffusion samples to use for prediction. |
+| `--output_format` | `[pdb,mmcif]` | `mmcif` | The output format to use for the predictions. |
+| `--num_workers INTEGER` | `INTEGER` | `2` | The number of dataloader workers to use for prediction. |
+| `--override` | `FLAG` | `False` | Whether to override existing predictions if found. |
+
+## Output
+
+After running the model, the generated outputs are organized into the output directory following the structure below:
+```
+out_dir/
+├── lightning_logs/ # Logs generated during training or evaluation
+├── predictions/ # Contains the model's predictions
+ ├── [input_file1]/
+ ├── [input_file1]_model_0.cif # The predicted structure in CIF format
+ ...
+ └── [input_file1]_model_[diffusion_samples-1].cif # The predicted structure in CIF format
+ └── [input_file2]/
+ ...
+└── processed/ # Processed data used during execution
+```
+The `predictions` folder contains a unique folder for each input file. The input folders contain diffusion_samples predictions saved in the output_format. The `processed` folder contains the processed input files that are used by the model during inference.
diff --git a/docs/training.md b/docs/training.md
new file mode 100644
index 0000000..14faeb6
--- /dev/null
+++ b/docs/training.md
@@ -0,0 +1,47 @@
+# Training
+
+## Download processed data
+
+Instructions on how to download the processed dataset for training are coming soon, we are currently uploading the data to sharable storage and will update this page when ready.
+
+## Modify the configuration file
+
+The training script requires a configuration file to run. This file specifies the paths to the data, the output directory, and other parameters of the data, model and training process.
+
+We provide under `scripts/train/configs` a template configuration file analogous to the one we used for training the structure model (`structure.yaml`) and the confidence model (`confidence.yaml`).
+
+The following are the main parameters that you should modify in the configuration file to get the structure model to train:
+
+```yaml
+trainer:
+ devices: 1
+
+output: SET_PATH_HERE # Path to the output directory
+resume: PATH_TO_CHECKPOINT_FILE # Path to a checkpoint file to resume training from if any null otherwise
+
+data:
+ datasets:
+ - _target_: boltz.data.module.training.DatasetConfig
+ target_dir: PATH_TO_TARGETS_DIR # Path to the directory containing the processed structure files
+ msa_dir: PATH_TO_MSA_DIR # Path to the directory containing the processed MSA files
+
+ symmetries: PATH_TO_SYMMETRY_FILE # Path to the file containing molecule the symmetry information
+ max_tokens: 512 # Maximum number of tokens in the input sequence
+ max_atoms: 4608 # Maximum number of atoms in the input structure
+```
+
+`max_tokens` and `max_atoms` are the maximum number of tokens and atoms in the crop. Depending on the size of the GPUs you are using (as well as the training speed desired), you may want to adjust these values. Other recommended values are 256 and 2304, or 384 and 3456 respectively.
+
+## Run the training script
+
+Before running the full training, we recommend using the debug flag. This turns off DDP (sets single device) and set `num_workers` to 0 so everything is in a single process, as well as disabling wandb:
+
+ python scripts/train/train.py scripts/train/configs/structure.yaml debug=1
+
+Once that seems to run okay, you can kill it and launch the training run:
+
+ python scripts/train/train.py scripts/train/configs/structure.yaml
+
+We also provide a different configuration file to train the confidence model:
+
+ python scripts/train/train.py scripts/train/configs/confidence.yaml
\ No newline at end of file
diff --git a/examples/ligand.fasta b/examples/ligand.fasta
new file mode 100644
index 0000000..ea964b7
--- /dev/null
+++ b/examples/ligand.fasta
@@ -0,0 +1,12 @@
+>A|protein|./examples/msa/seq1.a3m
+MVTPEGNVSLVDESLLVGVTDEDRAVRSAHQFYERLIGLWAPAVMEAAHELGVFAALAEAPADSGELARRLDCDARAMRVLLDALYAYDVIDRIHDTNGFRYLLSAEARECLLPGTLFSLVGKFMHDINVAWPAWRNLAEVVRHGARDTSGAESPNGIAQEDYESLVGGINFWAPPIVTTLSRKLRASGRSGDATASVLDVGCGTGLYSQLLLREFPRWTATGLDVERIATLANAQALRLGVEERFATRAGDFWRGGWGTGYDLVLFANIFHLQTPASAVRLMRHAAACLAPDGLVAVVDQIVDADREPKTPQDRFALLFAASMTNTGGGDAYTFQEYEEWFTAAGLQRIETLDTPMHRILLARRATEPSAVPEGQASENLYFQ
+>B|protein|./examples/msa/seq1.a3m
+MVTPEGNVSLVDESLLVGVTDEDRAVRSAHQFYERLIGLWAPAVMEAAHELGVFAALAEAPADSGELARRLDCDARAMRVLLDALYAYDVIDRIHDTNGFRYLLSAEARECLLPGTLFSLVGKFMHDINVAWPAWRNLAEVVRHGARDTSGAESPNGIAQEDYESLVGGINFWAPPIVTTLSRKLRASGRSGDATASVLDVGCGTGLYSQLLLREFPRWTATGLDVERIATLANAQALRLGVEERFATRAGDFWRGGWGTGYDLVLFANIFHLQTPASAVRLMRHAAACLAPDGLVAVVDQIVDADREPKTPQDRFALLFAASMTNTGGGDAYTFQEYEEWFTAAGLQRIETLDTPMHRILLARRATEPSAVPEGQASENLYFQ
+>C|ccd
+SAH
+>D|ccd
+SAH
+>E|smiles
+N[C@@H](Cc1ccc(O)cc1)C(=O)O
+>F|smiles
+N[C@@H](Cc1ccc(O)cc1)C(=O)O
\ No newline at end of file
diff --git a/examples/ligand.yaml b/examples/ligand.yaml
new file mode 100644
index 0000000..a5f04f0
--- /dev/null
+++ b/examples/ligand.yaml
@@ -0,0 +1,12 @@
+version: 1 # Optional, defaults to 1
+sequences:
+ - protein:
+ id: [A, B]
+ sequence: MVTPEGNVSLVDESLLVGVTDEDRAVRSAHQFYERLIGLWAPAVMEAAHELGVFAALAEAPADSGELARRLDCDARAMRVLLDALYAYDVIDRIHDTNGFRYLLSAEARECLLPGTLFSLVGKFMHDINVAWPAWRNLAEVVRHGARDTSGAESPNGIAQEDYESLVGGINFWAPPIVTTLSRKLRASGRSGDATASVLDVGCGTGLYSQLLLREFPRWTATGLDVERIATLANAQALRLGVEERFATRAGDFWRGGWGTGYDLVLFANIFHLQTPASAVRLMRHAAACLAPDGLVAVVDQIVDADREPKTPQDRFALLFAASMTNTGGGDAYTFQEYEEWFTAAGLQRIETLDTPMHRILLARRATEPSAVPEGQASENLYFQ
+ msa: ./examples/msa/seq1.a3m
+ - ligand:
+ id: [C, D]
+ ccd: SAH
+ - ligand:
+ id: [E, F]
+ smiles: N[C@@H](Cc1ccc(O)cc1)C(=O)O
diff --git a/examples/msa/seq1.a3m b/examples/msa/seq1.a3m
new file mode 100644
index 0000000..55a5dce
--- /dev/null
+++ b/examples/msa/seq1.a3m
@@ -0,0 +1,498 @@
+>101
+MVTPEGNVSLVDESLLVGVTDEDRAVRSAHQFYERLIGLWAPAVMEAAHELGVFAALAEAPADSGELARRLDCDARAMRVLLDALYAYDVIDRIHDTNGFRYLLSAEARECLLPGTLFSLVGKFMHDINVAWPAWRNLAEVVRHGARDTSGAESPNGIAQEDYESLVGGINFWAPPIVTTLSRKLRASGRSGDATASVLDVGCGTGLYSQLLLREFPRWTATGLDVERIATLANAQALRLGVEERFATRAGDFWRGGWGTGYDLVLFANIFHLQTPASAVRLMRHAAACLAPDGLVAVVDQIVDADREPKTPQDRFALLFAASMTNTGGGDAYTFQEYEEWFTAAGLQRIETLDTPMHRILLARRATEPSAVPEGQASENLYFQ
+>UniRef100_A0A0D4WTP2 338 1.00 7.965E-99 2 375 384 1 374 375
+--TPEGNVSLVDESLLVGVTDEDRAVRSAHQFYERLIGLWAPAVMEAAHELGVFAALAEAPADSGELARRLDCDARAMRVLLDALYAYDVIDRIHDTNGFRYLLSAEARECLLPGTLFSLVGKFMHDINVAWPAWRNLAEVVRHGARDTSGAESPNGIAQEDYESLVGGINFWAPPIVTTLSRKLRASGRSGDATASVLDVGCGTGLYSQLLLREFPRWTATGLDVERIATLANAQALRLGVEERFATRAGDFWRGGWGTGYDLVLFANIFHLQTPASAVRLMRHAAACLAPDGLVAVVDQIVDADREPKTPQDRFALLFAASMTNTGGGDAYTFQEYEEWFTAAGLQRIETLDTPMHRILLARRATEPSAVPEGQ--------
+>UniRef100_UPI00199CC1A7 304 0.851 3.950E-87 8 371 384 1 364 369
+--------SLVDESLLAGVTDEDRTVRSAHQFYERLIGLWAPAVMEAAHELGVFVALAEEPVGSAEMARRLDCDPRAMRVLLDALYAYDVIGRIHDTNGFRYVMSPEAQECLLPGRLFSLVGKLAHDIDVAWPAWRNLASVVRHGARDTTGTDSPNGIAEEDYESLVGGINFWAPPIVAALTRKLHALGRSGESAASILDVGCGTGLYSQLLLREFPEWTATGLDVERIAALASAQSLRLGVAERFGTGVGDFWKGDWGTGYDIVLFVNIFHLQTPASAARLMRNAAASLAPDGLVAVVDQIIDADREPKTPQDRFALLFAASMTNTGGGDTYTFQEYEEWFTAAGLQRVETLDTPMHRILLARRVTETPAA------------
+>UniRef100_UPI001674A1EF 256 0.569 1.909E-70 8 367 384 30 389 395
+--------ELVDYNYLSTTTGERKTIRAAHQLYEHLISLWAPAIIETAHDLGVFARLAKGPGTVAELATDLDTDQRATRVLLGGLVAYGVLERSDDDGESRYVLPEEFRHALLPGGTFSLVGKMAHDRHVAWAAWRNLGDAVRHGTRDRSGNDRTNQISETNYEDLTFGINFWAPPIVDVLSSFLAESGWKKDQAVSVLDVGCGTGLYSQLLLERFPSWTAEGIDAPRIIPLATRQAEKIGVGARFTGTVRDFWQHGWGEIVDLILFANIFHLQTDDSVQKLMRSAADVLAPDGLICIADQIVVDEARPTTAQDRFALLFAASMLATGGGDAYALSTYDQWLAEAGLERVAVLEAPMHRLLLVGHAGR----------------
+>UniRef100_A0A229H607 251 0.571 5.884E-69 11 369 384 1 362 363
+-----------DFNFLSATTGEQKTVRAAHQIYEHLIGLWAPAVIEAAHDLGVFSWLANRPGTVEEMSAELETDQRATRVLLGGLLAYGVIERSELDGEVRYSLPSEFRQALLPGGTFSLVGKMLHDRHVAWAGWRNLGDAVRHGTRDQSGNDRTNQISEADYEDLTSGINFWAPPIVDVLSAALAETGWKKDEAVSVLDVGCGTGLYSQLLLERFPAWKAEGIDAPRIIRLATAQAERLGVGSRFTGTVRDFWKDGWGETVDLILFANIFHLQTPDSVQKLMRSAADVLAPDGLICIADQIVVDEARPVTAQDRFAMLFAASMLATGGGDAYTLSAYDQWLAEAGLERVAVLEAPMHRLLLighAGRHPLPA--------------
+>UniRef100_A0A6I4VXR1 248 0.510 7.112E-68 19 365 384 24 367 381
+-------------------TPAARQSATAHRVYEALIAMWSTGVIEAGHDLGLFERLATGPATVPELAADLGADPRATRVLCDALVVYGVLER---GDHGRFAMPADIAACLLADGLYSLAGKIFYDRTVAWDAWRGLADAVRRGPVDAHGDDQANQISDVDYEQLTGGINFWAPPIAELLAGWLRDAGWDARPGRTVLDVGCGTGIYSHLLLQAFPGATSTGLEAARIVPIADRQAGLLGVADRFTATACDFMSDPWPSGVDLALFVNIFHLQHPAAARRLLARSAAALAPDGVLCVVDHIVDREGPLDSPQDRFALLFAASMLATGGGGAHALADYDLWLAGAGLRRVALLDAPMHRVLFAARA------------------
+>UniRef100_UPI001678806F 245 0.509 6.291E-67 8 363 384 2 358 361
+--------RLVDSTKLIGDPRDSAVVRASHRVYEHLVAMFAPGLIEAAFDLGAFVALADGPAGAAELAERLDADPLGVRVLLDGLSCYEIVYRESAPEGGhRYRLADGMAECLLPGGLYSLAGRIRYDRAIGWDAWRDLAQHVRHPARDDSGAYRANQLSAEDYESVARGINFWAPPIVEALAGLLTDTGWKEETPRSMLDVGCGTGIYSQLLLQRFKELTATGLDDPRIVPIAEEQAQRLNVGARFSPVSQDFFQQPWPGGQDLVLLVNIFHLQTADGAQELMHRARQAVREDGVVAIVDHIVDDDSEPHSPHNRFFRLFSASMLVTGGGDSFSLAEYDQWLERAGLCRTALVDTPMHRILLAR--------------------
+>UniRef100_UPI0021B13CE5 245 0.569 6.291E-67 8 367 384 30 389 395
+--------SLVDYNYLSATTGERKTIRAAHQLYEHLISLWAPAIIETAHDLGVFAWLAERSGTAEQLADGLKTDRRATRVLLDGLFAYGVLERSAAGGEVRYTLPEDFRHALLPGGTFSLVGKMAHDRHVAWAAWNNLGDAVRHGTRDQSGNDRTNQISETDYEDLTSGINFWAPPIVDVLASYLAESGWKQNETASVLDVGCGTGLYSQLLLERFPSWTAEGIDAPRIIRLADRQAERLGVADRFTGTVRDFWQHGWGEIVDLILFANIFHLQTADSVQKLMRSAADVLAADGLICIADQIVVDEAHPTTAQDRFALLFAASMLATGGGDAYALSEYDQWLAEAGLERVAVLEAPMHRLLLVGHAGR----------------
+>UniRef100_A0A640S974 244 0.654 1.173E-66 9 364 384 6 359 361
+---------LVDTSLLPSATHEEKVIRTAHAFYEHLIGLWAPAIIEAAHETGIFAALADRPVTADDLAASLHADPRTTRVLLDALYAYDVIDRIRSTDSFLYVLSDAARECLLPGGVFSIAGKMVHDRRVAWSAWANLGEVVRQGTR--TGTENDNQISERDYESLVGGINFWAPPIVDVLTDELRRRGADGGTPATVLDVGCGTGLYSQLLLRAFPAWCAMGLDAERIAPLAAAQGQRLGVADRFLVRSGDFWTEDWGTGHDHLLFANIFHLLTPASGQRLMDLAARSVSATGTVVVIDQILDAEREAKTPQDRFALLFAASMANTGGGDAYTFQDYDDWFAGAGMKRVATLDAPMHRILLAQR-------------------
+>UniRef100_F2YRZ2 244 0.563 2.186E-66 27 362 384 2 338 342
+---------------------------SAHRIYEHLISLWAPGVIEAAHDLGVFAELSSGPRTSDQLAESCDSNQRAMRVLLDGLFAYDILDRIPsDSGPTVYRMPDEMRECLLPGGLFSLVGKIEYDRQLAWHSWRNLADAVRTDNRDASGDLQLNQITEHNYESLVRGINFWAPPIVAALADGFETFEWPTDRPASVLDIGCGTGLYSQLLLERFPKWRATGLEAPHIAPIAEAQAQRLGVADRFDVQVRDFWTESWGSDHDLLIFVNIFHLQTPESSLELLRKSKESLADGGLICIADHMVTDEQEAKPVQDRFAMLFAASMLATGGGDAFLLDEYDTWLAEAGLRRVAVLDTPMHRILLA---------------------
+>UniRef100_UPI000A38C1E5 242 0.540 7.592E-66 9 364 384 17 370 371
+---------LVDTSLLPADDDGARAT---HRVYEHLIGMWAPGVIEAAQDLGVFATLTDGPATAAGLAETLGTDLRATRVLLDGLHAYDIVGRERGGDGqAVYTLPASLRGVFAPDGLYSLAGKITHDRNVAWQAWRHLADNVRGGARGEHGGQQVNQISEEDYTSLARGINFWAPPVVSVLADALRERGWGDDTEAVLLDVGCGTGIYSHLLLQAFPQLTARGLDAPRITAIAAEQAERLGVSERFSPLTADFWNDDWGNGTDLALFVNIFHLQTPESAHELLLKTAKGLGEGGLIAIVDHIVDEEAGSGNVQNRFFRLFAASMLATGGGDAYTVHDYDQWLADAGLRRVALLDTPMHRVLLAGR-------------------
+>UniRef100_A0A7C3EEX2 241 0.311 1.415E-65 21 365 384 281 612 615
+---------------------EGRAAADAGRLMELAWGYAAPVVIATAVRYGLFGSIGHRGASIEELVRRTGLSERGLRILLQALVGLRLLRR----NGSRFELTPESATCLVPEQPEYRGGLFLHHVEHLLPRWLQLPEVVRTGWPVREPQCPA-----HRYAGFVESLFASNYPAAKALQRHLQLAGRK--EPFQVLDLGAGSGVWGIALAEGAPQVWVTAVDWPEVLLIARKKAAAYGVSDRFRWVEGSFFEVPLGRGYDLVVLGHVLHAEGVEGVRTLLRRSCDALRPGGLVAIQEFLPDDDRSGP----LLPLLFAVNMLVnTEAGDTYTLAELTGWLEEAGFEAVETLNVPaPSPMVLARKP------------------
+>UniRef100_UPI00167C15C9 241 0.504 1.931E-65 9 365 384 30 388 389
+---------LVDGSKLIGDPRDSAVVRASHRVYEHLIAMFAPGMIEAAFDLGVFVALADGPATPTDLAARLDADAHGLRVLLDGLYCYEIVQRVRAEDGEdLYRLVDGMAECLLPGGLYSLAGRIGYDRAIGWDVWRNLADHVKRPARGADGGYQANQLSAEDYEQVARGINFWAPPIVESLANLLaEEEGWKGEADRSMLDVGCGTGIYSQLLLQRFRGLRATGLDHPRIVPIARGQAERLGVRERFEPVERDFFTEEWNTGQDLVLLVNIFHLQTAEGSEELMRRAAQAVRQGGVVAIVDHIVDDETDDQSIQNRFFRLFAASMLVTGGGDSFSLADYDQWLDRAGLVRTALVDTPMHRILLARRP------------------
+>UniRef100_UPI001BA9CE47 241 0.559 2.636E-65 5 364 384 20 382 384
+-----GNVQpLVDTALLPAGDGQPRVVRAAHRLYEHLISLWAPGAIEAAFDLGVFDELAKGPATADELAKSLSTNAKATRVLLDGLNAYDLLVRTWDADGtVVYVLPDEVRDVLRPDGLFSLAGKIGYDRQMAWGAWRNLAETVRTGALADDGSQQHNQISASEYESLVRGINFWAPPVVHALSAQLKEHGWAGDKTAGMLDVGCGTGIYSQLLLQQFAGLTATGLDVERILPLAIAQSEQLDVADRFHPLRRDFWREDWGTGFDLVLFVNIFHLQTPEDARDLAIKANKALADGGLVAIVDqIVVEDEVQQNSTQNRFFRLFAASMMATGGGDAYTLDQYDEWLTGAGLRRVALIDTPMHRILLAGR-------------------
+>UniRef100_A0A1V2QL50 239 0.561 6.705E-65 9 365 384 18 375 376
+---------LVDTALLPGHGLQHDVVTAAHRVYEHLIAIWAPGVIEAAHDLGVFVELSAGPATAERLAERLDTEPRATRVLMDALYAYDIVERTTEaSAPPSYRLPAAMRECLLPGGMFSLVGKIAYDRRLAWRAWQDFAGAVRRGSRDGSGSDQLNQISVDEYESLVSGINFWAPPVVQVLRQGLRDLAWPCDRAVRMVDVGCGTGLYGQLLLREFPQWTAVGLDVARIAPLATSQAAELGVAARFEATVCDFWQDSWGQDVDLILLANIFHLQTPESAETLVRLAAEALAEDGMLCIVDHVVDDERTAKSAQDRFALLFAASMLATGGGDAYTLKDYDDWFVRYGLRRERILETPMHRILLVTRA------------------
+>UniRef100_A0A2M9IGJ8 239 0.552 6.705E-65 3 362 384 18 379 383
+---PELNVRpLVDTTLLPDWRGSGRVVHSAHRVYEHLISLWAPGVIEAAHDLGVFAELSAGPRTSDQLARACAANQRAMRVLMDGLYAYDIVDRVPTEDGpAVYRMPEEMRECLLPDGLFSLVGKIEYDRQLAWHSWRNLADAVRGDNRDETGGLQLNQISEHNYESLVRGINFWAPPIVEALRGGFETLEWPTDRPASVLDIGCGTGLYSQLLLRAFPRWRATGLEAPAIAPIATAQAERLGVADRFGVQVRDFWTESWGTGHDLLVFVNIFHLQTPESAQELLRKSKEALSRDGLICIADHLVTDEKDAKSVQDRFAMLFAASMLATGGGDAFLLDDYDQWLASTGLRRVAVLDTPMHRILLA---------------------
+>UniRef100_A0A0B5DK60 238 0.563 1.705E-64 27 362 384 43 379 383
+---------------------------SAHRIYEHLISLWAPGVIEAAHDLGVFAELSTGPKTSDQLATACDAEQRAMRVLMDGLYAYDIVDRIPaDAGPALYRMSEEMHECLLPGGLFSLVGKIEYDRQLAWHSWRNLADAVRRDNRDETGSLQLNQITEHNYESLVRGINFWAPPIVEALRGGFETLEWPTDRPASVLDIGCGTGLYSQLLLRAFPGWRATGLEAPNIAPIARAQAERLGVADRFDVQVRDFWTESWGSDHDLLVFVNIFHLQTPESAQELLRRSKEALSKDGLVCIADHLVTDEKDAKSIQDRFAMLFAASMLATGGGDAFLLDDYDRWLASAGLRRVAVLDTPMHRILLA---------------------
+>UniRef100_UPI00055F6ABE 235 0.547 2.803E-63 9 365 384 14 368 369
+---------LVDTELLPSPTGE---IRAAHRLYEHLIGMWATGVIEAAQDLGAFAALTVAPATASGLSELLGTDLRATRVLLDGLYAYDVVERSRGADGqAVYTLPAELHQVFAPDGLYSLAGKIGHDRNVAWHAWRNLADAVRSGARGEDGAQQLNQISESDYTSLVRGINFWAPPITSALADGLRELGWTDGEAATLLDVGCGTGIYSHLLLDEFPGLQARGLDAERIIPIATEQAARLGVADRFDPVVCDFWNDDWGTGVDLALFVNIFHLQTPESARELLLKSAKSLSEDGVIAIADHIVDEDGGVGSTQNRFFRLFAASMLATGGGDSFTVQDYDQWLADAGLRRVALIDTPMHRVLLARRA------------------
+>UniRef100_A0A3E0GTP3 234 0.521 3.825E-63 9 364 384 6 362 364
+---------LVDTGLLPADGANSDVVMAAHRVYEHLIALWAPGVIEAAHDLGVFDALGTAPARADELAEQLGTDTKATGVLLEALYAYEIVAREVADDGvVGYTLAPAMAEVLSPTGLFSLTGKIGYDRKLAWDAWRGLADAVRSGRYDASGSEQGNRISEYEYESLVTGINFWAPPIVRELGRALRELGWPTTESARMLDIGCGSGLYSHLLLQEFPGLSAVGIDVELILKIAVEQSLRLGVADRFATFDGDFTSDDLGRDFDLVLLVNIFHLQSGDSAGLLAKRVASALGDNGIVAIVDQIIDDRQGPRSTHNRFFRLFATSMLATGGGGAYTVDDYDAWLESAGLHRIALVDTPMHRVLLAKR-------------------
+>UniRef100_A0A918C4G1 234 0.536 3.825E-63 25 364 384 30 370 371
+-------------------------ARATHRVYEHLIGMWAPGVIEAAQDLGVFATLTAGPATAAGLAETLGTDLRATRVLLDGLHAYDIVQRERGGDGqAVYTLPASLHGVFAPDGLFSLAGKITHDRNVAWHAWRHLADNVRSGARSAHGGQQVNQISEEDYTALARGINFWAPPVVSVLADALRERGWGDETDALLLDVGCGTGIYSHLMLEAFPRLTARGLDAPRITAIASEQAARLGVHDRFEPLTADFWNDDWGNGTDLALFVNIFHLQTPESAHELLLKTAKGLTEGGLIAIVDHIVDEEAGGANVQNRFFRLFAASMLATGGGDAYTVQDYDQWLADAGLRRVALLDTPMHRVLLAGR-------------------
+>UniRef100_UPI0018D5B757 233 0.517 9.722E-63 9 371 384 14 375 376
+---------LVDTAALADL--DDAESRAHHHLYEHLIGLWAPGLIEACHDLGIFTALRRGPASATDVADAVGADPRAVRVLLDGLQAYGIVRRAESGDPhPVYLLPAELHQAFSSDGLYSLAGKISHDRGIAWDAWRRLADRARTDTRSDGAPPRPNQISEDDYTALVRGINFWAPPIVHRLAGALRESGWAARTAPTLLDVGCGTGIYSHLLLREFPELTAHGLDAERIIPIAERQAARLGLaPSRFRGRTGDFWNDDWGSGYDLVLFVNIFHLQTPELACALLAKAAGSLAADGVIAIADHIVD-DAEPDSPQNRFSRLFAVSMLATGGGDAFTVQEYDRWLASARLRRFRLVNTPMHRVLLARRAAGPAAA------------
+>UniRef100_A0A6G4X3C8 233 0.555 1.327E-62 22 362 384 31 372 376
+----------------------DGEVRAAHRLYEHLIGIWAPGVIEAAQDLGAFAALTEGPATAAALAETLGTDLRATRVLLDGLSAYDVVQRTRGADGqAVYTLPAELHGVFAPDGLYSLAGKIGHDRNVAWSAWRNLARNVRDGARTSDGAEQLNQISEEDYTSLVRGINFWAPPIVRPLAERLRTTGWGTGSGRTLLDVGCGTGIYSHLLLKEFPELSATGLDVGRIVPIAEAQAAQLGVADRFRCVTGDFWNDEWTGDTDLALFVNIFHLQTPESARDLLLKSAKALSDDGVIAIADHIVDEEEGEDSTQNRFFRLFAASMLATGGGDAFTVHDYDQWLSDAGLRRVGLLDTPMHRVLLA---------------------
+>UniRef100_UPI001CD37CE1 233 0.541 1.327E-62 3 362 384 18 379 383
+---PELNVRpLVDTTLLPDWRGAGKVVHSAHRVYEHLISLWAPGVIEAAHDLGVFAELSTGPKTGDQLARACAANPRAMRVLMDGLYAYDVVDRVPaEDGPAVYRMPEEMRECLLPDGLFSLVGKIEYDRQLAWHSWRNLADSVRGDNRDEAGRLQLNQITEHNYESLVRGINFWAPPIVEALRGGFETLEWPTDRPASVLDIGCGTGLYSQLLLRAFQQWRATGLEAPSIAPIAMAQAERLGVADRFDVQVRDFWTESWGSDHDLLVFVNIFHLQTPESAQELLRKSKEALSRDGLVCIADHLVTDEKDAKSVQDRFAMLFAASMLATGGGDAFLLNDYDQWLASAGLRRVAVLDTPMHRILLA---------------------
+>UniRef100_A0A2T7T4I2 233 0.530 1.327E-62 3 362 384 18 379 383
+---PEHNIRpLVDTALLSDWRGSGKVVHSAHRIYEHLISLWAPGAIEAAHDLGVFAKLSTGPMTGDQLAEACQANRRAMRVLMDGLYAYDIVDRSSTDDGpAVYRMPEEMRECLLPDGLFSLVGKIEYDRQLAWPAWRNLADAVRHDNRDEVGELQLNQINEHNYASLVRGINFWAPPVVEALRGGFETLDWPTDRPASVLDVGCGTGLYSQLLLRHFGQWRATGLEAPHIASIAEEQAERLGVAERFEVQVRDFWTESWGSGHDLLLFVNIFHLQTPESARELLHKSKQALSENGMICIADHLVTGEQDAKSIQDRFAMLFAVSMLATGGGDAFLVDEYDGWLAETGLRRLALLDTPMHRILLA---------------------
+>UniRef100_UPI001661D816 232 0.548 2.471E-62 2 361 384 7 364 369
+--TFDNHTPLVDTELLPG---RGSGVNAAHRMYEHLIGIWATGVIEAAHDLGAFTALIGAPATAGELSTRLGTDLRATRVLLDGLAAYDVVERSRAADGqAVYTLPPEMHDIFAPEGLYSLVGKIRHDRNVAWGAWRNLAGNVRTGARNSEGSQQLNQISEEDYTSLVRGINFWAPPIAATLATALREQGWTDGAGRTLLDVGCGTGIYSQLLLQEFSGLNARALDAERIIPIANAQAHRLGVAERFNPEVVDFWADDWGTGVDVALFVNIFHLQTPESARELLLRSAKALTEDGVIAIADHIVDEDSTDGNTQNRFFRLFAASMLATGGGDAFTVQDYDQWLADAGLRRVALLDTPMHRLLL----------------------
+>UniRef100_A0A7V1RGG8 231 0.309 3.371E-62 23 366 384 3 336 340
+-----------------------ELAPDPTALFELATGFWASATLLAAEEVGVFHVLTEAPRTASEAAQALGADRRALERLLDACSGLNLLVK----QGERYLLSPLAAAYLVPGAPGGLASG-IAWARDQYAAWGRLAETVRTGRPAVDPGDHLGG-DPEQARRFVLAMHERAAGIARAVVGSL-----NLDGVERLLDVGAGPGTYAVLLARRHPGLSATLLDLPPILDAARELVDACGVAERIALRPGDASSGQYgEEAFDAVLFSGVLHQMPPETIRRMLEGAFRALVPGGRVFLSDILADATHTRPV----FSALFSLQMLLTTeGGGVFSVEECRSWLEQAGFAEIEVQRLPaplPYTVVSALRPR-----------------
+>UniRef100_D7BZK8 231 0.566 6.277E-62 22 364 384 25 368 370
+----------------------DGEVRAAHRLYEHLIGIWAPGVIEAAQDLGAFAALTVGPATAAQLAEVLDTDLRATRVLLDGLYAYDVVERSRGEDGqAVYTLPAELHGVFAPDGLFSLAGKIGHDRNVAWNAWRRLAENVRSGARTAEGAQQLNQISEEDYTSLVRGINFWAPPITRSLAGALRELGWTTGRSRTLLDVGCGTGIYSHLLLREFPELTARGLDAERIIPIAARQAGQLGVAERFQGEVVDFWSEDWGSGTDLALFVNIFHLQTPESARELLLKAVKGLTEDGVIAIADHIVDEDGGEGSVQNKFFRLFAASMLATGGGDAFTVHDYDQWLADAGLRRIGLLDTPMHRVLLARR-------------------
+>UniRef100_A0A1R1S6N2 231 0.569 6.277E-62 22 364 384 25 368 370
+----------------------DGEVRAAHRLYEHLIGMWAPGVIEAAQDLGAFAALAVGPATAAQLAEILDTDLRATRVLLDGLYAYDVVQRtRGDDGQAVYTLPAELHGVFAPHGLFSLAGKIGHDRNVAWNAWRHLADNVRSGARAADGAQQLNQISEEDYTSLVRGINFWAPPITRALAGALRDLGWTTGRSANLLDVGCGTGIYSHLLLREFPELTARGLDAERIIPIAARQATQLGVAERFRGEVVDFWSEDWGSGTDLALFVNIFHLQTPESARELLLKATKGLTEDGVIAIADHIVDEDRGEGSVQNKFFRLFAASMLATGGGDAFTVHDYDQWLADAGLRRVGLLDTPMHRVLLARR-------------------
+>UniRef100_UPI00224D4A51 229 0.538 1.595E-61 9 362 384 17 369 381
+---------LVDYTKLSADGAAPSEIRAAHQVYEHLVSLWAPSIIEAAHDLGFFVELADGARTADEVAHARGTDRRATRVMLDALYAYGLVGKSCEGSvPHRYVLPDACRGALLPGGFFSLVGKMAHDRNVAWNAWSDLARTVRRGTCDESGEDLANGISETDYEDLVTGINFWAPPIVDTLANCLADSGWKAGEAVSVLDVGCGTGLYGQLLLQRFPQWRAEGIDAPRIVPLADAQAKRLGVEDRFTGTVQDIWRGGWGEGADLILLNNMIHLQTAESGRKLLRTAADSLAPDGLVCIADQVIVNDEES--PQDRFAMLFAASMLATGGGDAHSLDTCKEWFAAAGLEMVAVLDAPMHRVVIA---------------------
+>UniRef100_UPI0020BDB3A6 228 0.555 4.051E-61 9 365 384 17 371 372
+---------LVDTARLTGVEAESQA---AHHLYEHLIGLWAPGVIEAAQDLGAFSALTLGPATAVRLAEILGTDLRATRVLMDGLHAYDVVRRSHSADGqALYTLPPELHDVFSPHGLYSLVGKISHDRKLAWNAWRNLAENVRTGARDATGGERVNQISEEDYTSLVRGINFWAPPIVRTLADALRELGWTTGESARVLDVGCGTGIYSQLLLREFPALTASGLDTERITAIASRQAQELDVADRFEVVVKDFWNDDWGTDIELALFVNIFHLQTPESARELLLKSSKSLAQGGLVAIADHIVDDDDGAGSVQNKFSRLFAASMLATGGGDAYTLHDYDQWLADSELRRVALLDTPMHRVLLARRA------------------
+>UniRef100_A0A5D0QPB3 228 0.563 7.541E-61 9 362 384 7 358 362
+---------LVDRRLLPDVGAGHETVAAAHHVYEHLIALWAPGAIEAAFDLGVFAALADGPATAEALAGRLEVDQRGMRVLLDALSAYDLIDRGSSAGGVRYGLRAGLRECLLPDGLYSLAGKVRYDRMLAWTAWRNLAQAVRGDGSAVP---QHNQISTTEYESLVRGINFWAPPIVSILAGALRDRgWPAGPAAPAMLDVGCGTGLYSQLLLQQFPELTGVGFDVERIVSIARAQSERMDVGDRFQPLAIDFWQRDWGTGFDLVLFANIFHLQTPDSARELSIRASKALAGGGVVAIIDQIVDDRADADSVQDRFFRLFAASMLATGGGDAYPLSDYDEWLSVAGLRRAALVDTPMHRILLA---------------------
+>UniRef100_A0A1E7JVS4 227 0.558 1.404E-60 22 362 384 25 365 369
+----------------------DGEVRAAHRLYEHLVGIWAPGVVEAAQDLGAFAALTEGPATAAQLSERLGTDLRATRVLLDGLHAYDVLGRARGEDGqPVYSLPPEMHGVFAPGGLYSLAGKITHDRNVAWDAWRNLAENVRSGARTSGGAQQLNQISEEDYTALVRGINFWAPPITQVLAEGLRAHGWTSGADRRMIDVGCGTGIYSQLLLNEFPELRARGLDVERIVPIAQEQAKRLGVADRFRTEICDFWNDDWGNDSSLALFVNIFHLQTPESAHELLLKTSKSLAEDGVIAIADHIVDEDEDGST-QNKFSRLFAASMLATGGGDAFTVQDYDQWLADAGLRRIALLDAPMHRVLLA---------------------
+>UniRef100_UPI0018F88670 224 0.498 1.684E-59 22 364 384 12 356 364
+----------------------DAASRRAHLLYEQLVSLWTPAVIEAAHDVGLFSALSRGPATSDELAAALSVHPRGARILLDALFACDLVECDEQPGcAPIYTLPEDVKACVEPLGLFSLAGKMLYDRRFAWDAWRNFATAVREGGVdQSSKQCRQNQISPEEYRFLTRGINFFAPPIIHALGEGLAKIGWSTRRAISVLDVGCGTGIYSQLLLQRHATWRAVGMDCETMAALARAQSAELGVEDRFSCRASDLWRLPWGGDFDLILLCNMFHLQSPDGAARLMKLAGEAVSTAGIVCVIDQIRDEHRHVDTAQNRFALMFAASMLATGGGDTYTLEQYDEWLRDAGLERLIVLPAPMHRILIARR-------------------
+>UniRef100_I2Q4T4 223 0.314 3.133E-59 27 353 384 5 316 332
+---------------------------TPAALLEIAGGYWKTCALHAGVVLDVFTPLADVPLTAGELAARLGCDARALGMLLRALAAMELLSR----SGERYALTGEAREFLDARSPRYIGYAVRHH-HRLMPVWTRLPEAIRSGRSLREHM--GGDADPGDREDFLLGMFNIAMGVAPRLARTLDLSGR-----RRLLDLGGGPGTYAVHFCLAHPDMTATVFDLAGSREFAASVSERFGVADRVEFVAGDYLRDPVPGGYDVAWLSQILHAEDPAGCRTILGKAAGALSPGGLLFVHEFMLDDDAAGP----EFAALFSLNMlLGTDHGQSYPEGRIREMMEGAGLKNVRRLD------------------------------
+>UniRef100_A0A840IMV4 222 0.583 4.274E-59 44 362 384 1 314 318
+--------------------------------------------IEAAHDLGVFVELAGEPRSGAELARALDADPRAMSVLLDALCAYDLL---VEDGNARYALPAELRECLLPDGLFSLVGKIEYDRTLAWRAWRNLADTVRAGTRVEDGSDAPNQIGETEYRSLVHGINFWAPPIVNVLAGALAERGWTPRRPVSLVDVGCGSGIYSHLLLRRFPALTATGIDVRRIMPLALAQADRLGVADRFRPAVMDFWSEDWGLGYDLALFVNIFHLQTPASAEELLRRAAKSLAGGGLVAVVDQIVTGDAHS--AQNRFSRLFAASMLATGGGGAYRTEDYDRWLDGAGLTRVALLDTPMHRVLLA---------------------
+>UniRef100_UPI00210A57ED 218 0.529 9.522E-58 10 365 384 40 405 406
+----------VDRSRLAGT---DAQSSAAHRVYEHLVALWAPGVIEAAQDLGAFAALTSGPATAAELARTLDTDPRATRVLLDGLHAYDIVERAMTEDGEiRYTLPVELHDVFSPGGLYSLAGKINYDRSLAWNAWRDLAHNVRTGARDAEGGHQLNQISEEEYTSLVRGINFWAPPIVSALADALREQGWSTGDGAKVLDVGCGTGIYSQLLLREFPLLTARGLDVARITPIAIRQAKELGVADRFEPTVVDFFHDSWGGG-DLALFVNIFHLQTAESARELMLKAAKEINEDGVIAIADHIVVEDTHGAPgvsdgadaagtgsVQNRFFRLFAASMLATGGGDAYTVEEYDGWLADAGLRRIALIDTPMHRLLLAKRA------------------
+>UniRef100_A0A3C0VCK0 218 0.279 1.299E-57 33 364 384 9 323 326
+---------------------------------QMASGFMPARVLLTALELDVFTACGAGAATAEELARRTGARPAPLARLLNALAALGLLDKRGD----RYRATPPARTHLIAGRPGYL-GDIMRHRASMWERWSDLTAIVRTGR------VPPRAFTKERERRFIKGMANLGASAAPACARALR---RELAGARRLLDIGGGPAVYACELARAWPKLSVVVLDLPGPLAYARETIAAYGLARRVSVKAGDVCAArSFGRGFDVAFMSSLIHSFKPAVVAEVIRKAAGALRPGGFLAVKEFFIDPGRASPP----FTALFSINMLVAGAGDVYTRGEVEGWMRAAGVRPVRYVDLPQFSGIVVGR-------------------
+>UniRef100_A0A1F5AWX5 218 0.297 1.299E-57 28 365 384 2 329 330
+----------------------------PTRLVEMASAFYESSVLFAASDLGIFAKLAElGEADARTISAVSRLDPRGARLLLDACVALELLVKKGD----RYQNSPEAAAFLVPGAPADLSGAIRYNRD-VYGAWGKLKELVKTGKP-VERAELHLGEDPERTRTFVLAMHGRAMGIGQAVIPLLDLDGRKA-----VLDVGGGPGTYSILIARAFQQVRCTVIDLPEVARIADEIISQAGVGNRVRTLAGDYHTLPFPADQDTVVFFGVLHQEDPAAIQDLFRRAYGAMLPGGRIYVLDMMTDASH----ARPKFSALFGLNMaLTTPHGWVFSDDELKGWLKEAGLTDFNCRPLPppmPHWLATARKA------------------
+>UniRef100_G7Q8Y2 218 0.307 1.299E-57 27 353 384 5 316 332
+---------------------------TPAALLEIAGGYWKTCALHAGVVLDVFTPLGDGPLTAGELAVRLGCDARALGMLLRALAAMELLAR----SGEGYALAGEAREFLDARSPRYIGYAVRHH-HRLMPVWTRLPEAIRSGRSLREHM--GGDADPGDREDFLMGMYNIALSIAPRLAQSLDLSGR-----RRLLDLGGGPGTYAVHFCLAHPEMTATVFDLAGSREFAASVSERFGVADRVEFVAGDYLKDPVPGGHDVAWLSQILHAEDPAGCRTILGKAAGALSPGGLLFVHEFMLDDDAAGP----EFAALFSLNMlLGTDHGQSYPEGQIREMMEAAGLRDIRRLD------------------------------
+>UniRef100_A0A3C1EP29 216 0.289 6.124E-57 28 361 384 10 327 336
+----------------------------AAPIMALARGFMASRILLTAFDLDIFTALDRGPIDSVQAARRIRCDNRATDRLLNALVSLGLTRK----KGRLFSNTPLAARHLVRGRPEYLAG--LGHCVHLWDSWSTLTAAVRRGRSVLEPSVGRRGAAW--LSAFISAMHERARAQADAVVK-----GLDLGAVESVLDVGGGSGAYAMAFVRAKPGLRATVFDLPQVAPLTRRYISREGLSGRVAVRAGDYEKDPLPKGFELVLLSAIIHSNSPTANIRLLRKCRRSLNPGGRIVIQDFVMNPDRTAP----AFGAIFALNMLtATAAGDTYTESEIRSWLKQAGFGSIKRRDTPFASTLI----------------------
+>UniRef100_A0A933TFS2 216 0.306 8.351E-57 38 364 384 23 334 335
+--------------------------------------FMESRVLLTAWELGVFTALGRGARTAAQVARTARADPRAMDRLLDALVSVGLARK----SGRIFSNSPSAARYLVAGRPAYIGS--LGHMASLWESWSTLTQAVRAGRSVLKDDMPRRG--KEFFVPFIAAMHERSSLQGPSFARALP-----LAGVRRLLDVGGGSGAYSIAFARAHPALRATVFDLPQVVPLARGYIRAAGLQDRVEARVGDYDRNPLPDGYDLVFLSHILHSNSPARNRRLLRKCARSLNPGGLVVIQEFLVDEDRTGP----QFAALFALNMLVgTPAGDAYTEREIGSWLKGARLRGIRRKDTAFDSaLLIARR-------------------
+>UniRef100_A0A7V5CKK2 215 0.283 1.553E-56 27 356 384 3 316 329
+---------------------------SVDKVWETARAFQASRILLTGFELGVFATLGDNAMTSAEVASKIGADPRAADRLMDALVVLGLLTK----EEGKFRNSGEARETLVPGKPTYAGGALGHVIS-LWKSWSTLTDAVRKGTSVFKHEDEARA---EFVKPFIAAMHFNASNLAPIILKQI-----DLTGVRRVLDVGGGSGAYSIAFCKASPEITSVIFDLPDVVPLTNEYAAKAGVADRISTVTGDFNTDNLPVGFDLAFLSQILHSNSPDENERLMRKVGTALNPGGQIVVQEFVVDEDRISPPG----PVFFSLNMLVgTKAGDTYTEKEIGSWLDGAGFGEIKRIDPPG---------------------------
+>UniRef100_A0A1V4XPT3 214 0.272 2.117E-56 27 352 384 4 314 331
+---------------------------TPEAVLQLARQFMESRILLTAAELGLFSPLAKKPHTAEQLSGRLGCDTRALAILLDALAAMGLLEK----RDGAYRTPPAAAPFLCGGSPRSVIPMILH-AAHLWERWSDLTPIVRATSSSAAPASGARST--EELSAFIGAMHIAGLPLAEKIVAAIR-----PGQARNLLDVGGASGTYTIAFLRAAPGMKATLFDRPEVIPMARERLAEAGVLDRVRLEAGDFHRDELPGGHDLALLSAIIHQYSPQENRELFGKVSRALVPGGRIVIRDHIMDPDRT----QPRDGAIFAVNMLVnTRGGSTYTFEEVRAWLEGTGFANVRFL-------------------------------
+>UniRef100_UPI000A3C465B 214 0.454 2.887E-56 9 362 384 8 365 374
+---------LIDYAA-FGSTgaPEEDTVVAAHELYSTLIGLWAPAIIEAAAELGVYPLLRDEPVGSDEIAAELALDPAAVRILLDGLHACGMLRRGLTGGGvPRYRLEDRFAPLLLGTGEYHLLGKMAYDRTVAWPAWRGLAETIRSGGVAPGALPEKNQNSERDFVSLVSGINFWAPHAIESVRTALRaDLGWDLARPTSVLDVGCGTGIYSQLLLRGESTWTATGFDTPKVAEIATAQAARLGVGDRFDCEAVDFLAEDWGPPRDLVLLVNVCHLLPRHLVSELIARAAKAVRPGGCVCVVDHMhLDTKDEFDEPQDRFAALFAVSMLSTGGGDTHRVSDYRHWLTDAGLRPAVLRPTPMHRLLLA---------------------
+>UniRef100_A0A345XPI4 214 0.523 3.936E-56 9 362 384 14 359 363
+---------LVDTTLL--PAGGDGEVQAAHRVYEHLVGIWAPGVVEAAQDLGAFAVLTEGPATAAQIAERLDTDLRATRVLLDGLHAYDILGRvRGDDGQPVYSLPPELHGVFAPGGLYSLAGKITHDRKVAWNAWRNLADNVRSG------TQELNQISEEDYTSLVHGINFWAPPITQVLAKGLREHGWTSGAGRSMIDVGCGTGIYSQLLLNEFPELRARGLDVERIVPIAREQAERLGVADRFRPEICDFWNDDWGNDSSLALFVNIFHLQTAESAHELLLKTSKALAEDGVIAIADHIVDEDKDGST-QNKFSRLFAASMLATGGGDAFTVLDYDKWLADAGLRRIALLDAPMHRVLLA---------------------
+>UniRef100_A0A7C3MML6 213 0.302 7.316E-56 27 364 384 8 339 340
+---------------------------SPQQLMGLLQGFMGSAALKAGLDLELFTHIAHGADTAEKLAAVKKVPERAMRILCDALVAFGALTK----SGGHYSLPPASQAMLVKGSPAYFGAMAgIMCNPLMWNEAGRLADVVRAGHSlLDQGAEAPEHPFWEEFSRRSKQMATMGGPAVAELA----ASLFGAGEPARILDIAAGSGMYGFSALKRFPGARLVSVDWPNVLRLAEPTAKQMGLAERVEFRPGDIFKDDLGTGYDLVLAVNIYHHFGIERNTELSRRLHAATASGGRLIIVDAVPDENREH----ERFALVFALTMLiWTREGDTYTLSEYERMLKPAGYRDIELKAVPgpaPFQAIVARK-------------------
+>UniRef100_A0A4V2PBK1 212 0.452 1.360E-55 8 367 384 14 368 372
+--------SVVD----FEEIGTDRTSESAHAIYAALVAQWQPAMLETASSLGLFGALRAGPLRAEEIAAVTGTNTRAVKVLLDALVAYGWVTSIPDGENSRYSADPAVAASLSSDSIFSLTGKIGYNRGLSRSAWRTLDQSVRDGVRAADGI-GNNEITAHAYEDLVTGINFWAPPIVDKLIDWTTRTGWRREQSRKFLDIGCGSGIYSQLLLRHFSRAVAVGLDVESIGRLAVGQSVELGVDDRFRLRTANFWRDDWGTGHDAVLFANIFHLVNPAGALELLDKARDAVADDGFVFIVDNIAVGGTESDSPQDRFAALFAVSMLVTGGGSTYTLADYDQWLSTTGLERVALIDAPMHRIVVARRTEE----------------
+>UniRef100_A0A1G3USZ5 210 0.277 6.404E-55 28 364 384 3 329 331
+----------------------------PDRIIGMASAFYESCVLFTASDLGIFARLSEaGPADAQSLALTLKLDERGVRLLLDACVAMELLQK----EGSHYANTLESKAFLTPGSPGDLSGAIRYNRD-VYTAWGKLKDFVKSGRP-VESPESHLGQDPERTRTFVMAMHYRALGMGRAVIGEL-----DLSGSKTVLDVGGGPGTYSMLIAQANPDATCTVLDLPEVVAVADELIRQQALQGRVKTLSGDYRRISFPEGYDMVNFFGVLHQESPQSILLLLQKAYRALRPGGAVNVMDMMTDSTHT----KPKFSALFGVNMaLTSENGWVFSDLELKEWLKEAGFADCMVKPLPppmPHSFAAARR-------------------
+>UniRef100_UPI0018935592 208 0.407 3.014E-54 9 362 384 20 377 381
+---------VIDYGAFAPTgSDDEKTVIAAHELYTVLIGLWAPAIIEAAHDLGVYPQLSGAGVSSDQVADVLSLPGTASRILLDGLHACGIAERFRSDDGiVRYRLRERFAPLLLGGGAYHLLGKLSYDRSVAWSAWWRLPDSIRNGNPGPGESDGRNQNSEQDFVALVSGINFWAPHVVQQLrAGLAEDLGWDLSHPRSILDVGCGTGIYSQLLLRKQPEWTAVGIETPKVAHIAREQALRFAVADRFDCRETDFLEDGWDVSCDIVLLVNVVHLLPAATAAEFIERASRAVRPGGCLCVIDTILDDSKDTfDQPQDRFAAMFAVSMLATGGGDAHCVSDYRRWLHAAGLRPTAVRETPMHRVLIA---------------------
+>UniRef100_A0A950Y7R7 206 0.274 1.418E-53 32 365 384 0 321 323
+--------------------------------MQMAWSYAPPLIIEAAVRNGFFDALATKPMNASELAQATGSSERGVTAVMDALVGLALAARDRN---GRYVLTAESDTFLVSARPGSLGGFFRHISDLI-PAWLPLRDIVRTGEPARKVDSQETGAA--FFSNFVESLFPLGYPAALGVAKSL---GAPLDAPLQVLDLAAGSGVWSVAIAHTYPQARVTAVDWEGVLPVTKKVTARERVADRYEYIAGDILETDFGGGYDVATLGHILHSEGDARSQELLRKVGRSLKPGGAIVIAEFLANEERSGPP----QALIFSVNMLVnTSAGRAFTFGEIRAWLEEAGFIDARTLEIPAPSLIVARKA------------------
+>UniRef100_UPI000B2CDD0E 205 0.465 2.633E-53 28 370 384 18 361 362
+----------------------------AHQLYAALVAQWQPAMLESASALGIFDVLRSGAASSTAVAKSIGADERSVRVLLDALAAYGWVSGRDGVDGEpLYEVDESVAACLTAGSMYSLIGKIGYNRSVSGDAWRKLDRVVREGISGHDGEIENNGISAVAYEDLVTGINFWAPPIVDKITAWLRAAGWGAGEARDVLDIGCGSGVYGQLLLGDFPAATATGVDAPNILRIAAKQAAALGVGERFEARGADFWTSEWGTGRDLVIFANIFHLVNPSGAEKLLEKARESVADDGIICIVDNIQVGGAETDSPQDRFAALFAVSMMVTGGGATYRLAEYDEWLRVAELERVALLDAPMHRVILARPRREGSS-------------
+>UniRef100_A0A6G9ZAG5 205 0.447 3.588E-53 5 363 384 31 389 393
+-----GRTTLVDYEHFQRSGVGFDEICAAHKVYETLVGLWAPGVIEAADELGVFREIAKSSKTPAELAEVAGAGSHGMRILLDALCVYGLLNRDvDDSDGYKYSLKPFFGSVVTGHGSASLIGKFLYDRQLAWPAWVNFVDAVRN-SGDPDSGRQENQIPAGQYIHLTKGISFWAPPIVDVLCHRLEELGWSSSSEKHILDVGCGTGIYSHLLLRSFRGSQAIGLDVPEICRVAIESASEFGVDDRFATREVDFWSEGWPKNQDLVVIANIFQMLTPDSAKRLIDLAASSLSESGVVCIVDQIRIGKAEFDTAQDRFAAVFAASMLATGGGDTFHLNQYDDWLESSDMHRIDLLDTPMHRIILAR--------------------
+>UniRef100_A0A7V7EA79 204 0.311 6.664E-53 30 370 384 9 333 335
+------------------------------RINEISTGFKGSMILFAANDAGVF-ALLEEERSADELAAVAGWHPRAARMLLDALVALDLIGK----SEGRYRNTPIASACLVPGGKAYQGHIIKHQQNG-WDAWARLEVSLRSG---TAVERDAHERSPEELRAFILGMRDTARISARTMCDVV-----DLSTHRHMLDLGAGPATYAIVFTQRHPELRATVFDVPEVIPIAREQVAAAGLDERFAYIEGDMLADDLGSDYDLVLASNIIHMYGPVENRALMKRCYDALAPSGLLIVKDFLVDDGRSGP----AFGLLFALQMLiHTPCGDTYATSELSEWTNEAGFAEGRLIElTPQARLWLAGKPPARSA-------------
+>UniRef100_A0A7T1WXA6 203 0.299 1.686E-52 34 365 384 0 326 327
+----------------------------------MMSAYKETSVLKAGIKLGVFDELArEEPQDAESLARRLGSDPRGMRILLNALAALELIE----TDGRQYRLPPGAAELLSRDSDGYAGDMIHVIaSDYEWDALKNLDGAVRNGGTVL--DEHAETPEYSYWEDFAAFAPHVARPTARVLADALEPWARD-RESLDVLDLACGHGIYGYTVAQRFEQAAVWSLDWENVLDVAAKHAGSMGVRERTNFIAGDMFDVSFGGEYDLVLITNVLHHFSDERARELLSRAAAALRPGGKIGIVGF---TTSDAPPALDPAPHLFSVLMLVwTSEGEVHSERNYRRMFTDCGLEepSVHQVENLPFRVLLADRA------------------
+>UniRef100_A0A932TFF8 203 0.313 1.686E-52 28 365 384 10 335 336
+----------------------------PRSIMDILWSMVPIRVLTAAVKLQIFAPLEMSPQTAAEVARQVEADARGIRMLLDALVGIGFLTKAGD----HYALTPVARAHLVPGKTGYL-GDYVAGSARMAERWGGLAEAVRTGQPVM--AVDEQETAEEYFSTLVRALQVTNGPPAQRLAAHL----ASRRSAARVLDVACGSGVWGIYYALADPQARITAHDFPTLLELARQYIRNHGVEDRFEYLPGDVRTVDFGvEHYDVAILGNICHSEGEAGSRALLRQMARALRPGGTAAIIDMIPNEARTGPP----FPLLFALNMlLHTREGDTFTLAQYTAWAREAGLERVETVDIGSHsPAILATRP------------------
+>UniRef100_A0A349GW63 198 0.285 1.280E-50 38 365 384 1 318 319
+--------------------------------------FYDSCVLFTASDLGIFNHLAQHPdATAADLASACQLDLRGATLLLDGCVALDLLTKTGD----RYRNTPETACFLVPGAPGDLSKAIRYNRD-VYAAWQQLPAFVKTGKP-VERPEIHLGEDEARTRAFVHSMHGRALGIGRSVVPQLDLAGR-----TQLFDAGGGPGTYSVLIAQANPQIRCTFLDLPGIVKVANELVAAQGMADRVTSIPGDYHTTPFPDGNDVVIFFGVLHQESPASIQDLFRRAYASLVPGGSVYVLDMMTDATHT----QPRFSALFAVNMaLTTTNGWVFSDQEAIDWLTGAGFIGAACRPLPppmPHWLVTATKP------------------
+>UniRef100_A0A7K0ITR6 198 0.293 1.280E-50 28 353 384 2 312 328
+----------------------------PSELLQLSGGYWATCALHAAVKLDLFTCIAGSPATSSEVSRLTNTDHRSMTMLLNAVAAIGLLHF----DNGKYVATPFSAEYLSKNSDKYLGHIIMHHHNLM-PGWSNLDEAVKSGAAVR--SNSSRSDDAADRESFLMGMFNLACLIAPKIVPAIDLSGR-----RSLLDLGGGPGTYAIHFCLHNPELRAVIYDLPTTREFAEQTVQRFGLSDRISFSAGDIITDGIGSGYDVVWISHLLHSEGPAGAATMLDKAVRSSRPDGLVFVQEFILDDDRTAP----LFPALFSLNMlLGTQAGQSYSQQELTQMMINAGVENISRLP------------------------------
+>UniRef100_A0A8J3ZYP0 196 0.298 3.235E-50 25 365 384 0 334 335
+-------------------------MQASADIFHALLAYKKSAMLRTGIELGVFARLAERPATADEVARDLELAPRGSRLLLNALVAIDVLE----ETDGVYRLAPLAAETLDPNRDGYLGElSRILTSRWEWEAMGRLPEAVRRGGPV--IAENAEQLDYGYYEEFATHAGAVTRPTVARMTGTVHDWAAQRER-LNILDLACGHGMYGLTLAQQHPHARLWSVDSAKVLEIAQKNAARLGVADRMQTIAGDMFTLDLGGPYDLALITNVLHHHTPERATELMRRVAAVTRPGGKLVLVGITADDGPVRESPE---AHLFSLLMLVwTDNGEAHSAGSYERMLSAAGYRDMRlyRQDEIPMRVIVAERA------------------
+>UniRef100_UPI0018675D04 184 0.381 4.614E-46 28 366 384 23 356 358
+----------------------------AHRLYSALISSWETAIIEAAYNLGIFSCVASGPATLFEIAKRTACNEECLRILVDALVAYGwLFTNAMPGSDPTYHLPEEYSDVLTAvEGVNDLTGKIYYDQEIAWQYWRNLAHTVKTGSVRN-----VNGISTATYRQLVLGIRFWAPPIAAAIGKALDK-HHFLREDRLLVDIGCGSGIYSHLLLQQHHGLRAVGYDVPEIADIAHESAGKFGVSSRFRMVTGDFFESDW-AAADLYLFANIFHLFDPEKCKILLSKARAGMSDDGRVLIVDAIRASGGSPVTSQEKFAALFAVSMVASGGGNTYSLNTFDSWLAELGLYRIDYLNTPMHGVIVAGWLP-----------------
+>UniRef100_UPI001FC9146D 181 0.559 7.376E-45 111 364 384 1 254 255
+---------------------------------------------------------------------------------------------------------------FAPDGLYSLAGKITHDRNVAWQAWRHLADNVRGGARGEHGGQQVNQISEEDYTSLARGINFWAPPVVSVLADALRERGWGDDTEAVLLDVGCGTGIYSHLLLQAFPQLTARGLDAPRITAIAAEQAERLGVSERFSPLTADFWNDDWGNGTDLALFVNIFHLQTPESAHELLLKTAKGLGEGGLIAIVDHIVDEEAGSGNVQNRFFRLFAASMLATGGGDAYTVHDYDQWLADAGLRRVALLDTPMHRVLLAGR-------------------
+>UniRef100_A0A938CP11 167 0.283 3.444E-40 35 364 384 20 343 344
+-----------------------------------AFGILAAEAVLAGLRLGLIEEVAARPATARQLARKVGAKERGVRVLLDALVALGQLGK----EGEQYCLSASTQMLLsLPGVDaKSYCADALLHLSAFSDGLRQLANVVRTGRPPSADPADTE----RFLVALAGSLFPFNYPVARALCHRIR--GEFGRGPLAILDVAAGGAPWSMPFAQGNRQARVTAVDFPAVLEVARHYAQAAGVEGQYELLPGDIRKAPFGNGqFDLAILGHICHSEGPNRTPRLFRKVAQALKPGGVMLVLDFVADEHRTG-EGSGALALLFALNMLVSaTDGDTFTESQYRLWGVQAGFSGPERLELPaPYPALLFRK-------------------
+>UniRef100_A0A3M1IMN7 153 0.300 1.529E-35 138 365 384 0 218 221
+------------------------------------------------------------------------------------------------------------------------------------------PEVVRTGRPVPRDRRPA-----EEFARFVEALFAGNLPAAQALQAHL--GLRQTTAPCRVLDLGAGSGVWGIGLAEGAPQVRVTAVDWPEVLAVARRLAAEHGVAERFRWIEGDFFEVGLGNDYDLVVLGHILHSEGIERVRRLLERSHEALRPGGRVVIAEFLPADDRSGP----LQPLLFAVNMLVnTEAGDTYTLAELTAWLEEAGFEAVETLPVPaVSPLVLARKP------------------
+>UniRef100_A0A7C2IW83 139 0.282 4.702E-31 151 365 384 2 205 206
+-------------------------------------------------------------------------------------------------------------------------------------------------------PPGVEGRDPAWTEAFIAAMHRGALAAAPAMVATV-----GAAKVRRLIDLGGGSGAYSIAFARANPELRAEVLDLASVVPIAEKHIAEAGLGDRVKTRVGDLLKDEFGSGYDLALLSAICHMFSPEENRDLLRRTFRALVPGGRVVIRDFIVEPDKTAP----KWAVLFALNMLVaTRGGATYTEAEYSSWLEEAGFVSIER---PQADLIVARRP------------------
+>UniRef100_A0A1F5BUK4 139 0.309 8.616E-31 150 365 384 25 235 236
+------------------------------------------------------------------------------------------------------------------------------------------------------GLFAKLGEDSERTRTFVLAMHGRAMGIGQAVVPLLALSGRKA-----VLDVGGGPGTYSILIARAFPQITCTVLDLPEVVRIAEEIISQAGVGDRVQTLAGDYHTIAFPANQDAVIFFGVLHQEDPAAIRNLLRRAHGALRQGGSIAVLDMMTDASHT----QPKFSALFALNMaLTTPHGWVFSEDELKAWLVEAGFTDFNCRPLPspmPHWLATARKA------------------
+>UniRef100_A0A7C4C0Z0 134 0.333 2.399E-29 195 367 384 12 181 182
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------RRMLDVGGGSGAYSIAFARANPELHADILDLPEVLAIAKRHISEAGLEERIATVAGDLRKDKLGENYDLVLLSAICHMLSVDENQDLIRRCFDALAPGGRIVIQDFILEADKTAP----RTAALFSINMLVgTRDGASYSEPEYVDWLAGAGFSDIRRVRLPGPAGLMAGVRPR----------------
+>UniRef100_A0A7C6B4B6 133 0.295 5.934E-29 192 365 384 4 175 177
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------EGVRRMLDLGGGSGAYTIAFAQASPELEADILDLAPVLRIARRHIEEAGLSDRIRTRPGDLHQRSYGAGYDLVFISAICHMLDPKQNRGMLRKSYRALKPGGRVVIQDFILEADKTAP----RAAALFSLNMLVgTRAGASYSEPEYRSWLQDTGFGGISRLHLPgPTSLMIGRRP------------------
+>UniRef100_A0A2V9K3Q8 133 0.331 8.024E-29 185 356 384 8 176 187
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------RLVRLEPGRKHRVLDIAAGHGKFGIAFAREYPKVEIVAQDWPNVLEVARENARAAGVEDRFRTLPGSAFDVDYGSGYDLVLLTNFLHHFDPETCERLLRKARAALAPGARAATLEFVPNEDRVSPPVPATFSLMMLG---STPKGDAYTFSELERMFRNAGFARSELHALPP---------------------------
+>UniRef100_T0ZCY9 129 0.284 1.209E-27 167 357 384 20 205 215
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------RSLHVANRAPAERLAQALEIADRRPLA---VLDIGCGSGIWGIAIAESAPHARVTALDFPQILELTREYATRHGVQDRFEYLPGDLRTAALGtARFDLSILGNIVHSEGEKSSRALFRRLHRATRPGGQLAICDMVPNDERTGP----IYPLLFALNMLVnTTAGDTFTLGEYSAWLGEAGFTEVRTCEIGSH--------------------------
+>UniRef100_A0A831XJT3 127 0.308 7.352E-27 190 366 384 18 191 192
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------DLSGFKHLLDLGGGPGTYAIACLNAHPQIRATLFDHANVVDIAREQVEAAGVSDRVTFVVGDALKDDLGDGYDVILMSNLIHAFDENENRRVVGKCFDALASGGRLIIKDFLVENDRSGPP----FALLFALHMFvHTQGGDTYTFAQVEEWTSAAGFSEGRALPLTPHTHVWLADKP-----------------
+>UniRef100_A0A838MY84 126 0.295 1.811E-26 196 364 384 4 169 170
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------KILDIAAGHGIFGITLARHNPQAEVFAVDWPNVLQIARENAETAGVAARYHLLPGSAFEVEFGDGYDLVLLTNFFHHFDQPTCESLMRKVHAALKDGGRAVTLEFVPDEDRVSPPAAATFAMVM---LASTPSGDAYTFSEYEQMFRNAGFTHSVGYPAPPGHIIVSQK-------------------
+>UniRef100_A0A1V5YR16 125 0.325 3.301E-26 197 364 384 0 165 166
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------MLDLAGGPGTYGITFQQQHPELHVTLFDRPEVVTIAREQVSEAGLSSRFSFIGGDCICDDLGNGYDLVFLSNIIHSFGTEENAGLMRRAYDALVPGGTLIIKDFILDNDRQGPAYGLMFAL---QMLVHTTAGNTYSFEEIQRWTDAAGFRQGESISlTPQTRLWIARK-------------------
+>UniRef100_A0A9D6DY32 123 0.266 1.479E-25 137 367 384 0 224 225
+-----------------------------------------------------------------------------------------------------------------------------------------MSEVVRTGRPVAAVNREKDG--GRFFSEFVEGLFPVSYPAAQALSEVLEISQSK--EPVHVLDLGAGSGVWGIAMAQKSSHVRVTAIDFADVLPVTRRVAQRFGLEEQFHYVAGDVLEADFGGGHNIAVLGHILHSEGEKRSRTLLRKTFDALAPDGTIAIADFIVNEERTGPPP----ALIFAVNMLVnTEHGDTFSFGEIKTWLDQAGFENARpVEANGASPLILATKPGR----------------
+>UniRef100_A0A924W8C4 122 0.328 2.694E-25 195 346 384 16 164 184
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------GRVLDVAAGHGLFGIVIAQRNPGARMTALDWPKVLEVAKLHADRMGVGERLTTIAGDAFEVDLQGPYDLILLTNLLHHFDAQQCTTLLKRLRAALRPGGRLVTLEFIPNEDRVSPAMAATFPLVM---LATTARGDAYTFSELEHMLRAAGF-------------------------------------
+>UniRef100_A0A950V3E9 121 0.327 6.615E-25 194 355 384 3 161 173
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PMKVLDVAGGHGLFGIAFAKQNPNAEVTLLDWAAVAAVGTENARKAGVEKRFKVLAGSAFDVDYGTGYDVILLTNFLHHFDPATIDKLLKKVHAALKPAGRVVTLEFIPNEDRVTPPIAAAFPMLM---LCGTPSGDAYTVSEFQKMFRAAGFSNNIFIPLP----------------------------
+>UniRef100_A0A2W2EUZ9 120 0.308 1.204E-24 197 365 384 5 173 175
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------VLDVACGHGLYGLTLAQRNPRANVWALDWPNVLTQVETHADQLGVRDRLHQVPGDMFQVPLGGPYDAILVTNVLHHFSEQRAGELLARLAPALKPDGKIVLVGFTLGDE---NPADDPAPHLFSILMLaWTYEGEVHSIAAYDRMLTAAGFTTGRRHDVPGlaFRVLVADKA------------------
+>UniRef100_A0A2V5V6K0 119 0.244 2.952E-24 194 364 384 8 176 178
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PVRVLDVAAGSGIWGIALAQQSPLVRVTAQDWPEMIPTTKRITQKFDVADRFSYVEGDVLEANFGTDYDIATLGHILHTEGKDRSRKLLKKVFGALKPGGTVAIGEWLVNDERTEPLPSLIFAV---NMLVHSERGDTFSFNEIKRWLEETGFKKVRKLEAPgPSPLILATK-------------------
+>UniRef100_UPI00215D97BC 119 0.288 3.979E-24 130 365 384 0 224 226
+----------------------------------------------------------------------------------------------------------------------------------MVKGWLRLPEAISGGHP------EPQGPDPEFFTHLTRGLLAVNWPEATELAGQL-----KSRGYQRLLDVGAGSCLWSAALLKELPSARAWAIDFPQVLDgSAQEIVRHLHLEDRFVFLPGNYWKISWGEGYDLIILGHICHSLGPEENVTLFKKARQSLARDGELVIIEFIPDEGRCSPLFPLIFAL---NMLLHTDSGDTYTASEYQDFLARAGLKISERlyLDQGHGSQVIVARP------------------
+>UniRef100_A0A2V6IMK1 117 0.269 9.748E-24 196 364 384 15 181 183
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------RVLDLASGSGIWGIALAQKSPRVQVTAVDWVGMIPTTKRITQKFGVGDRFKFIGGDLLKADFGDGYDVATLGHILHSEGRDRSRKLLKKTASALKSGGIIAIGEWLVNDERTEPLN----GLMFAVNMLVnTESGDTFSFNEIKRWLDEAGFKNARTLEAPgPSPLVLATK-------------------
+>UniRef100_A0A0F5VJF9 117 0.293 9.748E-24 138 363 384 0 218 219
+------------------------------------------------------------------------------------------------------------------------------------------PGVVRMGGPRAGG--ETEVADNPHGEGIVRAIPAVSVPAADA---GVDALGIADAGEISILDVGGGSGIYSSIWLKANPAARAVQLDWEPINVIARRLVGEQGVGDRFTTLDGDFHTTDFGTGlYDIAVYSHIAHQENAHSNIEVFTRLRKALKPGGALVVADYVVDEDRGAP----AFPLLFALEMlLKSNEGGTWRRSDYRDWLIKAGFEDVSFHAAPPATMVIAR--------------------
+>UniRef100_A0A7Y2F7Z1 117 0.269 1.314E-23 210 365 384 0 151 155
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ALLQKNPELTAVVVDRPEVLRVAEEMATEYGVIDRVELMPGDMFDEPLPSNADIVLLSNVLHDWDVPECQQLIYRCVESLAPAGRVVIHDVFLHDELDGPLP----IALYSAALFTLTQGRAYSQREYREWLEAAGLRTVPAVDTLIHCGIIVGQK------------------
+>UniRef100_A0A9E2XDP0 117 0.270 1.771E-23 158 361 384 0 197 202
+--------------------------------------------------------------------------------------------------------------------------------------------------------------DHPIWVKFARAMGPSRVPVAKIVASEL-----AVPSPRKVLDVAAGHGMFGIAIAQATTGAQITAIDWQAVLSVAQENAEAAGVSGRYHTLAGSAFDSDWGSGFDLVLMTNFLHQLDRDACVTLLRKARKSLVSGGRAVAVEFLPNEDRVSP----RFPAMFAFQMLgSTPQGDAYTAREFEEMGRAAGFGKViaKSLPPTPHSLIL----------------------
+>UniRef100_A0A354C2G1 114 0.267 1.426E-22 197 357 384 0 157 168
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------MLDVGGASGTYTIAFLRTVPDMRATLFDMPEVVEMARERLSKAGMLDRVTLVSGDFYQDEFPPGHDLAFVSAIIHQNSPAQNVDLYHKIFRSLDRGGRIVIRDHVMEPDRLHPKDGAIFAV---NMLLGTSGGGTYTYEEIKADLSQAGFTAVRLIKRGEH--------------------------
+>UniRef100_A0A2M7Z0A8 113 0.294 2.585E-22 190 361 384 11 180 194
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------SLKAVRSLLDLGCGPGTYALAFLAQNPTLHATVMDRPAALDVARMLAEQSSSGTRLTYQAGDFLTEHISGTYDVVWYSNVLHIYSPADNLKIFKKVKRILNPGGRLLIQDTFL---HDPTELQPLEANLFAVSMLLyTERGNTYSVRDVREWLQRAGLTRSRVLHLKEGTGDW----------------------
+>UniRef100_A0A955WEZ4 110 0.248 2.783E-21 201 366 384 0 164 166
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------GGGPGTYSALLAQANPELSAEVLDLPGVVAIAREIVGSMGVGDRVTCSPFDYYRDTLAGQYDAALISGVLHREQPAQVQAILANVARVVEPGGVLYISDVMLDDDRVGPV----FAAMFALNMRVLaHDGRCHSVAEQRAWLDEVGCKvtDVTHLPAPIHYTVIRAEKR-----------------
+>UniRef100_A0A931VA74 109 0.323 3.744E-21 198 364 384 0 166 167
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------LDIAAGSGVWSIAFAQEFRSATVTALDFPAVLKVARAYAGKFGVGHRFKYLSGDLRRLDFGkQQHDLIILGHICHSEGRANTIRLLRKSYAALRKDGQVLIADFLPNNRRTGPVMPLMFAL---NMLLNTTEGDVFSVAEYQKWLRAAGFKKIELLRsaPAPSPLILAAK-------------------
+>UniRef100_A0A2V6FEX5 108 0.250 9.108E-21 207 364 384 1 156 158
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------WGIVLAQKSPRVQVTAVDWAGMIPTTKRITQKFGVGDRFKFIEGDLLEADFGESYDIATLGHILHSEGEDRSRKLLKKTANALKSGGTIAIGEWLVNDERTEPLN----GLMFAVNMLVnTERGDTFSFNEIKRWLEEAGFKNVRTLEAPgPSPLVLATK-------------------
+>UniRef100_A0A7Y5FDI1 108 0.304 1.647E-20 195 354 384 21 177 189
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------KNMLDVGGGSAAFSMEIVKKNPSISAVVLDLPYVIPLTKKYVSGAGLSDKFNFIEGDYLTTELKDNYDLILLSAIVHINNYDQNKMLVKKCADVLNKSGMIIINDFVMNEDRT----QPRQSALFALNMLVgTENGDTYTEKEMREWFESAGLSKIERKNT-----------------------------
+>UniRef100_A0A938CHE3 107 0.308 2.976E-20 211 368 384 0 154 158
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------MLKRLPDATALIFDLPTVVAIARECAELAGVSDRVETRAGSYWDDELGEGFDLAIVSNILHSSGPEGCVTILQKTLRALAPGGRAVVHDFILGEDGTTPP----WAALFSLNMLNAGnEGRSYTRGELEEFAAEAGFEATEYRQCTEDTGVVVARKPVP---------------
+>UniRef100_A0A2V8H5I2 104 0.272 2.353E-19 215 364 384 0 149 151
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------NPQTRVTVVDLPGVVEtVTRRFVAREGLSERFAFWPGDLQQIDFGESaFDVIVLGHICHGEGAERTQELLHRAFRALRPGGQILIAEFVPDDDRNGP----LMPLLFALHMLVlTERGDTFTLGEFTEWLTTAGFVDIGTIAAPaPSPLIVATK-------------------
+>UniRef100_A0A3C1Z2Q6 104 0.270 2.353E-19 195 353 384 11 166 180
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------RTMLDLGGGAGTNAIAFCRVYPGLSATVFDLATTLPLTTRTVKDAGLEDRIALKSGDFNRDALGGPYDVVLMSDILHYQNLATNAALVKKIHGHLSPGGRLVIKDRFLDPSGTSPAWTAAFAVHILVN---TEQGACYRTAEAMQWMHDGGYVSVEEIE------------------------------
+>UniRef100_A0A2S9FNB9 98 0.375 1.447E-17 195 306 384 0 111 126
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------RSLLDIGGGHGWYSAQLCRRYPRLTATVFDLPGSAAIGREIIAGAGMADRVVHRDGDATTDDLGTGYDAVLCFNLLHHMTAEQTVHLFGRIHTALAPGGTLAVMDAFAEPGR-----------------------------------------------------------------------------
+>UniRef100_A0A7V9GF58 97 0.333 3.487E-17 197 365 384 2 169 171
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------VLDLGAGACPWSRALAQADAEIRVTAVELPGVAAITRRSIADHGLGDRFRVVEGDLFRAEVGTGFDLVLIAGVCRLFGPTANALLARRAAALVRPGGEVAIVDALPDADRSDGRSNALYALGLA---LRTSTGGVHHLSAYASWLYDAGLAGIElvELDRPELSLVRATRP------------------
+>UniRef100_A0A7V9QC46 96 0.280 8.394E-17 192 361 384 2 161 168
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PHAREMLDLGGGPGTFARAFARR--GLHATLLDRPEVIELVAERYDLRSIPE-LSLQSGDFLDDSPAGEFDIILLANITHIYDPATNTRLIGSLVPQLRPGGVLAILDFVR--------GLSEFAPLFAITMlLNTEQGGTYALEEYTRWLEEAGLGEVRCTSIDLDRQLV----------------------
+>UniRef100_A0A7W0Q6F8 96 0.327 8.394E-17 194 366 384 39 209 639
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------GGSLLDIGGGAGTYTAALLDAHPAATATLVDEAAVIALARQTLARFG--DRVTYVEGDAREVALGDRHAAVLLANVLHLHPPAVCAELCAIAAAAVMPGGQVIVKDLRVDVDHAGPLE----GLMFALNMAVyTDGGDVHDTVQLRNWLATAGLVDIiehRQEAAPDGIVVIGRRPR-----------------
+>UniRef100_UPI00227925D2 93 0.336 6.492E-16 42 151 384 8 113 129
+------------------------------------------KVLHSAVALGVFGALADGPADADQVAAATGLHERMAPDFLDALAGLGLLERTGD----RYGNSPLAEAYLVPGTATYLGGFVELTNETLYGTWGRLTEALTTGPRSTSTP----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A955TCB0 93 0.250 6.492E-16 215 353 384 0 135 145
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------HPGLSATVFDLPQTLRVTKDHVDRAGLGDRIHLQAGNFHVDAFQGSYDLALMSDILHYQDGSTNAALVKKVFACLTEGGRLIIKDRFLDPAKTSPAWTTAFAVHILVN---TECGECFTIQDSRQWMEQAGFRIVEELE------------------------------
+>UniRef100_A0A7X0B2G6 91 0.270 2.788E-15 172 299 384 59 195 245
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------WSSARGKIRAREIILDRIDWRGDERVIDIGCGRGLFTIAAARRVPRGHVIGIDIWQTEDLsgngpgaVIANAAREGVSGRVECRSADMRNIPFPdDSFDVVISSAAIHNLyDPADRARAIREIARVLAPDGRLVISD------------------------------------------------------------------------------------
+>UniRef100_A0A6I4PU19 90 0.333 8.927E-15 47 154 384 52 156 211
+-----------------------------------------------GVRAGVFARLADGPATRAELGEGLGLKPPALHDFLDALVALGLLER---RDGGRYANTAESDFYLVPGKRYYMGHYLTFVDNFMRPTWDGLAEMLRTGKPPAPQARRP-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A2V6KNM4 89 0.236 1.194E-14 223 365 384 1 141 142
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------VDWAGMIPTTKRITRKFGVADRFQFIEGDLLEANFGNGYNIATLGHILHSEGEERSRQLLKKTFRALKSGGTIAIAEWLVNDERTEPLPSLMFAV---QMLVNTEKGDTFSFNEIKGWLEEARFKRVRKLEAPgPSPLILATKP------------------
+>UniRef100_UPI002021C978 88 0.282 3.813E-14 193 305 384 43 155 242
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PPTRVLDLGGGPGVVAARMAERWPGTRVTLIDIDPVLL----TLARDGVPPSVTVLDADLGEPGWTEaagtGYDLVTAVMTVHYLRPESIRALYRHCRQAMSPGGLLVVADLIPDDN------------------------------------------------------------------------------
+>UniRef100_A0A933SQF1 87 0.295 9.097E-14 226 357 384 1 129 141
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PPVLEAARQYIQRYGLEDRVGVRPGDFLTDDMGSGYDLVLLANVVHMYGAENSSALIKKSAAALASGGRIIIHGFCVDGDGTGPMEDVLFNLNIG---MLTDAGRAHPVEEITGWLERAGISRVRHFRIEGH--------------------------
+>UniRef100_A0A560H420 87 0.262 9.097E-14 172 299 384 59 195 245
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------WSSARGKIREREIILDRIDWRGDERVIDIGCGRGLFTVAAARRVPRGHVVGIDIWQVEDLsgngpgaVIANAAREGVSGRVECRSADMREIPFPdNSFDVALSSAAIHNLyEAADRARAIREIARVLAPDGRLVISD------------------------------------------------------------------------------------
+>UniRef100_A0A1Z4QIT9 86 0.264 1.215E-13 194 351 384 42 209 225
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PLKVLDLGAGTGLYSGMVQAVFPNAEFTLLDLaPEMLEKAKSRFSKMGKSPKI--LIGDYVETDLGGSYDLVISALSIHHLSDVDKKRLYQQVYHVLSPGGMFVNADQVLgkTPDLEKLYRQNWLDSVIAKGISQEDLKAAQKRMEYDrmtpldiqlAWLDAAGFQDVDC--------------------------------
+>UniRef100_A0A163GAS4 86 0.238 1.623E-13 193 299 384 49 155 223
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PGQRILDLGCGTGTLTVQLKQSYPESEVTGLDIDPdVLRMAEAKAAQRHLS--IKFDQGNSYELPYPDhSFDRVVTSLMFHHLTTTNKLQTLKEIFRVLKPEGELHIAD------------------------------------------------------------------------------------
+>UniRef100_UPI0021A33580 85 0.240 2.168E-13 194 299 384 50 155 223
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------GQRILDLGCGTGTLTVQLKQSYPESKVTGLDIDPdVLRIAEAKAAQRHLD--IKFVQGNSYELPYPDhSFDRVVTSLMFHHLTTTNKLQTLKEIFRVLKPEGELHIAD------------------------------------------------------------------------------------
+>UniRef100_A0A1F8NFU5 85 0.286 3.866E-13 188 299 384 38 150 193
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------LLPFEGTARIrgLDLGAGTGVLAEGILRRYPLAEVTVFDLSdNMLAAARERLRK--FENRITFLKGDFSKDEFGIGYDLILSGLSIHHLTNPHKQQLFRRIYLALNPGGVFLNRD------------------------------------------------------------------------------------
+>UniRef100_UPI0015517974 85 0.272 3.866E-13 192 299 384 45 152 218
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------DSGQRVLDLGCGTGTLTLLLKQVYPKAEVTGLDIDPnVLRIAEKKAVDMGMD--IVFNQGMSFELPYPDhSFDRVVTSLMFHHLTLENKLRTLKEIFRVLKPQGELHIAD------------------------------------------------------------------------------------
+>UniRef100_A0A4R4UIL7 84 0.298 5.161E-13 190 300 384 76 189 236
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------RLKGGERLLDLGCGRGAVLTLAARRLDEGHVTGLDqRSKGAPRARANAEREGVADRVSLVVGDLRDLPFeDGAFDVVVTDQAIHtITRPAGREQAVREALRVLRPGGLMLIADP-----------------------------------------------------------------------------------
+>UniRef100_UPI001942DF3D 84 0.285 5.161E-13 190 304 384 40 154 244
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------HGGAPARCLDLGGGPGVLAERMAARWPSCRVAMVDLDPVLL----TLARAGVPDTVAVIDADLGSGSWahcaGRGHDLITSVMTVHYLPPSGIRRLYRECRDALAPGGLLVVADLMPDD-------------------------------------------------------------------------------
+>UniRef100_UPI0018F5EC24 84 0.275 5.161E-13 193 304 384 47 158 264
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PPRRVLDLGGGPGVLAERMARRWPSAAVSLLDLDPVLLALARSA----LPGRVSVLDGDLASAGWtalaGGGHDLITVVMTLHYLPAERARAVYAHARRCLAPGGVLIVADLMPDD-------------------------------------------------------------------------------
+>UniRef100_A0A6L4ZQ25 83 0.237 9.196E-13 213 364 384 0 151 152
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------KQYPEANVIALDFPGVLSVTRQYVQQLDAEAQYDYLSGDLNTLSFGKNcYDLVILGHICHSEGERRARKLIKKSAQALRDGGTLLIAEVLPNEDLSSP----LLAMLFSLNMLVfTSEGDVFPASQYQKWMAEVGLKEFEVLDkiPSPFPLLLATK-------------------
+>UniRef100_A0A941ADV1 83 0.338 9.196E-13 24 153 384 451 574 587
+------------------------PVRNVDDLMTVGHGYQRSMVLLAALRLGLFRALAGGAAVAGVLARRVGADAKKLSILLDALAALGLVEK----RGRRYRNAKPARDLLLPG-PHSKESILLHHLDG-WGEWGRLPSTIRAGRNPRAGAQG--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A388RG35 82 0.292 2.185E-12 20 108 384 18 102 108
+--------------------DPARPVLTPERLLQLGMGFWPAKTLLSAVELGVFTRLADGPLDAPTLTEALGLHPRSALDFLDALVALQVLER----DDGKYRNAPDTA-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A938ND12 82 0.369 2.915E-12 247 356 384 2 108 119
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------FLAGDLREGRFGEGYDLALVSAICHMLSPADNAALLGALRRALAPGALLVISDFILDENR----ATPSFASLFGINMLVgTGGGDSYAESDYRAWLAAAGFAEVRCLALPG---------------------------
+>UniRef100_A0A4R4NF76 81 0.298 3.888E-12 190 300 384 76 189 236
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------RLEGGERLLDLGCGRGAVLTLAAERLPEGHATGLDQHaREAPRASANAEREGVADRVSLVVGDLRDLPFENdAFDVVVSDQAIHAIARRQgREQAVREALRVLRPGGLILIADP-----------------------------------------------------------------------------------
+>UniRef100_A0A7W0ZUZ6 81 0.269 3.888E-12 197 298 384 44 143 272
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ILDVGCGTGEITSRLAGEFPRATIVGVDiIEPHLALARTRYPE--LADRVTFREADAFELPFaAGSFDLVVCRHMLQAIPHPE--RVLAEMVRVAKPGGVLHII-------------------------------------------------------------------------------------
+>UniRef100_A0A533ZC60 79 0.313 2.185E-11 252 353 384 0 98 108
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------FHRDDLGGPYDAVLMSDILHYQDPDANAALVRKVHRALAPSGRLVIKDRFLDDGRTSPAWTAVFAVHILVN---TDKGRCYTMAEAVQWLKDAGFTSVDELD------------------------------
+>UniRef100_A0A968BQZ9 79 0.261 2.185E-11 195 299 384 20 124 172
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------GRIIDVGCGSGATNLVLAERFPRAEIVGIDLSdPLLRLAREATANTSFGDRVAFERADVQQIPYdDDSFDVAISTNMVHIV--EHPLRMLGEIERILAPDGHLFIVD------------------------------------------------------------------------------------
+>UniRef100_A0A923PGD4 78 0.308 3.880E-11 182 300 384 44 160 217
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------RALQQLPISPAANISILEVGCGTGHNLVALAKYFPNAYVTGIDLSEdMLAIAAKKVARFG--GRVTLEEGAFGVVPLEEKYDLIVFSYCLTMVNP-DWDKLLEVARKSLPPTGMLTVVDF-----------------------------------------------------------------------------------
+>UniRef100_A0A4R7V7R0 78 0.282 3.880E-11 198 305 384 50 165 260
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------LDLACGPGAISDRLLRRLPKARSVAVDVDPVL-LAIGQAALGDVAGRLRWVRADLRDQDWtdalgadgaDGTFDAVLSSTALHWLDPATLVATYRRAYRLLRPGGVLLNADYLPHPE------------------------------------------------------------------------------
+>UniRef100_A0A401FX31 78 0.261 5.170E-11 195 299 384 47 151 201
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------GKILDVGCGFGAVAIELAKTFPDAEITGIDLgEPLLRLGESEARKAGVADRIHLLKGDVRKTEFPtDAYDVVTNTFMLHIV--ENPIAMLNEIERVTKPEGKIMITD------------------------------------------------------------------------------------
+>UniRef100_A0A957TIY1 78 0.280 5.170E-11 193 304 384 42 153 205
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PPGRALDLGCGFGRACRYLAQH--GWQCDGVDFvEQAIVTARQRAADAGVADRITFHVGSVGELDFlQPPYDLAIDVGCFHAQPEAVCVQYAKHVARLLKPGGLFLLFAHLRDE-------------------------------------------------------------------------------
+>UniRef100_UPI00030C1C4B 77 0.264 9.176E-11 192 301 384 98 214 269
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------DGKGKLLDIGCGSGAMSIKAAKKFPDVVVTGMDywgagWDYSKTLCESNAKIEGVAERITFQKGDAAKLDFsDGTFDAAISNFVFHEvMSQPDKFALVREALRVVKPGGYFVFEDIF----------------------------------------------------------------------------------
+>UniRef100_A0A532DC31 77 0.274 1.222E-10 241 352 384 0 109 128
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------MGHRLSYLPLDFIKHAIPGRYDVVWLSNVLHIYSPAENRKLLRNIARVLAPGGRLLIQEALLHDRHD---LAPLGANLFAVTMlLFTDRGNTYSVREATDWLMCSGFQRVSLL-------------------------------
+>UniRef100_A0A2W2GWE1 77 0.294 1.222E-10 196 306 384 79 186 243
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------RALDVGCSEGAFTRRLARAYPEAECVGVDVSAQA-VARAAAKAAGTA---RFVALDFLNDDPGGIFDLVICAEVLYYVGRGERLRLiFERFRTFMAPGGVLVLVHEWPEARR-----------------------------------------------------------------------------
+>UniRef100_A0A1J5AV34 77 0.300 1.222E-10 194 303 384 34 144 256
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------GGPLLELGCGTGRLAIPLAQA--GYQVTGVDLSPaMVTIARDKAARAGVTQRVTLIQGDYTDTPLGGPYRLafVVMNTFLHLLSQADQLAALRHWAAHLTAGGLLLIDVMYPD--------------------------------------------------------------------------------
+>UniRef100_A0A0F9AVU4 76 0.250 1.628E-10 240 365 384 0 122 133
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------GKGVKLNYKGGDFTCDSIGNSYDLILVSQIYHAYSEAASLELTKKCHDALVPGGRIAVQEFAISKDRTSPPG----GALFSVNMLVgTEGGNTYHTSHISDWLKEAGFKQVKVKTLSETVLVTARKP------------------
+>UniRef100_X1UPS3 76 0.291 2.168E-10 232 365 384 0 132 138
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------AEKRLTEAGLSDRVTLVAGDFYEDDLPTGPDFTFLGAIAHQNSREQNRALFVKVHAALAEDGLIVIRDVVMDPSHTSP----QAGALFAINMLVaTPAGGTYTFDEYAEDLTNAGFTDITLVhrDEFMNSLIRAKKK------------------
+>UniRef100_A0A2P9H162 76 0.291 2.886E-10 178 296 384 28 142 265
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------RTVAEKLAASLPAGEKFDSILEIGCGTGSLTELLRRRFPRALIYAVDVArPMIDLARERI---GECSRIHWHVADARQFRPGRDFALIISSSALHWMTP--VSETVKRLAGMLEPGGSLV---------------------------------------------------------------------------------------
+>UniRef100_UPI00035C64DE 75 0.293 3.842E-10 198 308 384 49 163 262
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------LDLACGPGTISARLLERFPKARCVAVDIDP-LLLAIGQGALSTMDGRLHWVDADITTDSWlqaigDEQFDVVLSATALHWLTPAQLITTYRDISAALRPGGLLLNADRLEFDERSP---------------------------------------------------------------------------
+>UniRef100_A0A960X0C7 74 0.300 9.061E-10 244 360 384 0 114 118
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------RVSFTGGNFFTDDLP-QADVLLFGHILHDWDLETKLMLLRKAYAALPPGGAVVVYDSIIDDER----KKNAFGLLMSLNMLIeTPGGFDYTGADCMGWMRQVGFQEccVEHLVGPDSMVI-----------------------
+>UniRef100_A0A5C7FQJ3 74 0.292 9.061E-10 196 300 384 53 155 213
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------RILEVGCGTGHNLVSLAERFPAAEITGIDLSqDMLSIARKKLRRFG--GRVSIVHGAFGSDSFREQFDVVLFSYCLTMVNP-GWDTLIEVATASLRDNGVLVAVDF-----------------------------------------------------------------------------------
+>UniRef100_A0A938BIS1 74 0.289 9.061E-10 193 296 384 40 142 255
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PPLNVLDLGCGDGRHSLEMARR--GYTVTGLDLSeELLTRARERADDAGLT--LVFRQGDMREIPYMQAFDLVvnFFTSFGYFATDTENARVLHAIARSLRPGGRFL---------------------------------------------------------------------------------------
+>UniRef100_A0A7W7G205 73 0.317 1.604E-09 249 353 384 133 235 254
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PGNFLTDtHLPDGHDVVLFSMILHDWDQATNRELLAKAYEALLPGGLVVVSELLLNAERTGPAP----AALMGLNMLVeTEGGRNYSDAEYGQRLTGAGFTEVRTVP------------------------------
+>UniRef100_UPI001CC7CCF7 73 0.245 1.604E-09 197 299 384 203 308 605
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ILDIGCSVGVSTRFLADKFPEAKATGLDLSPyFLSVAQYKERERALRKNpIKWFHANGENTGLPsKSFDLVSLAFVIHECPRRAIIGLVEEAFRLLRPGGIIVLTD------------------------------------------------------------------------------------
+>UniRef100_A0A2V6APU7 73 0.281 2.135E-09 258 365 384 0 105 106
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------GEGYDIAILGHILHSEGEDRSRKLLKKTANALKPGGTIAIGEWLVNDERTEPLN----GLMFAVNMLVnTERGDTFSFNEIKRWLEEAGFKNARTLEAPgPSPLVLATKP------------------
+>UniRef100_A0A7Y5VEE5 72 0.301 2.840E-09 196 300 384 52 156 212
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------RILEVGCGTGHNLRNLALHFPNARITGLDASaDMLAIARSRTRQFPERIQLVEKPYALGEEGFREQYDLVLFSYSLTMINP-QWEELLQQACKDLKPGGFIAVVDF-----------------------------------------------------------------------------------
+>UniRef100_A0A7Y4TBV2 72 0.333 3.777E-09 190 300 384 40 151 218
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------DPRPSERILDVGCGTGTLAVLLKRRMPSCEIVGLDPDPqILELAREKARSAGVA--IDFRQGFARDANtlGGTGYDKVVSSLVFHQTPMAEKSVGLRSMAAAAKATGELHVADY-----------------------------------------------------------------------------------
+>UniRef100_UPI0021F385A6 72 0.321 5.024E-09 26 109 384 2 81 90
+--------------------------PGVFDVIDMMTGYQPAAALTAAARLGVFDVLADAPLPADAVAARLGTEPRATRALLDALTGLGLL----GTDDGGYTAAPVARR----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A3C0ICS3 72 0.299 5.024E-09 196 300 384 53 156 215
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------SILDVGCGTGVNSAKMARLFPNAHITALDVSEdMLAQAAKRLKPFG--DKVSLVHQPYEKNPAhSERYDLIHFSYALTMINP-QWQDLLEQAQADLKPGGVIVVADF-----------------------------------------------------------------------------------
+>UniRef100_A0A2V5J6W1 71 0.245 6.681E-09 247 362 384 1 114 116
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------FIEGDLLKADFGEGHDIATLGHILHSEGEERSRKLLKKTANAVKSGGTIAIGEWLVNDERTEPLN----GLMFAVNMLVnTENGDTFSFNQIKRWLAEAGFKNARTLEAPgPSPLVLA---------------------
+>UniRef100_A0A2T5C4D6 70 0.271 1.181E-08 245 351 384 17 120 136
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------IKTYTGDYTKDDLPEGFDLVFLSAIIHSNPLETNQQLIKKCFKALNNKGQIIIQDWIMNDERTEPTTGAIFAI---NMLVGTDGGDCFTEQEVSDMLTTAGFKQIQR--------------------------------
+>UniRef100_UPI001F4F924D 70 0.305 1.181E-08 70 154 384 4 85 140
+----------------------------------------------------------------------LGLKPPALHDFLDALVALGLLER---RDGGRYANTAESDFYLVPGKRYYMGHYLTFVDNFMRPTWDGLAEMLRTGKPPAPQARRP-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A7C6B0U6 70 0.312 1.570E-08 31 142 384 8 114 116
+-------------------------------VLELTDGFRVAAVVGAAAELGVFEAIPEQGITADELASRLACSIRGIQVLCDALAGLSLLEK----RDGTYFLPPKLRPVLRESGAETVIPMLQHRMNMM-RGWANLPWTVK-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A7L5YJU2 70 0.288 1.570E-08 244 366 384 1 121 123
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------RVTVQEGDFQREGFGRGYHVALVFGVLNGEPPEGRPALIRKVYDCLEPGGKVVLRDFALDDDRAGQPE----AAIFAPDAAGDGVRGLDTRGDWTNWLTAAGFAPPQTlaLPDGVGTLTIAHKPT-----------------
+>UniRef100_A0A0F9D6E6 70 0.288 1.570E-08 205 364 384 3 151 155
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------GTYSKELLKK--EFEVTLLDTQGLTEMAKDHLKDTS----VKILAGDFNERLPNEKFDVILLSNITHIYKPEKNEALLSRVEKHLSPGGLIAIVDLIRSKSKG--------AAMFGVNMLvHTAGGGTWTLPQYEKWLHHAGLRLISVKDLKDADqkLLLAER-------------------
+>UniRef100_A0A928TDB3 70 0.287 1.570E-08 194 300 384 45 143 201
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PHRILEVGCGTGHNLQLLRRQFPDADITGIDLSaDMLRVAREKVPGVSLIQRAY--------DAPAGSFDLIVCSYALSMFNP-GWDRAIATAAQDLVPGGIIAVVDF-----------------------------------------------------------------------------------
+>UniRef100_A0A535YYI5 70 0.297 2.088E-08 257 364 384 2 109 113
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------LGGPYDLVVMSHVLHHFDEGRCVELLRRAAAATRDDGRIVIQDFVATGDEHG---RDVAAGLFSVIMLVwTRQGEAHPLARLERMLAAAGYGPPEVHPLPqlPTTVLVAGR-------------------
+>UniRef100_A0A552ZP34 70 0.283 2.088E-08 174 297 384 19 148 218
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PWVIDEPQPAVIALAEAGHISGRVLDVGCGTGEHTIYLTRA--GYDVLGVDgAPTAVDIARRNAAQRGVAAR--FAVGDAFELDAfeldaieggAQGYDTVLDSALFHVFDDADRVRYVRSLGRVTRPGGVVLV--------------------------------------------------------------------------------------
+>UniRef100_UPI00068E7623 70 0.274 2.088E-08 196 299 384 89 201 246
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------KVLDVGCGRGLLTILAARKVPLGDVTGVDIwsqeelsENSKEAAVENARLEQVSERIQFEDGDVRALGFrSHSFDKIVSSLCLHaIASRNDRNQAIANLIKLLKPGGEIAILD------------------------------------------------------------------------------------
+>UniRef100_A0A518C5C2 69 0.213 2.775E-08 188 297 384 115 250 293
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------GRSKDESVRILDMGTGSGIIAVTIAKQAPQANVLATDVSEkAIVVAKQNAEKHGVSERVEFAAGDLFQAvPSGSSFDVIVSNppyiaqserplmdaHVIEHEPhgalfaDEEgtsvLRRILEEAASFLKPGGWLLL--------------------------------------------------------------------------------------
+>UniRef100_A0A124DZA2 69 0.323 3.688E-08 197 297 384 42 139 209
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------VLDIGCGAGEHTILLARL--GYDVLGMDFaPRAIEQARANAAARGVDAR--FEVGDALRLAGTSTYRTVIDSALFHIFDDADRAAYVRSLHGVCRPGGLVYV--------------------------------------------------------------------------------------
+>UniRef100_UPI00232F94DE 68 0.283 4.901E-08 198 298 384 8 105 188
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------LDVGCGDGTLARLLASR--SGRVTGIDLSaEMVESARDQSKE---VENVRFLEADFLEAGgrelPPGHYDLITMVAVAHHLGTE---RALARSAELLAPGGRLAVI-------------------------------------------------------------------------------------
+>UniRef100_UPI000513702D 68 0.317 4.901E-08 195 297 384 40 139 209
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------GRVLDIGCGTGEHTILLTRL--GYDVLGVDgAPTAVEQARRNAAAHGVDARFEVR--DALDLGTTPTFDTVVDSALFHVFDADDRARYVRSLRGVTRPGALVAV--------------------------------------------------------------------------------------
+>UniRef100_A0A7C4TB43 68 0.315 6.512E-08 199 271 384 1 73 77
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------DLGGSHGLHSIRFCRRYPNLSATVFDLPQALEVARETIAAEEMGDRVAVQGGDFLADDVGTSYDVAFLFNILH----------------------------------------------------------------------------------------------------------------
+>UniRef100_UPI001CEC1545 68 0.273 8.652E-08 198 298 384 27 124 207
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------LDVGCGDGTLARLLAAR--AWRVTGIDLSaEMIESAREQSEE---VENARFMEADFFEASggelPLGHYDLITMVAVAHHLGTE---RALARSAELLAPGGRLVVI-------------------------------------------------------------------------------------
+>UniRef100_UPI000C7EEDCD 68 0.310 8.652E-08 196 297 384 41 139 209
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------KVLDVGCGAGEHTIMLTRL--GYDVLGIDFaPHAVAQARENAAAKGVDAR--FEVADALRLGTEPGYQTVVDSALFHIFDDADRARYVRSLHTACRPGGVVHV--------------------------------------------------------------------------------------
+>UniRef100_UPI001FDABEAA 68 0.321 8.652E-08 191 295 384 76 184 224
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PADARRVVDLGCGTGVLAVTAARTLPEASVLALDVSRAAVLsATATAAANGVGDRVEVRRGHLLAGVPDADVDLVLCNPPFHRGNSRDSAvafEMLADAARALRPGGEL----------------------------------------------------------------------------------------
+>UniRef100_A0A7W1S605 68 0.279 8.652E-08 197 299 384 114 224 250
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------VLDVGCGSGVLLFACLKESPTAKGTGIDIydpysfGGTAGVFWKNADVEGLKERVALQQVDARTMPFaGQRFDVIVSSLAMHHVgNAAEQEKATREMVRTLKPSGKIAICD------------------------------------------------------------------------------------
+>UniRef100_A9HJ62 68 0.323 8.652E-08 194 297 384 47 148 474
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PARILELGCGTGFLSAHLRRLFPDAILTVTDLaPEMVERARARLTPLGGDVRYAVVDAE-DPASVGTGFDLICSSLSMQWFTDPAA--TLDRLAARLAPGGMMAL--------------------------------------------------------------------------------------
+>UniRef100_UPI000B2ED3AD 67 0.317 1.149E-07 195 297 384 40 138 208
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------GTVLDVGCGAGEHTILLTRL--GYDVLGVDYaPSAVEQARRNAEATGVDAR--FDVADAMDLG-EAGYDTIVDSALFHIFDETDRPRYVRSLHAACRPGGLVHV--------------------------------------------------------------------------------------
+>UniRef100_UPI001E2F3222 67 0.336 1.527E-07 195 297 384 40 139 209
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------GKVLDVGCGTGEHTMLLTRL--GYDVLGIDFSPhAVAQATDNAARRGIDAR--FAVADAMQLGNGPRYDTILDSALFHIFDDADRPRYVASLHAACAPGGTVHV--------------------------------------------------------------------------------------
+>UniRef100_UPI0021B59A1B 67 0.305 1.527E-07 195 297 384 40 143 213
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------GRVLDVGCGAGEHTILLTRL--GYDVLGIDFaPHAIEQARENATSKGVDAR--FDVADALALGSSElaepGYETIVDSALFHIFDDADRPRYVRSLHAACRPGGLVHV--------------------------------------------------------------------------------------
+>UniRef100_UPI00083A0E7F 66 0.307 2.028E-07 190 305 384 40 156 220
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------DPQPGERILDVGCGTGSLAILLKSREPRCEVVGLDPDaEALVIARSKVIRLGLEIGFAQGFAREARDVCGTGFDKVVSSLLFHQVLPVEKRAGIKAMAAAARAAGEIHIADYAEQPD------------------------------------------------------------------------------
+>UniRef100_A0A951AB46 66 0.283 2.028E-07 174 297 384 97 242 283
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------RPETEELVELLKAETGKWKPG-SILDVGTGSGVIALSLAKEFPEAKVFAVDVSEdSLVLARANGARLGLNERVQFQQGDLLE-GLGERFDLVvanlpyismsdrhlLSREVLHdpevslfagDHGDELIRKLIEQTPARLEPDGLLAL--------------------------------------------------------------------------------------
+>UniRef100_A0A9C9N7Y1 66 0.292 2.694E-07 195 303 384 35 145 256
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------GRILDVACGTGMHAIELARR--GYEVTGSDLSaGMIERARVNAAQAGVKARFEAVSfGELAAALDGATFDALLClgNSLPHVLTNADLAAALFDFAACLRPGGLLLIQNRNFD--------------------------------------------------------------------------------
+>UniRef100_A0A957GH47 66 0.252 3.577E-07 258 365 384 0 107 111
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------GAGYDLALVFGVLNGEPPEGRPALIHKVFAALNPGGQIVLRDAVLDSDRAGPSEAALFAL---QMLLATESGGLDTRADWAKWLGKAGFLPPKEIELPgpvGSTLTIARKP------------------
+>UniRef100_A0A7V6DNW2 66 0.269 3.577E-07 193 301 384 47 161 266
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------GPAKVLDVGCGDGYATLNVARKLPEFTFAGIDYSaNMIRLAKERLAGLPsLTRRLTFKVGDVLDLGAACGetiFDAVISDRCLiNLADKADQEHAIKEIARHVAPGGYYLAVENF----------------------------------------------------------------------------------
+>UniRef100_A0A651H3F6 66 0.244 3.577E-07 193 297 384 112 242 284
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PPVRILDLGTGGGALALALARRFPEAEVTGLDTSaEALDLAEENAVRNGLADRVRWIRSDWFAgLGQTAGFNLVVANppylteeewataepEVKDHDPRvalvaaddgcAELLRILQEAPARLAPGGRLYL--------------------------------------------------------------------------------------
+>UniRef100_A0A510TTZ0 65 0.285 6.308E-07 195 296 384 39 139 252
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------GTLLDLGCGSGDFAVQMAQA--GWTVTGLDLsPEMLTLAEARAEQAGVD--VTWVQGDMRRLTGLGTFDAVTSFDdsLCYLPDLTAVQETLLAAAGVLVPGGYFF---------------------------------------------------------------------------------------
+>UniRef100_A0A937NK55 65 0.330 6.308E-07 196 298 384 36 139 260
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------RVLDAACGTGMHAIALAQQ--GYVAVGADLSaGMIQRAQDNAMAAGVDARFEVAGLGKLSARVGTGFDAVLClgNSLPHLLTPADLAAALADFAACLRPGGLLLIQ-------------------------------------------------------------------------------------
+>UniRef100_A0A3C1M550 64 0.244 8.374E-07 265 350 384 1 84 118
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------WYSNVLHIYSAEENQALFRRLYSALSPGGRLIIQDVFLHD--REGLYPEEASLFAVSMLLVTPAGNTYSFSETAEWLRAAGFVRIR---------------------------------
+>UniRef100_A0A497AHF1 64 0.285 8.374E-07 196 303 384 36 145 262
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------RVLDAACGTGMHAIALAQQ--GYEVTGTDLSaGMIERARVNADAANVDVRFEAAGfGELARRFAPGSFDALLClgNSLPHLLTSADLAAALADFAACLRPGGLLLIQNRNFD--------------------------------------------------------------------------------
+>UniRef100_UPI000489D375 64 0.317 1.112E-06 196 296 384 38 138 246
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------RVLDLGCGAGDLLLAVRRYRPRASLTGIDISPlniQAAVARAKADPYGLES-VTFEAADYLRAQF-ETFGVILAESVLHLIADDH-DGLAAKLAADLAPGGLLI---------------------------------------------------------------------------------------
+>UniRef100_A0A931W3L4 64 0.269 1.476E-06 182 267 384 153 241 347
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------RSYAFACLDGKPPLKILDLGTGSGAIAVSLAKELPQARVCAVDISaAAIEVARLNARRHGVEERMEFFCGDLFEPvaEEREGFDLIVAN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_UPI000C7DCED2 63 0.320 1.959E-06 196 297 384 41 139 209
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------KVLDIGCGTGEHTILLTRL--GYDVVGIDFsSNAIEQARANAADNGVDAR--FQVADAMNLAPDATYQTILDSALFHIFDQADRVRYVHSLHGALRRDGLVHV--------------------------------------------------------------------------------------
+>UniRef100_A0A960RG96 63 0.257 1.959E-06 194 296 384 89 216 261
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PLRIVDVGTGTGCIALALALSFPDAVVTGIDASEaALSLARENGLRLGLHDRIHWRHGDGLTGLSPGTVEVVVSNPPYISSDDyralpahirdyepqmalesgpsglEMLVRLCREASALLSPGGMLY---------------------------------------------------------------------------------------
+>UniRef100_A0A518AML2 63 0.280 2.600E-06 192 264 384 127 201 300
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------DAPLRVLDIGTGSGIVAICLAKHLPKSQVTAVDLSPqAIEVAKRNAAKHKVDDRVAFVKGDAYqALPADAKYDFI-----------------------------------------------------------------------------------------------------------------------
+>UniRef100_UPI0021BC03F5 63 0.313 3.450E-06 230 315 384 1 86 106
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ELSRHNAENAGLSDGYQTIAGSAFDVDWGTGYDLVLLPNFLHYFDLPTCAQLLSKIVASLAEDGRIVAVDFVPNEDGVSPPFPEAF--------------------------------------------------------------------
+>UniRef100_UPI002111CB12 63 0.298 3.450E-06 196 296 384 16 116 224
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------RILDLGCGAGDLLLALRRDRPRAILTGVDISPlniQAAVTRAKADPNGHGD-LRFEASDYLQARFD-GFDVILAESVLHLI-VGDHDRLAAKLATDLAPGGVII---------------------------------------------------------------------------------------
+>UniRef100_UPI000314BD16 62 0.319 4.578E-06 192 285 384 42 134 207
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PPFGTALDLGCGTGRHAIELARR--GWQVTGVDIvPKAIRLATRRARAAGVDAR--FLKGDITALPaeVGTGYRLILDFGAFHGLTDPERHTMGRQV--------------------------------------------------------------------------------------------------
+>UniRef100_A0A1F8RJ87 62 0.278 4.578E-06 193 303 384 33 145 256
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------GAQRVLDAACGTGWHAIALAQR--GFDVAGGDLSaSMVARATANAREAEVTAEFRQAGfGDLASAFGRDSFDAVLClgNSLPHVLDPAHLTRTLEDFAACLRPGGMLIVQNRNFD--------------------------------------------------------------------------------
+>UniRef100_A0A8S9DR76 62 0.245 4.578E-06 194 303 384 39 150 257
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PKKVLDSACGTGMHALALAKL--GFEVVGADFSgEMIAKARSNSVEIGLKARFEVIGfGSLAKNLGAGQFDAVLClgNSLPHLHTQNEVDETLKDFASCLRPGGLLLIQNRNFD--------------------------------------------------------------------------------
+>UniRef100_A0A7Y9JCX3 62 0.308 6.074E-06 192 296 384 51 149 210
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------EPPASVLDAGCGTGRIAVRLAER--GFDVVGLDVDaAMLEVARDEAPD------LDWRHADLASFDLGRRFDVVlLAGNIVPLLEPGTLPAVAERLAAHVAPGGRVV---------------------------------------------------------------------------------------
+>UniRef100_A0A357HZI5 62 0.259 6.074E-06 189 267 384 107 187 279
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------RIETAPQRILDLGTGSGALALALATQYPEAQVVAVDQStAALELARENASALELNERIQFLAGSWWTPVMSESpFDLIVSN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A2D7ZK05 61 0.250 1.069E-05 196 297 384 116 242 289
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------SILDLGVGSGAILLAILAERPNAKGLGVDVSEVaLAVARDNAAHLGLAGRCALLRGDWADGLSDAGFDIVtanppyIASEVIETLEPEvrvheprlaldggadgldAYRRLAPEILRVLKPGGRFAV--------------------------------------------------------------------------------------
+>UniRef100_A0A1G6ZTE5 61 0.339 1.418E-05 192 296 384 46 144 207
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------DPPARVLDAGCGTGRIAVRLTEL--GYDVVGVDVD-ASMLAVARAEAPGLD----WREADLATLDLGETFDLVlLAGNIVPLLEPGTLAAVAERLAAHTAPGGRVV---------------------------------------------------------------------------------------
+>UniRef100_A0A944PRW7 60 0.292 2.495E-05 198 298 384 25 122 216
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------LDVGCGEGALVRTLARQVEA--VTGVDCStEMVRLARER--SLGV-PNVTFAEADFLDGSHGlltqGGYDFISAVAVIHHV---RFAEAIRAMVRLLAPGGRLVIV-------------------------------------------------------------------------------------
+>UniRef100_UPI00094622D5 60 0.286 2.495E-05 193 303 384 36 148 255
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------GGFRVLDSACGTGMHTIELARR--GYAAAGADLSaKMIERARENAVSAGVAARFETAGfGQLQAAFGSEAFDVLLClgNSLPHVLSAAELAAALEDFAACLRPGGLLLVQNRNFD--------------------------------------------------------------------------------
+>UniRef100_A0A7W1SHZ7 59 0.329 3.310E-05 47 123 384 37 113 175
+-----------------------------------------------GDELGYYRALAEhGPTTPPELAERTGTDEHYAREWLNAQAAGSYV--TYDADSGRYTLPPEQAIALTDEtSPAFVVGLF--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A8I0C7G1 59 0.250 3.310E-05 196 297 384 116 242 289
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------SMLDLGVGSGAIILAILAERPNAKGLGIDVSeEALAVARDNAAHLGLGGRVALLRGDWTAGLSDESFDLVtanppyIATDVIETLEPEvrvheprlaldggpdgldAYRRLAPEILRVLKPGGLFFV--------------------------------------------------------------------------------------
+>UniRef100_A0A3M1L038 59 0.298 3.310E-05 193 267 384 118 194 290
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------GARRILDVGTGCGAVAVALAVELPGAEIVATDVSEaVLEVAPANAERHGVSDRIEFRCGSLLEPlAAGERFDLIVSN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A413QXR5 58 0.250 7.717E-05 192 276 384 114 205 286
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------DSQAHILDIGTGSGAILLSVLHERPQCRGLGVDISqQALDVARKNGERLGLSDRVSWKISDLLASVPPAAYDWVVSNppyltaDDMHHLQPE-----------------------------------------------------------------------------------------------------------
+>UniRef100_UPI001F5CAA19 58 0.252 7.717E-05 197 297 384 117 242 289
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------VLDLGVGSGAILLAILAERPAAKGLGVDVSeEALAVARENAANLGLGGRVALLRGDWTAGLSDDSFDLVVSNppyiatDVIETLEPEvgvheprlaldggldglDAYRILaPEILRVLKPGGTFAV--------------------------------------------------------------------------------------
+>UniRef100_UPI0020C9A714 58 0.244 7.717E-05 196 297 384 116 242 289
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------SILDLGVGSGAIILSILAERPAAKGLGVDVSaEALAVARENAANLGMASRLALLRGDWTSGLGDASFDVVVSNppyiatDVLETLEPEvkdheprvaldggpdgldHYRRLAPEILRVLKPGGMFAV--------------------------------------------------------------------------------------
+>UniRef100_A0A6I3C5F3 57 0.275 1.357E-04 194 296 384 51 156 212
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PGRVLDAGCGTGRVARELARR--GVAVVGVDVDAvMLATARAKAPELQWIEH-DLASLDLARDptvnPPSSSFDVaVLAGNVMIFVTPGTEAAVLTRLASHVSPGGYVI---------------------------------------------------------------------------------------
+>UniRef100_A0A1V1PS24 57 0.333 1.357E-04 191 267 384 107 183 279
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PDDAARILDLGVGSGAILLAALRERPNAVGVGVDLSEaALEIAQANAEALGLRERVRLVQGD-WGAGLAEAFDVVVSN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_B4RGB7 57 0.248 1.357E-04 194 297 384 114 242 287
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PWSVLDLGVGSGAILLAILAERPAAKGLGIDASeEALAVARDNAAALGLAGRTALLRGDWTAGLGDSAFDLVVSnppyiaSDVLETLEPEvkdyeprlaleggadglDAYRILaPEIVRVLKPGGRFAV--------------------------------------------------------------------------------------
+>UniRef100_A0A1F9AMF8 57 0.325 1.357E-04 176 251 384 108 187 305
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------LVEECLRLLRELSARQGPSAgrlRVLDLGTGCGTIALALAHAFPEAHYLATDLSaEALTLARENAERLGLSRRVTFRQGD------------------------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A2P2E901 57 0.333 1.357E-04 194 267 384 146 219 315
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PNRILDLGVGSGAILLALLAERPSWTGVGVDQSeEALELARENAALHGLSARLDLRQGD-WHHGIDERFDIVVSN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_UPI001653A8F6 57 0.250 1.798E-04 196 297 384 115 241 291
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------TMLDLGVGSGAILLAVLAERPAAKGLGVDVSeEALAVARENAANLGLADRAAFLRGDWTAGLGDESFDLVVSnppyirSAEIETLDPEvrdheprlaldggpdglDAYRLLaPEIMRVLKPGGVFAV--------------------------------------------------------------------------------------
+>UniRef100_A0A2V9FTA7 57 0.313 2.384E-04 265 364 384 1 99 100
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------LLTNIFHRFDMATSEKLMRRVHAALKAGGKAITLEFVPNEDRITPPMAAAFSLTM---LAGTDSGDAYTFSQYEKMFRNAGFARTTEHAVPesPQQLLLLEK-------------------
+>UniRef100_UPI0022A7AF9B 57 0.307 2.384E-04 185 285 384 6 105 173
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------REQARAQAPWGAALDLGCGTGVHAVELARR--GWQVTGVDIvRKAIRRATKRARAAGVD--VRFLEGDITALPaqVGTGYRLILDFGAFHGLTDPERHALGRQV--------------------------------------------------------------------------------------------------
+>UniRef100_A0A348NQJ2 57 0.198 2.384E-04 194 297 384 112 242 282
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PKMILDLGTGSGALALAFANKYPEASVDAVDVSaEALSLAQENALALGLDNRVTFHEGSWWCPLGLGKqhYDLIVSNppyltneemttaepEVVDHEPHSalvsgadglgDMRLIFKDAASHMKPGGLLAL--------------------------------------------------------------------------------------
+>UniRef100_A0A1F4Q3A2 57 0.308 2.384E-04 192 267 384 113 193 291
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PQPPVILDLCTGTGAVAVALARELPAARIIATDISrRALRMARTNAERHGVADRVTFLRGDLWRAldghAPANGVDLVVSN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A2T4Z4G2 57 0.338 2.384E-04 192 255 384 118 182 303
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------DRPLQVADLGCGSGAVAVTLAAERPHWQVIAVDLSPhALALARRNAEIHGVAERIQFRRGDWLQP--------------------------------------------------------------------------------------------------------------------------------
+>UniRef100_UPI0021F9D48E 56 0.267 3.160E-04 199 267 384 117 187 284
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------DLGTGSGILAVTLCVLFPGATGVAVDISPaALEVAKSNAQRHGVSGRIEFQHSDFTEQKFdPESFELVVSN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A7X7AW76 56 0.298 4.188E-04 196 297 384 69 168 283
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------RVLDLACGPGLYTSRLARLGH--TCVGIDYSPaSIAHAEAEAEREDLACRYRLE--DLRSADYGSGFGLAmLLFGEFNAFRPVDARRILNTAHAALSEGGILLL--------------------------------------------------------------------------------------
+>UniRef100_A0A9E0MQ12 56 0.317 4.188E-04 184 267 384 107 191 289
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------LDAVGGDRAAARRGLDLCTGSGVLAITLAHELPGLTMIATDVSaPAAAIARANAQRNRVEDRVEVRVGDRFAPVAGERFDVIVAN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A7X9FKT3 55 0.297 5.550E-04 180 252 384 103 175 286
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------LIDMVREFFPE-DGRERFADLGTGSGCLAVTLAVRFPGWSGVAVDASPaALAVARENAARHGVSERIEFVPGDF-----------------------------------------------------------------------------------------------------------------------------------
+>UniRef100_UPI001CD91C63 55 0.322 5.550E-04 176 267 384 104 194 288
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------LVDTLLPALREAV-SQKGSARILDLGTGTGAICLALLKECPDATGIGSDISaGALETAAKNASRNGLETRFEIRQSDWFE-KISGSFDIIVSN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_L0NLB7 55 0.315 7.356E-04 193 267 384 116 190 289
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------GEARILDLGTGTGAIVLALLKECPQATGVGTDLSEaALQTARENAARLGLAGRFETIRSNWLE-EVTGRFDIVVSN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A2T0LHZ5 55 0.316 7.356E-04 192 267 384 123 201 302
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------DRPLSVCDVGTGSGALAVTLAAERPRWSVWATDISPaALEVARDNARRNGVEGRIRFVRGEWLNPlrHRGVRVDVVVSN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A4P6H758 55 0.287 9.748E-04 184 267 384 107 192 290
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------LPLLAEIGERKGRchVLDLGTGTGAIALALLAATPQARAVGVDISEdALTTAARNARDLGLSERFSAVRSDWFE-AISGRFDVIVSN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A371XFG7 55 0.315 9.748E-04 196 267 384 121 192 294
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------RILDLGTGTGAIALALLHEAPKAEAVGVDISdDALETANENARRLGLGNRFSTVKSSWFE-KIEGRFDVIVSN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_UPI001FE0A581 54 0.293 1.292E-03 192 295 384 84 190 230
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------DRPAEVLDLGCGNGVIAASVARRFGEAvRVAATDVSWlASDSARLTAAASGVE--VAVSQADGLESVADASLDLILTNPPFHRGTARDsapTLRMLAEAARVLRPGGQL----------------------------------------------------------------------------------------
+>UniRef100_A0A929D867 54 0.290 1.292E-03 193 303 384 33 161 269
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------EAHRVLDAACGTGMHSVALAQQ--GYAMTGTDLNaGMVERARANATAAGpVLSRAEGVDVQFEVAGFGElartltpalslpgrgsSFDAVLClgNSLPHLLTPASLAAALADFAACLRPGGLLLIQNRNFD--------------------------------------------------------------------------------
+>UniRef100_A0A546XUT9 54 0.329 1.292E-03 190 267 384 117 194 288
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------SQKGSARVLDLGTGTGAICLALLKECPGATGIGSDISaDALETAAKNASRNGLETRFEIRQSDWFE-KISGRFDIIVSN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A546XBX9 54 0.303 1.711E-03 190 267 384 117 194 288
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------SHKGSARILDLGTGTGAICLALLKECPDATGIGSDIStGALETAAKNASRNGLETRFEIMQSDWFE-KISGRFDIIVSN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A7X0SNC9 53 0.315 2.267E-03 184 255 384 114 186 304
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------LLEAERWQGAELTALDVGTGSGALAVTLAAERPAWRVVASDLSPdALEVARGNARANGVEPRIAFVQGDLLEP--------------------------------------------------------------------------------------------------------------------------------
+>UniRef100_UPI0004786EF2 53 0.304 2.267E-03 188 255 384 119 187 306
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------GWGGGGALTALDVGTGSGALAVTLAAERPAWRVVASDLSPdALEVARGNARANGVAERVTFVRGDLLEP--------------------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A6N7IC32 53 0.301 3.004E-03 193 264 384 32 101 124
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PPGQALDVGCGEGADALWLARR--GWQVTAVDISRVA-LQRAATTGTSLAGRVAWTCADLTATPPPaGAFDLV-----------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A9D7HE84 53 0.324 3.004E-03 192 267 384 113 188 283
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------DAPHRILDLGTGSGAIMLALLKERPNATGVAIDISeEALAVVRANAEQLGVAERLQAGQGN-WAEHIDERFDLVVSN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A7V8DQQ7 53 0.303 3.004E-03 190 267 384 108 185 283
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------SPESAQTVLDLGTGTGAILLALLAERPNWTGLGIDIsSEALDLARENAKMHSLSERAHFQIGN-WAENITEKFNIVTSN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A5B8B2L1 53 0.295 3.004E-03 174 267 384 95 191 286
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------RPDTETLVEAVLPFVRRavqGKGACSILDLGTGTGAIALALLSAAPQAVATGVDISaDALATAARNAADLGLDGRFRTLQSDWFE-KISGRYDAIVSN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_UPI000F83F048 53 0.313 3.004E-03 188 267 384 118 200 298
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------LWPDTAALDVLDIGTGSGAIALTLAAERPRWRVTTVDLSPtALAIARENAQRLQVEYRVRFLEGDLAQPllAAGEQVDLLVSN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A0D8KML8 53 0.328 3.980E-03 193 267 384 121 195 289
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------GSASVLDLGTGTGAICLALLKECPEATGIGSDISaDALETAAKNAARNGLASRFETVRSDWFK-KISGSFDIIVSN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A7C2EBL1 53 0.302 3.980E-03 197 267 384 145 220 320
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ILDLCTGTGAIAIALARELPAARLIATDISrRALRIARTNAEAHGVADRVRFLRGDLWRAlygvMPGRQADLIVSN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A0B1TB99 52 0.310 5.272E-03 198 280 384 3 89 118
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------LDVGCGKGVHSALLANKFPKSNFTGIDvvMDAIQLANQQRKENGDSYENLKFEQMNgaKLDDNWSDKYDLVTIFFAAHDQTRPDLVR-------------------------------------------------------------------------------------------------------
+>UniRef100_UPI00227CA78F 52 0.369 5.272E-03 196 267 384 130 201 299
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------RVIDLGTGTGAIGLTLLCELPQAEGTGTDISqDALATARRNAQRLGVSDRFRTICGNWF-DAVEGEYDLVVSN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_UPI00230011E8 52 0.293 5.272E-03 196 298 384 272 380 417
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------SIVDLGCGNGTISSYVPLKFKEFVGTMIatDSSRDAVAAtAETAKRNGVDSRVDVIRDDAMSTFAPASQDLILLNPPFHVgntVDPQIAPKLFRASARVLTQGGELWCV-------------------------------------------------------------------------------------
+>UniRef100_A0A257JKW7 52 0.320 6.983E-03 194 267 384 117 190 286
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PNRVLDLGVGSGTILLALLAERKSWTGVGIDLSqDALALATENAAHVDLTDRVEFRLGD-WHQGLDERFDIVVSN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A257HLL7 52 0.320 6.983E-03 194 267 384 117 190 286
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PNRVLDLGTGSGAILLALLAERKSWSGVGIDRSeEALALAADNAALHGLSDRVDLRLGD-WHQGLDEQFDIVVSN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A0N1A6W0 52 0.342 6.983E-03 196 267 384 122 193 290
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------SILDLGTGTGAIALALLHECGQAQAVGVDISEdALSTAARNAERLGLASRFETRAGPWF-VRVPERFDIIVSN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A8D5UIE1 52 0.285 6.983E-03 194 255 384 118 180 297
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PIRVVDIGTGSGAIAVTLACERPHWEVWAIDLsPEALATAQTNAEIHGVRNRIVWRQGDLLEP--------------------------------------------------------------------------------------------------------------------------------
+>UniRef100_H0UAL5 52 0.292 6.983E-03 189 267 384 120 201 299
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------WSAEQPLSVVDFGTGSGAITLTLAAEKPNWQLTTVDISlDAIAIAKQNAGRLDVEKRVRFIQGDLVEPilETGERVDIIVSN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A4Q4CP92 51 0.307 9.249E-03 277 353 384 2 75 88
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------QKVALLRRAHAALPEGGALIVYDGMIDDDR----RENAFGLLMSLNMLIeTPGGFDYTGADCRGWMLQAGFREARVQP------------------------------
+>UniRef100_UPI001AE806D2 51 0.302 9.249E-03 193 267 384 116 190 287
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------GEARLLDLGTGTGAIILALLKESPETQGIGSDISEdALKTAAENAARLGLSERFEAIRSDWFE-NISGRFDIIVSN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_UPI001FF66CFB 51 0.263 9.249E-03 193 276 384 116 205 287
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------GSARILDLGTGTGAIVLALLKESPQAQGIGSDISEdALQTASRNAARLGMSERFQAIRSDWF-DAISGRFDIIvsnppyICSGVIPALDPE-----------------------------------------------------------------------------------------------------------
+>UniRef100_A0A1I2LK27 51 0.320 9.249E-03 193 267 384 120 197 297
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------RPLSVCDVGTGSGALAVTLAAERPRWVVFATDIsSAALAVARENARRNGVEERIRFLRGKWLEPlrQGGDRVDVVVSN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_UPI00115A3FED 51 0.315 1.225E-02 196 267 384 119 190 289
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------RILDLGTGTGAILLALLKECPEATGLGADISaDALQTAQANAAALGLQDRFEAVRSDWF-QNIGQRFDMIVSN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A502III3 51 0.292 1.225E-02 189 267 384 120 201 299
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------WDTEQTLSVVDFGTGSGAITLTLAAEKPNWQLTTVDISlDAIAIATKNAERLGVRDRVRFLQGDLVEPMLiaGERVDILISN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A938SZY6 51 0.316 1.225E-02 197 255 384 128 187 299
+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------IVDIGTGSGCIAVALARALPTAVVYATDRSaGALQMARANAARQGVEDRIRFFAGDLFEP--------------------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A7C4LCS8 51 0.329 1.225E-02 180 267 384 119 209 302
+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------VDEALKFLASTERAEAKVLDVGTGSGCIAVTLAVRRPRAAVTALDIaEDALDVARLNAERHGVAGRVAFFRSDLLEGlrLLRPGFDLVCAN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A920LUD8 51 0.315 1.622E-02 196 267 384 46 117 197
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------SCLDLGTGSGCLLLSLLSALPKTSGIGVDLaPLAVSQARANAAQLGLADRAQFICSDWFE-GVEGSFDLVLAN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A949NMS4 51 0.301 1.622E-02 196 267 384 109 180 278
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------RIVDLGIGTGAIGLALLAECPEAQCLGVDVSaEAVAIALENARSLGLSARYSAVTGDWLS-GIEARFDLIVSN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_UPI000A19732A 51 0.333 1.622E-02 194 267 384 117 190 286
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PNRVLDLGTGSGAILLALLSERKSWTGVGIDRSeEALALAAENAALHGLSERVELRLGN-WHQGVDEEFDIVVSN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A7U4XWP3 51 0.333 1.622E-02 194 267 384 117 190 286
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------PNRVLDLGVGSGTILLALLAERKSWTGVGIDRSeEALSLAGENASLHGLTDRVDLRLGD-WHQGLDEQFDIVVSN--------------------------------------------------------------------------------------------------------------------
+>UniRef100_A0A068SLK7 51 0.302 1.622E-02 193 267 384 116 190 287
+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------GSARILDLGTGTGAIVLALLKESPQAQGIGSDISEdALQTALRNAARLGMSERFQAIRSDWF-DAISGRFDIIVSN--------------------------------------------------------------------------------------------------------------------
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..8a176ea
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,76 @@
+[build-system]
+requires = ["setuptools>=64", "setuptools-scm>=8"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "boltz"
+requires-python = ">=3.9"
+dynamic = ["version"]
+dependencies = [
+ "torch>=2.2",
+ "numpy==1.26.3",
+ "hydra-core==1.3.2",
+ "pytorch-lightning==2.4.0",
+ "rdkit==2024.3.6",
+ "dm-tree==0.1.8",
+ "requests==2.32.3",
+ "pandas==2.2.3",
+ "types-requests",
+ "einops==0.8.0",
+ "einx==0.3.0",
+ "fairscale==0.4.13",
+ "mashumaro==3.14",
+ "modelcif==1.2",
+ "wandb==0.18.7",
+ "click==8.1.7",
+ "pyyaml==6.0.2",
+ "biopython==1.84",
+ "scipy==1.13.1",
+]
+
+[project.scripts]
+boltz = "boltz.main:cli"
+
+[project.optional-dependencies]
+lint = ["ruff"]
+
+[tool.ruff]
+src = ["src"]
+extend-exclude = ["conf.py"]
+target-version = "py39"
+lint.select = ["ALL"]
+lint.ignore = [
+ "COM812", # Conflicts with the formatter
+ "ISC001", # Conflicts with the formatter
+ "ANN101", # "missing-type-self"
+ "RET504", # Unnecessary assignment to `x` before `return` statementRuff
+ "S101", # Use of `assert` detected
+ "D100", # Missing docstring in public module
+ "D104", # Missing docstring in public package
+ "PT001", # https://github.com/astral-sh/ruff/issues/8796#issuecomment-1825907715
+ "PT004", # https://github.com/astral-sh/ruff/issues/8796#issuecomment-1825907715
+ "PT005", # https://github.com/astral-sh/ruff/issues/8796#issuecomment-1825907715
+ "PT023", # https://github.com/astral-sh/ruff/issues/8796#issuecomment-1825907715
+ "FBT001",
+ "FBT002",
+ "PLR0913", # Too many arguments to init (> 5)
+]
+
+[tool.ruff.lint.per-file-ignores]
+"**/__init__.py" = [
+ "F401", # Imported but unused
+ "F403", # Wildcard imports
+]
+"docs/**" = [
+ "INP001", # Requires __init__.py but folder is not a package.
+]
+"scripts/**" = [
+ "INP001", # Requires __init__.py but folder is not a package.
+]
+
+[tool.ruff.lint.pyupgrade]
+# Preserve types, even if a file imports `from __future__ import annotations`(https://github.com/astral-sh/ruff/issues/5434)
+keep-runtime-typing = true
+
+[tool.ruff.lint.pydocstyle]
+convention = "numpy"
diff --git a/scripts/train/assets/casp15_ids.txt b/scripts/train/assets/casp15_ids.txt
new file mode 100644
index 0000000..b3464dd
--- /dev/null
+++ b/scripts/train/assets/casp15_ids.txt
@@ -0,0 +1,72 @@
+T1112
+T1118v1
+T1154
+T1137s1
+T1188
+T1157s1
+T1137s6
+R1117
+H1106
+T1106s2
+R1149
+T1158
+T1137s2
+T1145
+T1121
+T1123
+T1113
+R1156
+T1114s1
+T1183
+R1107
+T1137s7
+T1124
+T1178
+T1147
+R1128
+T1161
+R1108
+T1194
+T1185s2
+T1176
+T1158v3
+T1137s4
+T1160
+T1120
+H1185
+T1134s1
+T1119
+H1151
+T1137s8
+T1133
+T1187
+H1157
+T1122
+T1104
+T1158v2
+T1137s5
+T1129s2
+T1174
+T1157s2
+T1155
+T1158v4
+T1152
+T1137s9
+T1134s2
+T1125
+R1116
+H1134
+R1136
+T1159
+T1137s3
+T1185s1
+T1179
+T1106s1
+T1132
+T1185s4
+T1114s3
+T1114s2
+T1151s2
+T1158v1
+R1117v2
+T1173
diff --git a/scripts/train/assets/test_ids.txt b/scripts/train/assets/test_ids.txt
new file mode 100644
index 0000000..2bb230d
--- /dev/null
+++ b/scripts/train/assets/test_ids.txt
@@ -0,0 +1,550 @@
+8BZ4
+8URN
+7U71
+7Z64
+7Y3Z
+8SOT
+8GH8
+8IIB
+7U08
+8EB5
+8G49
+8K7Y
+7QQD
+8EIL
+8JQE
+8V1K
+7ZRZ
+7YN2
+8D40
+8RXO
+8SXS
+7UDL
+8ADD
+7Z3I
+7YUK
+7XWY
+8F9Y
+8WO7
+8C27
+8I3J
+8HVC
+8SXU
+8K1I
+8FTV
+8ERC
+8DVQ
+8DTQ
+8J12
+8D0P
+8POG
+8HN0
+7QPK
+8AGR
+8GXR
+8K7X
+8BL6
+8HAW
+8SRO
+8HHM
+8C26
+7SPQ
+8SME
+7XGV
+8GTY
+8Q42
+8BRY
+8HDV
+8B3Z
+7XNJ
+8EEL
+8IOI
+8Q70
+8Y4U
+8ANT
+8IUB
+8D49
+8CPQ
+8BAT
+8E2B
+8IWP
+8IJT
+7Y01
+8CJG
+8HML
+8WU2
+8VRM
+8J1J
+8DAJ
+8SUT
+8PTJ
+8IVZ
+8SDZ
+7YDQ
+8JU7
+8K34
+8B6Q
+8F7N
+8IBZ
+7WOI
+8R7D
+8T65
+8IQC
+8SIU
+8QK8
+8HIG
+7Y43
+8IN8
+8IBW
+8GOY
+7ZAO
+8J9G
+7ZCA
+8HIO
+8EFZ
+8IQ8
+8OQ0
+8HHL
+7XMW
+8GI1
+8AYR
+7ZCB
+8BRD
+8IN6
+8I3F
+8HIU
+8ER5
+8WIL
+7YPR
+8UA2
+8BW6
+8IL8
+8J3R
+8K1F
+8OHI
+8WCT
+8AN0
+8BDQ
+7FCT
+8J69
+8HTX
+8PE3
+8K5U
+8AXT
+8PSO
+8JHR
+8GY0
+8QCW
+8K3D
+8P6J
+8J0Q
+7XS3
+8DHJ
+8EIN
+7WKP
+8GAQ
+7WRN
+8AHD
+7SC4
+8B3E
+8AAS
+8UZ8
+8Q1K
+8K5K
+8B45
+8PT7
+7ZPN
+8UQ9
+8TJG
+8TN8
+8B2E
+7XFZ
+8FW7
+8B3W
+7T4W
+8SVA
+7YL4
+8GLD
+8OEI
+8GMX
+8OWF
+8FNR
+8IRQ
+8JDG
+7UXA
+8TKA
+7YH1
+8HUZ
+8TA2
+8E5D
+7YUN
+7UOI
+7WMY
+8AA9
+8ISZ
+8EXA
+8E7F
+8B2S
+8TP8
+8GSY
+7XRX
+8SY3
+8CIL
+8WBR
+7XF1
+7YPO
+8AXF
+7QNL
+8OYY
+7R1N
+8H5S
+8B6U
+8IBX
+8Q43
+8OW8
+7XSG
+8U0M
+8IOO
+8HR5
+8BVK
+8P0C
+7TL6
+8J48
+8S0U
+8K8A
+8G53
+7XYO
+8POF
+8U1K
+8HF2
+8K4L
+8JAH
+8KGZ
+8BNB
+7UG2
+8A0A
+8Q3Z
+8XBI
+8JNM
+8GPS
+8K1R
+8Q66
+7YLQ
+7YNX
+8IMD
+7Y8H
+8OXU
+8BVE
+8B4E
+8V14
+7R5I
+8IR2
+8UK7
+8EBB
+7XCC
+8AEP
+7YDW
+8XX9
+7VS6
+8K3F
+8CQM
+7XH4
+8BH9
+7VXT
+8SM9
+8HGU
+8PSQ
+8SSU
+8VXA
+8GSX
+8GHZ
+8BJ3
+8C9V
+8T66
+7XPC
+8RH3
+8CMQ
+8AGG
+8ERM
+8P6M
+8BUX
+7S2J
+8G32
+8AXJ
+8CID
+8CPK
+8P5Q
+8HP8
+7YUJ
+8PT2
+7YK3
+7YYG
+8ABV
+7XL7
+7YLZ
+8JWS
+8IW5
+8SM6
+8BBZ
+8EOV
+8PXC
+7UWV
+8A9N
+7YH5
+8DEO
+7X2X
+8W7P
+8B5W
+8CIH
+8RB4
+8HLG
+8J8H
+8UA5
+7YKM
+8S9W
+7YPD
+8GA6
+7YPQ
+8X7X
+8HI8
+8H7A
+8C4D
+8XAT
+8W8S
+8HM4
+8H3Z
+7W91
+8GPP
+8TNM
+7YSI
+8OML
+8BBR
+7YOJ
+8JZX
+8I3X
+8AU6
+8ITO
+7SFY
+8B6P
+7Y8S
+8ESL
+8DSP
+8CLZ
+8F72
+8QLD
+8K86
+8G8E
+8QDO
+8ANU
+8PT6
+8F5D
+8DQ6
+8IFK
+8OJN
+8SSC
+7QRR
+8E55
+7TPU
+7UQU
+8HFP
+7XGT
+8A39
+8CB2
+8ACR
+8G5S
+7TZL
+8T4R
+8H18
+7UI4
+8Q41
+8K76
+7WUY
+8VXC
+8GYG
+8IMS
+8IKS
+8X51
+7Y7O
+8PX4
+8BF8
+7XMJ
+8GDW
+7YTU
+8CH4
+7XHZ
+7YH4
+8PSN
+8A16
+8FBJ
+7Y9G
+8JI2
+7YR9
+8SW0
+8A90
+8X6V
+8H8P
+7WJU
+8PSS
+8HL8
+8FJD
+8PM4
+7UK8
+8DX0
+8PHB
+8FBN
+8FXF
+8GKH
+8ENR
+8PTH
+8CBV
+8GKV
+8CQO
+8OK3
+8GSR
+8TPK
+8H1J
+8QFL
+8CHW
+7V34
+8HE2
+7ZIE
+8A50
+7Z8E
+8ILL
+7WWC
+7XVI
+8Q2A
+8HNO
+8PR6
+7XCA
+7XGS
+8H55
+8FJE
+7UNH
+8AY2
+8ARD
+8HBR
+8EWG
+8D4A
+8FIT
+8E5E
+8PMU
+8F5G
+8AMU
+8CPN
+7QPL
+8EHN
+8SQU
+8F70
+8FX9
+7UR2
+8T1M
+7ZDS
+7YH2
+8B6A
+8CHX
+8G0N
+8GY4
+7YKG
+8BH8
+8BVI
+7XF2
+8BFY
+8IA3
+8JW3
+8OQJ
+8TFS
+7Y1S
+8HBB
+8AF9
+8IP1
+7XZ3
+8T0P
+7Y16
+8BRP
+8JNX
+8JP0
+8EC3
+8PZH
+7URP
+8B4D
+8JFR
+8GYR
+7XFS
+8SMQ
+7WNH
+8H0L
+8OWI
+8HFC
+7X6G
+8FKL
+8PAG
+8UPI
+8D4B
+8BCK
+8JFU
+8FUQ
+8IF8
+8PAQ
+8HDU
+8W9O
+8ACA
+7YIA
+7ZFR
+7Y9A
+8TTO
+7YFX
+8B2H
+8PSU
+8ACC
+8JMR
+8IHA
+7UYX
+8DWJ
+8BY5
+8EZW
+8A82
+8TVL
+8R79
+8R8A
+8AHZ
+8AYV
+8JHU
+8Q44
+8ARE
+8OLJ
+7Y95
+7XP0
+8EX9
+8BID
+8Q40
+7QSJ
+7UBA
+7XFU
+8OU1
+8G2V
+8YA7
+8GMZ
+8T8L
+8CK0
+7Y4H
+8IOM
+7ZLQ
+8BZ2
+8B4C
+8DZJ
+8CEG
+8IBY
+8T3J
+8IVI
+8ITN
+8CR7
+8TGH
+8OKH
+7UI8
+8EHT
+8ADC
+8T4C
+7XBJ
+8CLU
+7QA1
diff --git a/scripts/train/assets/validation_ids.txt b/scripts/train/assets/validation_ids.txt
new file mode 100644
index 0000000..13f5bb9
--- /dev/null
+++ b/scripts/train/assets/validation_ids.txt
@@ -0,0 +1,552 @@
+7UTN
+7F9H
+7TZV
+7ZHH
+7SOV
+7EOF
+7R8H
+8AW3
+7F2F
+8BAO
+7BCB
+7D8T
+7D3T
+7BHY
+7YZ7
+8DC2
+7SOW
+8CTL
+7SOS
+7V6W
+7Z55
+7NQF
+7VTN
+7KSP
+7BJQ
+7YZC
+7Y3L
+7TDX
+7R8I
+7OYK
+7TZ1
+7KIJ
+7T8K
+7KII
+7YZA
+7VP4
+7KIK
+7M5W
+7Q94
+7BCA
+7YZB
+7OG0
+7VTI
+7SOP
+7S03
+7YZG
+7TXC
+7VP5
+7Y3I
+7TDW
+8B0R
+7R8G
+7FEF
+7VP1
+7VP3
+7RGU
+7DV2
+7YZD
+7OFZ
+7Y3K
+7TEC
+7WQ5
+7VP2
+7EDB
+7VP7
+7PDV
+7XHT
+7R6R
+8CSH
+8CSZ
+7V9O
+7Q1C
+8EDC
+7PWI
+7FI1
+7ESI
+7F0Y
+7EYR
+7ZVA
+7WEG
+7E4N
+7U5Q
+7FAV
+7LJ2
+7S6F
+7B3N
+7V4P
+7AJO
+7WH1
+8DQP
+7STT
+7VQ7
+7E4J
+7RIS
+7FH8
+7BMW
+7RD0
+7V54
+7LKC
+7OU1
+7QOD
+7PX1
+7EBY
+7U1V
+7PLP
+7T8N
+7SJK
+7RGB
+7TEM
+7UG9
+7B7A
+7TM2
+7Z74
+7PCM
+7V8G
+7EUU
+7VTL
+7ZEI
+7ZC0
+7DZ9
+8B2M
+7NE9
+7ALV
+7M96
+7O6T
+7SKO
+7Z2V
+7OWX
+7SHW
+7TNI
+7ZQY
+7MDF
+7EXR
+7W6B
+7EQF
+7WWO
+7FBW
+8EHE
+7CLE
+7T80
+7WMV
+7SMG
+7WSJ
+7DBU
+7VHY
+7W5F
+7SHG
+7VU3
+7ATH
+7FGZ
+7ADS
+7REO
+7T7H
+7X0N
+7TCU
+7SKH
+7EF6
+7TBV
+7B29
+7VO5
+7TM1
+7QLD
+7BB9
+7SZ8
+7RLM
+7WWP
+7NBV
+7PLD
+7DNM
+7SFZ
+7EAW
+7QNQ
+7SZX
+7U2S
+7WZX
+7TYG
+7QCE
+7DCN
+7WJL
+7VV6
+7TJ4
+7VI8
+8AKP
+7WAO
+7N7V
+7EYO
+7VTD
+7VEG
+7QY5
+7ELV
+7P0J
+7YX8
+7U4H
+7TBD
+7WME
+7RI3
+7TOH
+7ZVM
+7PUL
+7VBO
+7DM0
+7XN9
+7ALY
+7LTB
+8A28
+7UBZ
+8DTE
+7TA2
+7QST
+7AN1
+7FIB
+8BAL
+7TMJ
+7REV
+7PZJ
+7T9X
+7SUU
+7KJQ
+7V6P
+7QA3
+7ULC
+7Y3X
+7TMU
+7OA7
+7PO9
+7Q20
+8H2C
+7VW1
+7VLJ
+8EP4
+7P57
+7QUL
+7ZQE
+7UJU
+7WG1
+7DMK
+7Y8X
+7EHG
+7W13
+7NL4
+7R4J
+7AOV
+7RFT
+7VUF
+7F72
+8DSR
+7MK3
+7MQQ
+7R55
+7T85
+7NCY
+7ZHL
+7E1N
+7W8F
+7PGK
+8GUN
+7P8D
+7PUK
+7N9D
+7XWN
+7ZHA
+7TVP
+7VI6
+7PW6
+7YM0
+7RWK
+8DKR
+7WGU
+7LJI
+7THW
+7OB6
+7N3Z
+7T3S
+7PAB
+7F9F
+7PPP
+7AD5
+7VGM
+7WBO
+7RWM
+7QFI
+7T91
+7ANU
+7UX0
+7USR
+7RDN
+7VW5
+7Q4T
+7W3R
+8DKQ
+7RCX
+7UOF
+7OKR
+7NX1
+6ZBS
+7VEV
+8E8U
+7WJ6
+7MP4
+7RPY
+7R5Z
+7VLM
+7SNE
+7WDW
+8E19
+7PP2
+7Z5H
+7P7I
+7LJJ
+7QPC
+7VJS
+7QOE
+7KZH
+7F6N
+7TMI
+7POH
+8DKS
+7YMO
+6S5I
+7N6O
+7LYU
+7POK
+7BLK
+7TCY
+7W19
+8B55
+7SMU
+7QFK
+7T5T
+7EPQ
+7DCK
+7S69
+6ZSV
+7ZGT
+7TJ1
+7V09
+7ZHD
+7ALL
+7P1Y
+7T71
+7MNK
+7W5Q
+7PZ2
+7QSQ
+7QI3
+7NZZ
+7Q47
+8D08
+7QH5
+7RXQ
+7F45
+8D07
+8EHC
+7PZT
+7K3C
+7ZGI
+7MC4
+7NPQ
+7VD7
+7XAN
+7FDP
+8A0K
+7TXO
+7ZB1
+7V5V
+7WWS
+7PBK
+8EBG
+7N0J
+7UMA
+7T1S
+8EHB
+7DWC
+7K6W
+7WEJ
+7LRH
+7ZCV
+7RKC
+7X8C
+7PV1
+7UGK
+7ULN
+7A66
+7R7M
+7M0Q
+7BGS
+7UPP
+7O62
+7VKK
+7L6Y
+7VG4
+7V2V
+7ETN
+7ZTB
+7AOO
+7OH2
+7E0M
+7PEG
+8CUK
+7ZP0
+7T6A
+7BTM
+7DOV
+7VVV
+7P22
+7RUO
+7E40
+7O5Y
+7XPK
+7R0K
+8D04
+7TYD
+7LSV
+7XSI
+7RTZ
+7UXR
+7QH3
+8END
+8CYK
+7MRJ
+7DJL
+7S5B
+7XUX
+7EV8
+7R6S
+7UH4
+7R9X
+7F7P
+7ACW
+7SPN
+7W70
+7Q5G
+7DXN
+7DK9
+8DT0
+7FDN
+7DGX
+7UJB
+7X4O
+7F4O
+7T9W
+8AID
+7ERQ
+7EQB
+7YDG
+7ETR
+8D27
+7OUU
+7R5Y
+7T8I
+7UZT
+7X8V
+7QLH
+7SAF
+7EN6
+8D4Y
+7ESJ
+7VWO
+7SBE
+7VYU
+7RVJ
+7FCL
+7WUO
+7WWF
+7VMT
+7SHJ
+7SKP
+7KOU
+6ZSU
+7VGW
+7X45
+8GYZ
+8BFE
+8DGL
+7Z3H
+8BD1
+8A0J
+7JRK
+7QII
+7X39
+7Y6B
+7OIY
+7SBI
+8A3I
+7NLI
+7F4U
+7TVY
+7X0O
+7VMH
+7EPN
+7WBK
+8BFJ
+7XFP
+7LXQ
+7TIL
+7O61
+8B8B
+7W2Q
+8APR
+7WZE
+7NYQ
+7RMX
+7PGE
+8F43
+7N2K
+7UXG
+7SXN
+7T5U
+7R22
+7E3T
+7PTB
+7OA8
+7X5T
+7PL7
+7SQ5
+7VBS
+8D03
+7TAE
+7T69
+7WF6
+7LBU
+8A06
+8DA2
+7QFL
+7KUW
+7X9R
+7XT3
+7RB4
+7PT5
+7RPS
+7RXU
+7TDY
+7W89
+7N9I
+7T1M
+7OBM
+7K3X
+7ZJC
+8BDP
+7V8W
+7DJK
+7W1K
+7QFG
+7DGY
+7ZTQ
+7F8A
+7NEK
+7CG9
+7KOB
+7TN7
+8DYS
+7WVR
diff --git a/scripts/train/configs/confidence.yaml b/scripts/train/configs/confidence.yaml
new file mode 100644
index 0000000..c60cd88
--- /dev/null
+++ b/scripts/train/configs/confidence.yaml
@@ -0,0 +1,192 @@
+trainer:
+ accelerator: gpu
+ devices: 1
+ precision: 32
+ gradient_clip_val: 10.0
+ max_epochs: -1
+
+# Optional set wandb here
+# wandb:
+# name: boltz
+# project: boltz
+# entity: boltz
+
+
+output: SET_PATH_HERE
+pretrained: PATH_TO_STRUCTURE_CHECKPOINT_FILE
+resume: null
+disable_checkpoint: false
+matmul_precision: null
+save_top_k: -1
+load_confidence_from_trunk: true
+
+data:
+ datasets:
+ - _target_: boltz.data.module.training.DatasetConfig
+ target_dir: PATH_TO_TARGETS_DIR
+ msa_dir: PATH_TO_MSA_DIR
+ prob: 1.0
+ sampler:
+ _target_: boltz.data.sample.cluster.ClusterSampler
+ cropper:
+ _target_: boltz.data.crop.boltz.BoltzCropper
+ min_neighborhood: 0
+ max_neighborhood: 40
+ split: ./scripts/train/assets/validation_ids.txt
+
+ filters:
+ - _target_: boltz.data.filter.dynamic.size.SizeFilter
+ min_chains: 1
+ max_chains: 300
+ - _target_: boltz.data.filter.dynamic.date.DateFilter
+ date: "2021-09-30"
+ ref: released
+ - _target_: boltz.data.filter.dynamic.resolution.ResolutionFilter
+ resolution: 4.0
+
+ tokenizer:
+ _target_: boltz.data.tokenize.boltz.BoltzTokenizer
+ featurizer:
+ _target_: boltz.data.feature.featurizer.BoltzFeaturizer
+
+ symmetries: PATH_TO_SYMMETRY_FILE
+ max_tokens: 512
+ max_atoms: 4608
+ max_seqs: 2048
+ pad_to_max_tokens: true
+ pad_to_max_atoms: true
+ pad_to_max_seqs: true
+ samples_per_epoch: 100000
+ batch_size: 1
+ num_workers: 4
+ random_seed: 42
+ pin_memory: true
+ overfit: null
+ crop_validation: true
+ return_train_symmetries: true
+ return_val_symmetries: true
+ train_binder_pocket_conditioned_prop: 0.3
+ val_binder_pocket_conditioned_prop: 0.3
+ binder_pocket_cutoff: 6.0
+ binder_pocket_sampling_geometric_p: 0.3
+ min_dist: 2.0
+ max_dist: 22.0
+ num_bins: 64
+ atoms_per_window_queries: 32
+
+model:
+ _target_: boltz.model.models.boltz.BoltzPreview
+ atom_s: 128
+ atom_z: 16
+ token_s: 384
+ token_z: 128
+ num_bins: 64
+ atom_feature_dim: 389
+ atoms_per_window_queries: 32
+ atoms_per_window_keys: 128
+ compile_pairformer: false
+ nucleotide_rmsd_weight: 5.0
+ ligand_rmsd_weight: 10.0
+ ema: true
+ ema_decay: 0.999
+
+ embedder_args:
+ atom_encoder_depth: 3
+ atom_encoder_heads: 4
+
+ msa_args:
+ msa_s: 64
+ msa_blocks: 4
+ msa_dropout: 0.15
+ z_dropout: 0.25
+ pairwise_head_width: 32
+ pairwise_num_heads: 4
+ activation_checkpointing: true
+ offload_to_cpu: false
+
+ pairformer_args:
+ num_blocks: 48
+ num_heads: 16
+ dropout: 0.25
+ activation_checkpointing: true
+ offload_to_cpu: false
+
+ score_model_args:
+ sigma_data: 16
+ dim_fourier: 256
+ atom_encoder_depth: 3
+ atom_encoder_heads: 4
+ token_transformer_depth: 24
+ token_transformer_heads: 16
+ atom_decoder_depth: 3
+ atom_decoder_heads: 4
+ conditioning_transition_layers: 2
+ activation_checkpointing: true
+ offload_to_cpu: false
+
+ structure_prediction_training: false
+ run_trunk_and_structure: true
+ confidence_prediction: true
+ alpha_pae: 1
+ confidence_imitate_trunk: true
+ confidence_model_args:
+ use_gaussian: false
+ num_dist_bins: 64
+ max_dist: 22
+ add_s_to_z_prod: true
+ add_s_input_to_s: true
+ use_s_diffusion: true
+ add_z_input_to_z: true
+
+ confidence_args:
+ num_plddt_bins: 50
+ num_pde_bins: 64
+ num_pae_bins: 64
+ relative_confidence: none
+
+ training_args:
+ recycling_steps: 3
+ sampling_steps: 200
+ diffusion_multiplicity: 16
+ diffusion_samples: 1
+ confidence_loss_weight: 3e-3
+ diffusion_loss_weight: 4.0
+ distogram_loss_weight: 3e-2
+ adam_beta_1: 0.9
+ adam_beta_2: 0.95
+ adam_eps: 0.00000001
+ lr_scheduler: af3
+ base_lr: 0.0
+ max_lr: 0.0018
+ lr_warmup_no_steps: 1000
+ lr_start_decay_after_n_steps: 50000
+ lr_decay_every_n_steps: 50000
+ lr_decay_factor: 0.95
+ symmetry_correction: true
+
+ validation_args:
+ recycling_steps: 3
+ sampling_steps: 200
+ diffusion_samples: 5
+ symmetry_correction: true
+
+ diffusion_process_args:
+ sigma_min: 0.0004
+ sigma_max: 160.0
+ sigma_data: 16.0
+ rho: 7
+ P_mean: -1.2
+ P_std: 1.5
+ gamma_0: 0.8
+ gamma_min: 1.0
+ noise_scale: 1.0
+ step_scale: 1.0
+ coordinate_augmentation: true
+ alignment_reverse_diff: true
+ synchronize_sigmas: true
+ use_inference_model_cache: true
+
+ diffusion_loss_args:
+ add_smooth_lddt_loss: true
+ nucleotide_loss_weight: 5.0
+ ligand_loss_weight: 10.0
diff --git a/scripts/train/configs/structure.yaml b/scripts/train/configs/structure.yaml
new file mode 100644
index 0000000..d5ce7ba
--- /dev/null
+++ b/scripts/train/configs/structure.yaml
@@ -0,0 +1,184 @@
+trainer:
+ accelerator: gpu
+ devices: 1
+ precision: 32
+ gradient_clip_val: 10.0
+ max_epochs: -1
+
+# Optional set wandb here
+# wandb:
+# name: boltz
+# project: boltz
+# entity: boltz
+
+output: SET_PATH_HERE
+resume: PATH_TO_CHECKPOINT_FILE
+disable_checkpoint: false
+matmul_precision: null
+save_top_k: -1
+
+data:
+ datasets:
+ - _target_: boltz.data.module.training.DatasetConfig
+ target_dir: PATH_TO_TARGETS_DIR
+ msa_dir: PATH_TO_MSA_DIR
+ prob: 1.0
+ sampler:
+ _target_: boltz.data.sample.cluster.ClusterSampler
+ cropper:
+ _target_: boltz.data.crop.boltz.BoltzCropper
+ min_neighborhood: 0
+ max_neighborhood: 40
+ split: ./scripts/train/assets/validation_ids.txt
+
+ filters:
+ - _target_: boltz.data.filter.dynamic.size.SizeFilter
+ min_chains: 1
+ max_chains: 300
+ - _target_: boltz.data.filter.dynamic.date.DateFilter
+ date: "2021-09-30"
+ ref: released
+ - _target_: boltz.data.filter.dynamic.resolution.ResolutionFilter
+ resolution: 9.0
+
+ tokenizer:
+ _target_: boltz.data.tokenize.boltz.BoltzTokenizer
+ featurizer:
+ _target_: boltz.data.feature.featurizer.BoltzFeaturizer
+
+ symmetries: PATH_TO_SYMMETRY_FILE
+ max_tokens: 512
+ max_atoms: 4608
+ max_seqs: 2048
+ pad_to_max_tokens: true
+ pad_to_max_atoms: true
+ pad_to_max_seqs: true
+ samples_per_epoch: 100000
+ batch_size: 1
+ num_workers: 4
+ random_seed: 42
+ pin_memory: true
+ overfit: null
+ crop_validation: false
+ return_train_symmetries: false
+ return_val_symmetries: true
+ train_binder_pocket_conditioned_prop: 0.3
+ val_binder_pocket_conditioned_prop: 0.3
+ binder_pocket_cutoff: 6.0
+ binder_pocket_sampling_geometric_p: 0.3
+ min_dist: 2.0
+ max_dist: 22.0
+ num_bins: 64
+ atoms_per_window_queries: 32
+
+model:
+ _target_: boltz.model.model.Boltz1
+ atom_s: 128
+ atom_z: 16
+ token_s: 384
+ token_z: 128
+ num_bins: 64
+ atom_feature_dim: 389
+ atoms_per_window_queries: 32
+ atoms_per_window_keys: 128
+ compile_pairformer: false
+ nucleotide_rmsd_weight: 5.0
+ ligand_rmsd_weight: 10.0
+ ema: true
+ ema_decay: 0.999
+
+ embedder_args:
+ atom_encoder_depth: 3
+ atom_encoder_heads: 4
+
+ msa_args:
+ msa_s: 64
+ msa_blocks: 4
+ msa_dropout: 0.15
+ z_dropout: 0.25
+ pairwise_head_width: 32
+ pairwise_num_heads: 4
+ activation_checkpointing: true
+ offload_to_cpu: false
+
+ pairformer_args:
+ num_blocks: 48
+ num_heads: 16
+ dropout: 0.25
+ activation_checkpointing: true
+ offload_to_cpu: false
+
+ score_model_args:
+ sigma_data: 16
+ dim_fourier: 256
+ atom_encoder_depth: 3
+ atom_encoder_heads: 4
+ token_transformer_depth: 24
+ token_transformer_heads: 16
+ atom_decoder_depth: 3
+ atom_decoder_heads: 4
+ conditioning_transition_layers: 2
+ activation_checkpointing: true
+ offload_to_cpu: false
+
+ confidence_prediction: false
+ confidence_model_args:
+ use_gaussian: false
+ num_dist_bins: 64
+ max_dist: 22
+ add_s_to_z_prod: true
+ add_s_input_to_s: true
+ use_s_diffusion: true
+ add_z_input_to_z: true
+
+ confidence_args:
+ num_plddt_bins: 50
+ num_pde_bins: 64
+ num_pae_bins: 64
+ relative_confidence: none
+
+ training_args:
+ recycling_steps: 3
+ sampling_steps: 20
+ diffusion_multiplicity: 16
+ diffusion_samples: 2
+ confidence_loss_weight: 1e-4
+ diffusion_loss_weight: 4.0
+ distogram_loss_weight: 3e-2
+ adam_beta_1: 0.9
+ adam_beta_2: 0.95
+ adam_eps: 0.00000001
+ lr_scheduler: af3
+ base_lr: 0.0
+ max_lr: 0.0018
+ lr_warmup_no_steps: 1000
+ lr_start_decay_after_n_steps: 50000
+ lr_decay_every_n_steps: 50000
+ lr_decay_factor: 0.95
+
+ validation_args:
+ recycling_steps: 3
+ sampling_steps: 200
+ diffusion_samples: 5
+ symmetry_correction: true
+
+ diffusion_process_args:
+ sigma_min: 0.0004
+ sigma_max: 160.0
+ sigma_data: 16.0
+ rho: 7
+ P_mean: -1.2
+ P_std: 1.5
+ gamma_0: 0.8
+ gamma_min: 1.0
+ noise_scale: 1.0
+ step_scale: 1.0
+ coordinate_augmentation: true
+ alignment_reverse_diff: true
+ synchronize_sigmas: true
+ use_inference_model_cache: true
+
+ diffusion_loss_args:
+ add_smooth_lddt_loss: true
+ nucleotide_loss_weight: 5.0
+ ligand_loss_weight: 10.0
diff --git a/scripts/train/train.py b/scripts/train/train.py
new file mode 100644
index 0000000..d5b62d2
--- /dev/null
+++ b/scripts/train/train.py
@@ -0,0 +1,228 @@
+import os
+import sys
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Optional
+
+import hydra
+import omegaconf
+import pytorch_lightning as pl
+import torch
+import torch.multiprocessing
+from omegaconf import OmegaConf, listconfig
+from pytorch_lightning import LightningModule
+from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
+from pytorch_lightning.loggers import WandbLogger
+from pytorch_lightning.strategies import DDPStrategy
+from pytorch_lightning.utilities import rank_zero_only
+
+from boltz.data.module.training import BoltzTrainingDataModule, DataConfig
+
+
+@dataclass
+class TrainConfig:
+ """Train configuration.
+
+ Attributes
+ ----------
+ data : DataConfig
+ The data configuration.
+ model : ModelConfig
+ The model configuration.
+ output : str
+ The output directory.
+ trainer : Optional[dict]
+ The trainer configuration.
+ resume : Optional[str]
+ The resume checkpoint.
+ pretrained : Optional[str]
+ The pretrained model.
+ wandb : Optional[dict]
+ The wandb configuration.
+ disable_checkpoint : bool
+ Disable checkpoint.
+ matmul_precision : Optional[str]
+ The matmul precision.
+ find_unused_parameters : Optional[bool]
+ Find unused parameters.
+ save_top_k : Optional[int]
+ Save top k checkpoints.
+ validation_only : bool
+ Run validation only.
+ debug : bool
+ Debug mode.
+ strict_loading : bool
+ Fail on mismatched checkpoint weights.
+ load_confidence_from_trunk: Optional[bool]
+ Load pre-trained confidence weights from trunk.
+
+ """
+
+ data: DataConfig
+ model: LightningModule
+ output: str
+ trainer: Optional[dict] = None
+ resume: Optional[str] = None
+ pretrained: Optional[str] = None
+ wandb: Optional[dict] = None
+ disable_checkpoint: bool = False
+ matmul_precision: Optional[str] = None
+ find_unused_parameters: Optional[bool] = False
+ save_top_k: Optional[int] = 1
+ validation_only: bool = False
+ debug: bool = False
+ strict_loading: bool = True
+ load_confidence_from_trunk: Optional[bool] = False
+
+
+def train(raw_config: str, args: list[str]) -> None: # noqa: C901, PLR0912, PLR0915
+ """Run training.
+
+ Parameters
+ ----------
+ raw_config : str
+ The input yaml configuration.
+ args : list[str]
+ Any command line overrides.
+
+ """
+ # Load the configuration
+ raw_config = omegaconf.OmegaConf.load(raw_config)
+
+ # Apply input arguments
+ args = omegaconf.OmegaConf.from_dotlist(args)
+ raw_config = omegaconf.OmegaConf.merge(raw_config, args)
+
+ # Instantiate the task
+ cfg = hydra.utils.instantiate(raw_config)
+ cfg = TrainConfig(**cfg)
+
+ # Set matmul precision
+ if cfg.matmul_precision is not None:
+ torch.set_float32_matmul_precision(cfg.matmul_precision)
+
+ # Create trainer dict
+ trainer = cfg.trainer
+ if trainer is None:
+ trainer = {}
+
+ # Flip some arguments in debug mode
+ devices = trainer.get("devices", 1)
+
+ wandb = cfg.wandb
+ if cfg.debug:
+ if isinstance(devices, int):
+ devices = 1
+ elif isinstance(devices, (list, listconfig.ListConfig)):
+ devices = [devices[0]]
+ trainer["devices"] = devices
+ cfg.data.num_workers = 0
+ if wandb:
+ wandb = None
+
+ # Create objects
+ data_config = DataConfig(**cfg.data)
+ data_module = BoltzTrainingDataModule(data_config)
+ model_module = cfg.model
+
+ if cfg.pretrained and not cfg.resume:
+ # Load the pretrained weights into the confidence module
+ if cfg.load_confidence_from_trunk:
+ checkpoint = torch.load(cfg.pretrained, map_location="cpu")
+
+ # Modify parameter names in the state_dict
+ new_state_dict = {}
+ for key, value in checkpoint["state_dict"].items():
+ if not key.startswith("structure_module") and not key.startswith(
+ "distogram_module"
+ ):
+ new_key = "confidence_module." + key
+ new_state_dict[new_key] = value
+ new_state_dict.update(checkpoint["state_dict"])
+
+ # Update the checkpoint with the new state_dict
+ checkpoint["state_dict"] = new_state_dict
+ else:
+ file_path = cfg.pretrained
+
+ print(f"Loading model from {file_path}")
+ model_module = type(model_module).load_from_checkpoint(
+ file_path, strict=False, **(model_module.hparams)
+ )
+
+ if cfg.load_confidence_from_trunk:
+ os.remove(file_path)
+
+ # Create checkpoint callback
+ callbacks = []
+ dirpath = cfg.output
+ if not cfg.disable_checkpoint:
+ mc = ModelCheckpoint(
+ monitor="val/lddt",
+ save_top_k=cfg.save_top_k,
+ save_last=True,
+ mode="max",
+ every_n_epochs=1,
+ )
+ callbacks = [mc]
+
+ # Create wandb logger
+ loggers = []
+ if wandb:
+ wdb_logger = WandbLogger(
+ group=wandb["name"],
+ save_dir=cfg.output,
+ project=wandb["project"],
+ entity=wandb["entity"],
+ log_model=False,
+ )
+ loggers.append(wdb_logger)
+ # Save the config to wandb
+
+ @rank_zero_only
+ def save_config_to_wandb() -> None:
+ config_out = Path(wdb_logger.experiment.dir) / "run.yaml"
+ with Path.open(config_out, "w") as f:
+ OmegaConf.save(raw_config, f)
+ wdb_logger.experiment.save(str(config_out))
+
+ save_config_to_wandb()
+
+ # Set up trainer
+ strategy = "auto"
+ if (isinstance(devices, int) and devices > 1) or (
+ isinstance(devices, (list, listconfig.ListConfig)) and len(devices) > 1
+ ):
+ strategy = DDPStrategy(find_unused_parameters=cfg.find_unused_parameters)
+
+ trainer = pl.Trainer(
+ default_root_dir=str(dirpath),
+ strategy=strategy,
+ callbacks=callbacks,
+ logger=loggers,
+ enable_checkpointing=not cfg.disable_checkpoint,
+ reload_dataloaders_every_n_epochs=1,
+ **trainer,
+ )
+
+ if not cfg.strict_loading:
+ model_module.strict_loading = False
+
+ if cfg.validation_only:
+ trainer.validate(
+ model_module,
+ datamodule=data_module,
+ ckpt_path=cfg.resume,
+ )
+ else:
+ trainer.fit(
+ model_module,
+ datamodule=data_module,
+ ckpt_path=cfg.resume,
+ )
+
+
+if __name__ == "__main__":
+ arg1 = sys.argv[1]
+ arg2 = sys.argv[2:]
+ train(arg1, arg2)
diff --git a/src/boltz/__init__.py b/src/boltz/__init__.py
new file mode 100644
index 0000000..ce79ee0
--- /dev/null
+++ b/src/boltz/__init__.py
@@ -0,0 +1,7 @@
+from importlib.metadata import PackageNotFoundError, version
+
+try: # noqa: SIM105
+ __version__ = version("boltz")
+except PackageNotFoundError:
+ # package is not installed
+ pass
diff --git a/src/boltz/data/__init__.py b/src/boltz/data/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/boltz/data/const.py b/src/boltz/data/const.py
new file mode 100644
index 0000000..39b727e
--- /dev/null
+++ b/src/boltz/data/const.py
@@ -0,0 +1,353 @@
+####################################################################################################
+# CHAINS
+####################################################################################################
+
+chain_types = [
+ "PROTEIN",
+ "DNA",
+ "RNA",
+ "NONPOLYMER",
+]
+chain_type_ids = {chain: i for i, chain in enumerate(chain_types)}
+
+out_types = [
+ "dna_protein",
+ "rna_protein",
+ "ligand_protein",
+ "dna_ligand",
+ "rna_ligand",
+ "intra_ligand",
+ "intra_dna",
+ "intra_rna",
+ "intra_protein",
+ "protein_protein",
+]
+
+out_types_weights_af3 = {
+ "dna_protein": 10.0,
+ "rna_protein": 10.0,
+ "ligand_protein": 10.0,
+ "dna_ligand": 5.0,
+ "rna_ligand": 5.0,
+ "intra_ligand": 20.0,
+ "intra_dna": 4.0,
+ "intra_rna": 16.0,
+ "intra_protein": 20.0,
+ "protein_protein": 20.0,
+}
+
+out_types_weights = {
+ "dna_protein": 5.0,
+ "rna_protein": 5.0,
+ "ligand_protein": 20.0,
+ "dna_ligand": 2.0,
+ "rna_ligand": 2.0,
+ "intra_ligand": 20.0,
+ "intra_dna": 2.0,
+ "intra_rna": 8.0,
+ "intra_protein": 20.0,
+ "protein_protein": 20.0,
+}
+
+
+out_single_types = ["protein", "ligand", "dna", "rna"]
+
+####################################################################################################
+# RESIDUES & TOKENS
+####################################################################################################
+
+tokens = [
+ "",
+ "-",
+ "ALA",
+ "ARG",
+ "ASN",
+ "ASP",
+ "CYS",
+ "GLN",
+ "GLU",
+ "GLY",
+ "HIS",
+ "ILE",
+ "LEU",
+ "LYS",
+ "MET",
+ "PHE",
+ "PRO",
+ "SER",
+ "THR",
+ "TRP",
+ "TYR",
+ "VAL",
+ "UNK", # unknown protein token
+ "A",
+ "G",
+ "C",
+ "U",
+ "N", # unknown rna token
+ "DA",
+ "DG",
+ "DC",
+ "DT",
+ "DN", # unknown dna token
+]
+
+token_ids = {token: i for i, token in enumerate(tokens)}
+num_tokens = len(tokens)
+unk_token = {"PROTEIN": "UNK", "DNA": "DN", "RNA": "N"}
+unk_token_ids = {m: token_ids[t] for m, t in unk_token.items()}
+
+prot_letter_to_token = {
+ "A": "ALA",
+ "R": "ARG",
+ "N": "ASN",
+ "D": "ASP",
+ "C": "CYS",
+ "E": "GLU",
+ "Q": "GLN",
+ "G": "GLY",
+ "H": "HIS",
+ "I": "ILE",
+ "L": "LEU",
+ "K": "LYS",
+ "M": "MET",
+ "F": "PHE",
+ "P": "PRO",
+ "S": "SER",
+ "T": "THR",
+ "W": "TRP",
+ "Y": "TYR",
+ "V": "VAL",
+ "X": "UNK",
+ "J": "UNK",
+ "B": "UNK",
+ "Z": "UNK",
+ "O": "UNK",
+ "U": "UNK",
+ "-": "-",
+}
+
+prot_token_to_letter = {v: k for k, v in prot_letter_to_token.items()}
+prot_token_to_letter["UNK"] = "X"
+
+rna_letter_to_token = {
+ "A": "A",
+ "G": "G",
+ "C": "C",
+ "U": "U",
+ "N": "N",
+}
+rna_token_to_letter = {v: k for k, v in rna_letter_to_token.items()}
+
+dna_letter_to_token = {
+ "A": "DA",
+ "G": "DG",
+ "C": "DC",
+ "T": "DT",
+ "N": "DN",
+}
+dna_token_to_letter = {v: k for k, v in dna_letter_to_token.items()}
+
+####################################################################################################
+# ATOMS
+####################################################################################################
+
+num_elements = 128
+
+chirality_types = [
+ "CHI_UNSPECIFIED",
+ "CHI_TETRAHEDRAL_CW",
+ "CHI_TETRAHEDRAL_CCW",
+ "CHI_OTHER",
+]
+chirality_type_ids = {chirality: i for i, chirality in enumerate(chirality_types)}
+unk_chirality_type = "CHI_UNSPECIFIED"
+
+# fmt: off
+ref_atoms = {
+ "PAD": [],
+ "UNK": ["N", "CA", "C", "O", "CB"],
+ "-": [],
+ "ALA": ["N", "CA", "C", "O", "CB"],
+ "ARG": ["N", "CA", "C", "O", "CB", "CG", "CD", "NE", "CZ", "NH1", "NH2"],
+ "ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2"],
+ "ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2"],
+ "CYS": ["N", "CA", "C", "O", "CB", "SG"],
+ "GLN": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2"],
+ "GLU": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2"],
+ "GLY": ["N", "CA", "C", "O"],
+ "HIS": ["N", "CA", "C", "O", "CB", "CG", "ND1", "CD2", "CE1", "NE2"],
+ "ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1"],
+ "LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2"],
+ "LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ"],
+ "MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE"],
+ "PHE": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ"],
+ "PRO": ["N", "CA", "C", "O", "CB", "CG", "CD"],
+ "SER": ["N", "CA", "C", "O", "CB", "OG"],
+ "THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2"],
+ "TRP": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "NE1", "CE2", "CE3", "CZ2", "CZ3", "CH2"], # noqa: E501
+ "TYR": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH"],
+ "VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2"],
+ "A": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'", "N9", "C8", "N7", "C5", "C6", "N6", "N1", "C2", "N3", "C4"], # noqa: E501
+ "G": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'", "N9", "C8", "N7", "C5", "C6", "O6", "N1", "C2", "N2", "N3", "C4"], # noqa: E501
+ "C": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'", "N1", "C2", "O2", "N3", "C4", "N4", "C5", "C6"], # noqa: E501
+ "U": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'", "N1", "C2", "O2", "N3", "C4", "O4", "C5", "C6"], # noqa: E501
+ "N": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'"], # noqa: E501
+ "DA": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "C1'", "N9", "C8", "N7", "C5", "C6", "N6", "N1", "C2", "N3", "C4"], # noqa: E501
+ "DG": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "C1'", "N9", "C8", "N7", "C5", "C6", "O6", "N1", "C2", "N2", "N3", "C4"], # noqa: E501
+ "DC": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "C1'", "N1", "C2", "O2", "N3", "C4", "N4", "C5", "C6"], # noqa: E501
+ "DT": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "C1'", "N1", "C2", "O2", "N3", "C4", "O4", "C5", "C7", "C6"], # noqa: E501
+ "DN": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "C1'"]
+}
+
+ref_symmetries = {
+ "PAD": [],
+ "ALA": [],
+ "ARG": [],
+ "ASN": [],
+ "ASP": [[(6, 7), (7, 6)]],
+ "CYS": [],
+ "GLN": [],
+ "GLU": [[(7, 8), (8, 7)]],
+ "GLY": [],
+ "HIS": [],
+ "ILE": [],
+ "LEU": [],
+ "LYS": [],
+ "MET": [],
+ "PHE": [[(6, 7), (7, 6), (8, 9), (9, 8)]],
+ "PRO": [],
+ "SER": [],
+ "THR": [],
+ "TRP": [],
+ "TYR": [[(6, 7), (7, 6), (8, 9), (9, 8)]],
+ "VAL": [],
+ "A": [[(1, 2), (2, 1)]],
+ "G": [[(1, 2), (2, 1)]],
+ "C": [[(1, 2), (2, 1)]],
+ "U": [[(1, 2), (2, 1)]],
+ "N": [[(1, 2), (2, 1)]],
+ "DA": [[(1, 2), (2, 1)]],
+ "DG": [[(1, 2), (2, 1)]],
+ "DC": [[(1, 2), (2, 1)]],
+ "DT": [[(1, 2), (2, 1)]],
+ "DN": [[(1, 2), (2, 1)]]
+}
+
+
+res_to_center_atom = {
+ "UNK": "CA",
+ "ALA": "CA",
+ "ARG": "CA",
+ "ASN": "CA",
+ "ASP": "CA",
+ "CYS": "CA",
+ "GLN": "CA",
+ "GLU": "CA",
+ "GLY": "CA",
+ "HIS": "CA",
+ "ILE": "CA",
+ "LEU": "CA",
+ "LYS": "CA",
+ "MET": "CA",
+ "PHE": "CA",
+ "PRO": "CA",
+ "SER": "CA",
+ "THR": "CA",
+ "TRP": "CA",
+ "TYR": "CA",
+ "VAL": "CA",
+ "A": "C1'",
+ "G": "C1'",
+ "C": "C1'",
+ "U": "C1'",
+ "N": "C1'",
+ "DA": "C1'",
+ "DG": "C1'",
+ "DC": "C1'",
+ "DT": "C1'",
+ "DN": "C1'"
+}
+
+res_to_disto_atom = {
+ "UNK": "CB",
+ "ALA": "CB",
+ "ARG": "CB",
+ "ASN": "CB",
+ "ASP": "CB",
+ "CYS": "CB",
+ "GLN": "CB",
+ "GLU": "CB",
+ "GLY": "CA",
+ "HIS": "CB",
+ "ILE": "CB",
+ "LEU": "CB",
+ "LYS": "CB",
+ "MET": "CB",
+ "PHE": "CB",
+ "PRO": "CB",
+ "SER": "CB",
+ "THR": "CB",
+ "TRP": "CB",
+ "TYR": "CB",
+ "VAL": "CB",
+ "A": "C4",
+ "G": "C4",
+ "C": "C2",
+ "U": "C2",
+ "N": "C1'",
+ "DA": "C4",
+ "DG": "C4",
+ "DC": "C2",
+ "DT": "C2",
+ "DN": "C1'"
+}
+
+res_to_center_atom_id = {
+ res: ref_atoms[res].index(atom)
+ for res, atom in res_to_center_atom.items()
+}
+
+res_to_disto_atom_id = {
+ res: ref_atoms[res].index(atom)
+ for res, atom in res_to_disto_atom.items()
+}
+
+# fmt: on
+
+####################################################################################################
+# BONDS
+####################################################################################################
+
+atom_interface_cutoff = 5.0
+interface_cutoff = 15.0
+
+bond_types = [
+ "OTHER",
+ "SINGLE",
+ "DOUBLE",
+ "TRIPLE",
+ "AROMATIC",
+]
+bond_type_ids = {bond: i for i, bond in enumerate(bond_types)}
+unk_bond_type = "OTHER"
+
+
+####################################################################################################
+# Contacts
+####################################################################################################
+
+
+pocket_contact_info = {
+ "UNSPECIFIED": 0,
+ "UNSELECTED": 1,
+ "POCKET": 2,
+ "BINDER": 3,
+}
+
+
+####################################################################################################
+# MSA
+####################################################################################################
+
+max_msa_seqs = 16384
diff --git a/src/boltz/data/crop/__init__.py b/src/boltz/data/crop/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/boltz/data/crop/boltz.py b/src/boltz/data/crop/boltz.py
new file mode 100644
index 0000000..6cded72
--- /dev/null
+++ b/src/boltz/data/crop/boltz.py
@@ -0,0 +1,296 @@
+from dataclasses import replace
+from typing import Optional
+
+import numpy as np
+from scipy.spatial.distance import cdist
+
+from boltz.data import const
+from boltz.data.crop.cropper import Cropper
+from boltz.data.types import Tokenized
+
+
+def pick_random_token(
+ tokens: np.ndarray,
+ random: np.random.RandomState,
+) -> np.ndarray:
+ """Pick a random token from the data.
+
+ Parameters
+ ----------
+ tokens : np.ndarray
+ The token data.
+ random : np.ndarray
+ The random state for reproducibility.
+
+ Returns
+ -------
+ np.ndarray
+ The selected token.
+
+ """
+ return tokens[random.randint(len(tokens))]
+
+
+def pick_chain_token(
+ tokens: np.ndarray,
+ chain_id: int,
+ random: np.random.RandomState,
+) -> np.ndarray:
+ """Pick a random token from a chain.
+
+ Parameters
+ ----------
+ tokens : np.ndarray
+ The token data.
+ chain_id : int
+ The chain ID.
+ random : np.ndarray
+ The random state for reproducibility.
+
+ Returns
+ -------
+ np.ndarray
+ The selected token.
+
+ """
+ # Filter to chain
+ chain_tokens = tokens[tokens["asym_id"] == chain_id]
+
+ # Pick from chain, fallback to all tokens
+ if chain_tokens.size:
+ query = pick_random_token(chain_tokens, random)
+ else:
+ query = pick_random_token(tokens, random)
+
+ return query
+
+
+def pick_interface_token(
+ tokens: np.ndarray,
+ interface: np.ndarray,
+ random: np.random.RandomState,
+) -> np.ndarray:
+ """Pick a random token from an interface.
+
+ Parameters
+ ----------
+ tokens : np.ndarray
+ The token data.
+ interface : int
+ The interface ID.
+ random : np.ndarray
+ The random state for reproducibility.
+
+ Returns
+ -------
+ np.ndarray
+ The selected token.
+
+ """
+ # Sample random interface
+ chain_1 = int(interface["chain_1"])
+ chain_2 = int(interface["chain_2"])
+
+ tokens_1 = tokens[tokens["asym_id"] == chain_1]
+ tokens_2 = tokens[tokens["asym_id"] == chain_2]
+
+ # If no interface, pick from the chains
+ if tokens_1.size and (not tokens_2.size):
+ query = pick_random_token(tokens_1, random)
+ elif tokens_2.size and (not tokens_1.size):
+ query = pick_random_token(tokens_2, random)
+ elif (not tokens_1.size) and (not tokens_2.size):
+ query = pick_random_token(tokens, random)
+ else:
+ # If we have tokens, compute distances
+ tokens_1_coords = tokens_1["center_coords"]
+ tokens_2_coords = tokens_2["center_coords"]
+
+ dists = cdist(tokens_1_coords, tokens_2_coords)
+ cuttoff = dists < const.interface_cutoff
+
+ # In rare cases, the interface cuttoff is slightly
+ # too small, then we slightly expand it if it happens
+ if not np.any(cuttoff):
+ cuttoff = dists < (const.interface_cutoff + 5.0)
+
+ tokens_1 = tokens_1[np.any(cuttoff, axis=1)]
+ tokens_2 = tokens_2[np.any(cuttoff, axis=0)]
+
+ # Select random token
+ candidates = np.concatenate([tokens_1, tokens_2])
+ query = pick_random_token(candidates, random)
+
+ return query
+
+
+class BoltzCropper(Cropper):
+ """Interpolate between contiguous and spatial crops."""
+
+ def __init__(self, min_neighborhood: int = 0, max_neighborhood: int = 40) -> None:
+ """Initialize the cropper.
+
+ Modulates the type of cropping to be performed.
+ Smaller neighborhoods result in more spatial
+ cropping. Larger neighborhoods result in more
+ continuous cropping. A mix can be achieved by
+ providing a range over which to sample.
+
+ Parameters
+ ----------
+ min_neighborhood : int
+ The minimum neighborhood size, by default 0.
+ max_neighborhood : int
+ The maximum neighborhood size, by default 40.
+
+ """
+ sizes = list(range(min_neighborhood, max_neighborhood + 1, 2))
+ self.neighborhood_sizes = sizes
+
+ def crop( # noqa: PLR0915
+ self,
+ data: Tokenized,
+ max_tokens: int,
+ random: np.random.RandomState,
+ max_atoms: Optional[int] = None,
+ chain_id: Optional[int] = None,
+ interface_id: Optional[int] = None,
+ ) -> Tokenized:
+ """Crop the data to a maximum number of tokens.
+
+ Parameters
+ ----------
+ data : Tokenized
+ The tokenized data.
+ max_tokens : int
+ The maximum number of tokens to crop.
+ random : np.random.RandomState
+ The random state for reproducibility.
+ max_atoms : int, optional
+ The maximum number of atoms to consider.
+ chain_id : int, optional
+ The chain ID to crop.
+ interface_id : int, optional
+ The interface ID to crop.
+
+ Returns
+ -------
+ Tokenized
+ The cropped data.
+
+ """
+ # Check inputs
+ if chain_id is not None and interface_id is not None:
+ msg = "Only one of chain_id or interface_id can be provided."
+ raise ValueError(msg)
+
+ # Randomly select a neighborhood size
+ neighborhood_size = random.choice(self.neighborhood_sizes)
+
+ # Get token data
+ token_data = data.tokens
+ token_bonds = data.bonds
+ mask = data.structure.mask
+ chains = data.structure.chains
+ interfaces = data.structure.interfaces
+
+ # Filter to valid chains
+ valid_chains = chains[mask]
+
+ # Filter to valid interfaces
+ valid_interfaces = interfaces
+ valid_interfaces = valid_interfaces[mask[valid_interfaces["chain_1"]]]
+ valid_interfaces = valid_interfaces[mask[valid_interfaces["chain_2"]]]
+
+ # Filter to resolved tokens
+ valid_tokens = token_data[token_data["resolved_mask"]]
+
+ # Check if we have any valid tokens
+ if not valid_tokens.size:
+ msg = "No valid tokens in structure"
+ raise ValueError(msg)
+
+ # Pick a random token, chain, or interface
+ if chain_id is not None:
+ query = pick_chain_token(valid_tokens, chain_id, random)
+ elif interface_id is not None:
+ interface = interfaces[interface_id]
+ query = pick_interface_token(valid_tokens, interface, random)
+ elif valid_interfaces.size:
+ idx = random.randint(len(valid_interfaces))
+ interface = valid_interfaces[idx]
+ query = pick_interface_token(valid_tokens, interface, random)
+ else:
+ idx = random.randint(len(valid_chains))
+ chain_id = valid_chains[idx]["asym_id"]
+ query = pick_chain_token(valid_tokens, chain_id, random)
+
+ # Sort all tokens by distance to query_coords
+ dists = valid_tokens["center_coords"] - query["center_coords"]
+ indices = np.argsort(np.linalg.norm(dists, axis=1))
+
+ # Select cropped indices
+ cropped: set[int] = set()
+ total_atoms = 0
+ for idx in indices:
+ # Get the token
+ token = valid_tokens[idx]
+
+ # Get all tokens from this chain
+ chain_tokens = token_data[token_data["asym_id"] == token["asym_id"]]
+
+ # Pick the whole chain if possible, otherwise select
+ # a contiguous subset centered at the query token
+ if len(chain_tokens) <= neighborhood_size:
+ new_tokens = chain_tokens
+ else:
+ # First limit to the maximum set of tokens, with the
+ # neighboorhood on both sides to handle edges. This
+ # is mostly for efficiency with the while loop below.
+ min_idx = token["res_idx"] - neighborhood_size
+ max_idx = token["res_idx"] + neighborhood_size
+
+ max_token_set = chain_tokens
+ max_token_set = max_token_set[max_token_set["res_idx"] >= min_idx]
+ max_token_set = max_token_set[max_token_set["res_idx"] <= max_idx]
+
+ # Start by adding just the query token
+ new_tokens = max_token_set[max_token_set["res_idx"] == token["res_idx"]]
+
+ # Expand the neighborhood until we have enough tokens, one
+ # by one to handle some edge cases with non-standard chains.
+ # We switch to the res_idx instead of the token_idx to always
+ # include all tokens from modified residues or from ligands.
+ min_idx = max_idx = token["res_idx"]
+ while new_tokens.size < neighborhood_size:
+ min_idx = min_idx - 1
+ max_idx = max_idx + 1
+ new_tokens = max_token_set
+ new_tokens = new_tokens[new_tokens["res_idx"] >= min_idx]
+ new_tokens = new_tokens[new_tokens["res_idx"] <= max_idx]
+
+ # Compute new tokens and new atoms
+ new_indices = set(new_tokens["token_idx"]) - cropped
+ new_tokens = token_data[list(new_indices)]
+ new_atoms = np.sum(new_tokens["atom_num"])
+
+ # Stop if we exceed the max number of tokens or atoms
+ if (len(new_indices) > (max_tokens - len(cropped))) or (
+ (max_atoms is not None) and ((total_atoms + new_atoms) > max_atoms)
+ ):
+ break
+
+ # Add new indices
+ cropped.update(new_indices)
+ total_atoms += new_atoms
+
+ # Get the cropped tokens sorted by index
+ token_data = token_data[sorted(cropped)]
+
+ # Only keep bonds within the cropped tokens
+ indices = token_data["token_idx"]
+ token_bonds = token_bonds[np.isin(token_bonds["token_1"], indices)]
+ token_bonds = token_bonds[np.isin(token_bonds["token_2"], indices)]
+
+ # Return the cropped tokens
+ return replace(data, tokens=token_data, bonds=token_bonds)
diff --git a/src/boltz/data/crop/cropper.py b/src/boltz/data/crop/cropper.py
new file mode 100644
index 0000000..4eb1dbf
--- /dev/null
+++ b/src/boltz/data/crop/cropper.py
@@ -0,0 +1,45 @@
+from abc import ABC, abstractmethod
+from typing import Optional
+
+import numpy as np
+
+from boltz.data.types import Tokenized
+
+
+class Cropper(ABC):
+ """Abstract base class for cropper."""
+
+ @abstractmethod
+ def crop(
+ self,
+ data: Tokenized,
+ max_tokens: int,
+ random: np.random.RandomState,
+ max_atoms: Optional[int] = None,
+ chain_id: Optional[int] = None,
+ interface_id: Optional[int] = None,
+ ) -> Tokenized:
+ """Crop the data to a maximum number of tokens.
+
+ Parameters
+ ----------
+ data : Tokenized
+ The tokenized data.
+ max_tokens : int
+ The maximum number of tokens to crop.
+ random : np.random.RandomState
+ The random state for reproducibility.
+ max_atoms : Optional[int]
+ The maximum number of atoms to consider.
+ chain_id : Optional[int]
+ The chain ID to crop.
+ interface_id : Optional[int]
+ The interface ID to crop.
+
+ Returns
+ -------
+ Tokenized
+ The cropped data.
+
+ """
+ raise NotImplementedError
diff --git a/src/boltz/data/feature/__init__.py b/src/boltz/data/feature/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/boltz/data/feature/featurizer.py b/src/boltz/data/feature/featurizer.py
new file mode 100644
index 0000000..7704b0d
--- /dev/null
+++ b/src/boltz/data/feature/featurizer.py
@@ -0,0 +1,977 @@
+import math
+import random
+from typing import Optional
+
+import numpy as np
+import torch
+from torch import Tensor, from_numpy
+from torch.nn.functional import one_hot
+
+from boltz.data import const
+from boltz.data.feature.pad import pad_dim
+from boltz.data.feature.symmetry import (
+ get_amino_acids_symmetries,
+ get_chain_symmetries,
+ get_ligand_symmetries,
+)
+from boltz.data.types import (
+ MSA,
+ MSADeletion,
+ MSAResidue,
+ MSASequence,
+ Tokenized,
+)
+from boltz.model.modules.utils import center_random_augmentation
+
+####################################################################################################
+# HELPERS
+####################################################################################################
+
+
+def compute_frames_nonpolymer(
+ data: Tokenized,
+ coords,
+ resolved_mask,
+ atom_to_token,
+ frame_data: list,
+ resolved_frame_data: list,
+) -> tuple[list, list]:
+ """Get the frames for non-polymer tokens.
+
+ Parameters
+ ----------
+ data : Tokenized
+ The tokenized data.
+ frame_data : list
+ The frame data.
+ resolved_frame_data : list
+ The resolved frame data.
+
+ Returns
+ -------
+ tuple[list, list]
+ The frame data and resolved frame data.
+
+ """
+ frame_data = np.array(frame_data)
+ resolved_frame_data = np.array(resolved_frame_data)
+ asym_id_token = data.tokens["asym_id"]
+ asym_id_atom = data.tokens["asym_id"][atom_to_token]
+ token_idx = 0
+ atom_idx = 0
+ for id in np.unique(data.tokens["asym_id"]):
+ mask_chain_token = asym_id_token == id
+ mask_chain_atom = asym_id_atom == id
+ num_tokens = mask_chain_token.sum()
+ num_atoms = mask_chain_atom.sum()
+ if (
+ data.tokens[token_idx]["mol_type"] != const.chain_type_ids["NONPOLYMER"]
+ or num_atoms < 3
+ ):
+ token_idx += num_tokens
+ atom_idx += num_atoms
+ continue
+ dist_mat = (
+ (
+ coords.reshape(-1, 3)[mask_chain_atom][:, None, :]
+ - coords.reshape(-1, 3)[mask_chain_atom][None, :, :]
+ )
+ ** 2
+ ).sum(-1) ** 0.5
+ resolved_pair = 1 - (
+ resolved_mask[mask_chain_atom][None, :]
+ * resolved_mask[mask_chain_atom][:, None]
+ ).astype(np.float32)
+ resolved_pair[resolved_pair == 1] = math.inf
+ indices = np.argsort(dist_mat + resolved_pair, axis=1)
+ frames = (
+ np.concatenate(
+ [
+ indices[:, 1:2],
+ indices[:, 0:1],
+ indices[:, 2:3],
+ ],
+ axis=1,
+ )
+ + atom_idx
+ )
+ frame_data[token_idx : token_idx + num_atoms, :] = frames
+ resolved_frame_data[token_idx : token_idx + num_atoms] = resolved_mask[
+ frames
+ ].all(axis=1)
+ token_idx += num_tokens
+ atom_idx += num_atoms
+ frames_expanded = coords.reshape(-1, 3)[frame_data]
+
+ mask_collinear = compute_collinear_mask(
+ frames_expanded[:, 1] - frames_expanded[:, 0],
+ frames_expanded[:, 1] - frames_expanded[:, 2],
+ )
+ return frame_data, resolved_frame_data & mask_collinear
+
+
+def compute_collinear_mask(v1, v2):
+ norm1 = np.linalg.norm(v1, axis=1, keepdims=True)
+ norm2 = np.linalg.norm(v2, axis=1, keepdims=True)
+ v1 = v1 / (norm1 + 1e-6)
+ v2 = v2 / (norm2 + 1e-6)
+ mask_angle = np.abs(np.sum(v1 * v2, axis=1)) < 0.9063
+ mask_overlap1 = norm1.reshape(-1) > 1e-2
+ mask_overlap2 = norm2.reshape(-1) > 1e-2
+ return mask_angle & mask_overlap1 & mask_overlap2
+
+
+def dummy_msa(residues: np.ndarray) -> MSA:
+ """Create a dummy MSA for a chain.
+
+ Parameters
+ ----------
+ residues : np.ndarray
+ The residues for the chain.
+
+ Returns
+ -------
+ MSA
+ The dummy MSA.
+
+ """
+ residues = [res["res_type"] for res in residues]
+ deletions = []
+ sequences = [(0, -1, 0, len(residues), 0, 0)]
+ return MSA(
+ residues=np.array(residues, dtype=MSAResidue),
+ deletions=np.array(deletions, dtype=MSADeletion),
+ sequences=np.array(sequences, dtype=MSASequence),
+ )
+
+
+def construct_paired_msa( # noqa: C901, PLR0915, PLR0912
+ data: Tokenized,
+ max_seqs: int,
+ max_pairs: int = 8192,
+ max_total: int = 16384,
+ random_subset: bool = False,
+) -> tuple[Tensor, Tensor, Tensor]:
+ """Pair the MSA data.
+
+ Parameters
+ ----------
+ data : Input
+ The input data.
+
+ Returns
+ -------
+ Tensor
+ The MSA data.
+ Tensor
+ The deletion data.
+ Tensor
+ Mask indicating paired sequences.
+
+ """
+ # Get unique chains (ensuring monotonicity in the order)
+ assert np.all(np.diff(data.tokens["asym_id"], n=1) >= 0)
+ chain_ids = np.unique(data.tokens["asym_id"])
+
+ # Get relevant MSA, and create a dummy for chains without
+ msa = {k: data.msa[k] for k in chain_ids if k in data.msa}
+ for chain_id in chain_ids:
+ if chain_id not in msa:
+ chain = data.structure.chains[chain_id]
+ res_start = chain["res_idx"]
+ res_end = res_start + chain["res_num"]
+ residues = data.structure.residues[res_start:res_end]
+ msa[chain_id] = dummy_msa(residues)
+
+ # Map taxonomies to (chain_id, seq_idx)
+ taxonomy_map: dict[str, list] = {}
+ for chain_id, chain_msa in msa.items():
+ sequences = chain_msa.sequences
+ sequences = sequences[sequences["taxonomy"] != -1]
+ for sequence in sequences:
+ seq_idx = sequence["seq_idx"]
+ taxon = sequence["taxonomy"]
+ taxonomy_map.setdefault(taxon, []).append((chain_id, seq_idx))
+
+ # Remove taxonomies with only one sequence and sort by the
+ # number of chain_id present in each of the taxonomies
+ taxonomy_map = {k: v for k, v in taxonomy_map.items() if len(v) > 1}
+ taxonomy_map = sorted(
+ taxonomy_map.items(),
+ key=lambda x: len({c for c, _ in x[1]}),
+ reverse=True,
+ )
+
+ # Keep track of the sequences available per chain, keeping the original
+ # order of the sequences in the MSA to favor the best matching sequences
+ visited = {(c, s) for c, items in taxonomy_map for s in items}
+ available = {}
+ for c in chain_ids:
+ available[c] = [
+ i for i in range(1, len(msa[c].sequences)) if (c, i) not in visited
+ ]
+
+ # Create sequence pairs
+ is_paired = []
+ pairing = []
+
+ # Start with the first sequence for each chain
+ is_paired.append({c: 1 for c in chain_ids})
+ pairing.append({c: 0 for c in chain_ids})
+
+ # Then add up to 8191 paired rows
+ for _, pairs in taxonomy_map:
+ # Group occurences by chain_id in case we have multiple
+ # sequences from the same chain and same taxonomy
+ chain_occurences = {}
+ for chain_id, seq_idx in pairs:
+ chain_occurences.setdefault(chain_id, []).append(seq_idx)
+
+ # We create as many pairings as the maximum number of occurences
+ max_occurences = max(len(v) for v in chain_occurences.values())
+ for i in range(max_occurences):
+ row_pairing = {}
+ row_is_paired = {}
+
+ # Add the chains present in the taxonomy
+ for chain_id, seq_idxs in chain_occurences.items():
+ # Roll over the sequence index to maximize diversity
+ idx = i % len(seq_idxs)
+ seq_idx = seq_idxs[idx]
+
+ # Add the sequence to the pairing
+ row_pairing[chain_id] = seq_idx
+ row_is_paired[chain_id] = 1
+
+ # Add any missing chains
+ for chain_id in chain_ids:
+ if chain_id not in row_pairing:
+ row_is_paired[chain_id] = 0
+ if available[chain_id]:
+ # Add the next available sequence
+ seq_idx = available[chain_id].pop(0)
+ row_pairing[chain_id] = seq_idx
+ else:
+ # No more sequences available, we place a gap
+ row_pairing[chain_id] = -1
+
+ pairing.append(row_pairing)
+ is_paired.append(row_is_paired)
+
+ # Break if we have enough pairs
+ if len(pairing) >= max_pairs:
+ break
+
+ # Break if we have enough pairs
+ if len(pairing) >= max_pairs:
+ break
+
+ # Now add up to 16384 unpaired rows total
+ max_left = max(len(v) for v in available.values())
+ for _ in range(min(max_total - len(pairing), max_left)):
+ row_pairing = {}
+ row_is_paired = {}
+ for chain_id in chain_ids:
+ row_is_paired[chain_id] = 0
+ if available[chain_id]:
+ # Add the next available sequence
+ seq_idx = available[chain_id].pop(0)
+ row_pairing[chain_id] = seq_idx
+ else:
+ # No more sequences available, we place a gap
+ row_pairing[chain_id] = -1
+
+ pairing.append(row_pairing)
+ is_paired.append(row_is_paired)
+
+ # Break if we have enough sequences
+ if len(pairing) >= max_total:
+ break
+
+ # Randomly sample a subset of the pairs
+ # ensuring the first row is always present
+ if random_subset:
+ num_seqs = len(pairing)
+ if num_seqs > max_seqs:
+ indices = np.random.choice(list(range(1, num_seqs)), replace=False) # noqa: NPY002
+ pairing = [pairing[0]] + [pairing[i] for i in indices]
+ is_paired = [is_paired[0]] + [is_paired[i] for i in indices]
+ else:
+ # Deterministic downsample to max_seqs
+ pairing = pairing[:max_seqs]
+ is_paired = is_paired[:max_seqs]
+
+ # Create MSA data
+ msa_data = []
+ del_data = []
+ paired_data = []
+ gap_token_id = const.token_ids["-"]
+
+ # Map (chain_id, seq_idx, res_idx) to deletion
+ deletions = {}
+ for chain_id, chain_msa in msa.items():
+ chain_deletions = chain_msa.deletions
+ for sequence in chain_msa.sequences:
+ del_start = sequence["del_start"]
+ del_end = sequence["del_end"]
+ chain_deletions = chain_msa.deletions[del_start:del_end]
+ for deletion_data in chain_deletions:
+ seq_idx = sequence["seq_idx"]
+ res_idx = deletion_data["res_idx"]
+ deletion = deletion_data["deletion"]
+ deletions[(chain_id, seq_idx, res_idx)] = deletion
+
+ # Add all the token MSA data
+ for token in data.tokens:
+ token_res_types = []
+ token_deletions = []
+ token_is_paired = []
+ for row_pairing, row_is_paired in zip(pairing, is_paired):
+ res_idx = int(token["res_idx"])
+ chain_id = int(token["asym_id"])
+ seq_idx = row_pairing[chain_id]
+ token_is_paired.append(row_is_paired[chain_id])
+
+ # Add residue type
+ if seq_idx == -1:
+ token_res_types.append(gap_token_id)
+ token_deletions.append(0)
+ else:
+ sequence = msa[chain_id].sequences[seq_idx]
+ res_start = sequence["res_start"]
+ res_type = msa[chain_id].residues[res_start + res_idx][0]
+ deletion = deletions.get((chain_id, seq_idx, res_idx), 0)
+ token_res_types.append(res_type)
+ token_deletions.append(deletion)
+
+ msa_data.append(token_res_types)
+ del_data.append(token_deletions)
+ paired_data.append(token_is_paired)
+
+ msa_data = torch.tensor(msa_data, dtype=torch.long)
+ del_data = torch.tensor(del_data, dtype=torch.float)
+ paired_data = torch.tensor(paired_data, dtype=torch.float)
+
+ return msa_data, del_data, paired_data
+
+
+####################################################################################################
+# FEATURES
+####################################################################################################
+
+
+def select_subset_from_mask(mask, p):
+ num_true = np.sum(mask)
+ v = np.random.geometric(p) + 1
+ k = min(v, num_true)
+
+ true_indices = np.where(mask)[0]
+
+ # Randomly select k indices from the true_indices
+ selected_indices = np.random.choice(true_indices, size=k, replace=False)
+
+ new_mask = np.zeros_like(mask)
+ new_mask[selected_indices] = 1
+
+ return new_mask
+
+
+def process_token_features(
+ data: Tokenized,
+ max_tokens: Optional[int] = None,
+ binder_pocket_conditioned_prop: Optional[float] = 0.0,
+ binder_pocket_cutoff: Optional[float] = 6.0,
+ binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
+ only_ligand_binder_pocket: Optional[bool] = False,
+ inference_binder: Optional[int] = None,
+ inference_pocket: Optional[list[tuple[int, int]]] = None,
+) -> dict[str, Tensor]:
+ """Get the token features.
+
+ Parameters
+ ----------
+ data : Tokenized
+ The tokenized data.
+ max_tokens : int
+ The maximum number of tokens.
+
+ Returns
+ -------
+ dict[str, Tensor]
+ The token features.
+
+ """
+ # Token data
+ token_data = data.tokens
+ token_bonds = data.bonds
+
+ # Token core features
+ token_index = torch.arange(len(token_data), dtype=torch.long)
+ residue_index = from_numpy(token_data["res_idx"]).long()
+ asym_id = from_numpy(token_data["asym_id"]).long()
+ entity_id = from_numpy(token_data["entity_id"]).long()
+ sym_id = from_numpy(token_data["sym_id"]).long()
+ mol_type = from_numpy(token_data["mol_type"]).long()
+ res_type = from_numpy(token_data["res_type"]).long()
+ res_type = one_hot(res_type, num_classes=const.num_tokens)
+ disto_center = from_numpy(token_data["disto_coords"])
+
+ # Token mask features
+ pad_mask = torch.ones(len(token_data), dtype=torch.float)
+ resolved_mask = from_numpy(token_data["resolved_mask"]).float()
+ disto_mask = from_numpy(token_data["disto_mask"]).float()
+
+ # Token bond features
+ if max_tokens is not None:
+ pad_len = max_tokens - len(token_data)
+ num_tokens = max_tokens if pad_len > 0 else len(token_data)
+ else:
+ num_tokens = len(token_data)
+
+ tok_to_idx = {tok["token_idx"]: idx for idx, tok in enumerate(token_data)}
+ bonds = torch.zeros(num_tokens, num_tokens, dtype=torch.float)
+ for token_bond in token_bonds:
+ token_1 = tok_to_idx[token_bond["token_1"]]
+ token_2 = tok_to_idx[token_bond["token_2"]]
+ bonds[token_1, token_2] = 1
+ bonds[token_2, token_1] = 1
+
+ bonds = bonds.unsqueeze(-1)
+
+ # Pocket conditioned feature
+ pocket_feature = (
+ np.zeros(len(token_data)) + const.pocket_contact_info["UNSPECIFIED"]
+ )
+ if inference_binder is not None:
+ assert inference_pocket is not None
+ pocket_residues = set(inference_pocket)
+ for idx, token in enumerate(token_data):
+ if token["asym_id"] == inference_binder:
+ pocket_feature[idx] = const.pocket_contact_info["BINDER"]
+ elif (token["asym_id"], token["res_idx"]) in pocket_residues:
+ pocket_feature[idx] = const.pocket_contact_info["POCKET"]
+ elif (
+ binder_pocket_conditioned_prop > 0.0
+ and random.random() < binder_pocket_conditioned_prop
+ ):
+ # choose as binder a random ligand in the crop, if there are no ligands select a protein chain
+ binder_asym_ids = np.unique(
+ token_data["asym_id"][
+ token_data["mol_type"] == const.chain_type_ids["NONPOLYMER"]
+ ]
+ )
+
+ if len(binder_asym_ids) == 0:
+ if not only_ligand_binder_pocket:
+ binder_asym_ids = np.unique(token_data["asym_id"])
+
+ if len(binder_asym_ids) > 0:
+ pocket_asym_id = random.choice(binder_asym_ids)
+ binder_mask = token_data["asym_id"] == pocket_asym_id
+
+ binder_coords = []
+ for token in token_data:
+ if token["asym_id"] == pocket_asym_id:
+ binder_coords.append(
+ data.structure.atoms["coords"][
+ token["atom_idx"] : token["atom_idx"] + token["atom_num"]
+ ]
+ )
+ binder_coords = np.concatenate(binder_coords, axis=0)
+
+ # find the tokens in the pocket
+ token_dist = np.zeros(len(token_data)) + 1000
+ for i, token in enumerate(token_data):
+ if (
+ token["mol_type"] != const.chain_type_ids["NONPOLYMER"]
+ and token["asym_id"] != pocket_asym_id
+ and token["resolved_mask"] == 1
+ ):
+ token_coords = data.structure.atoms["coords"][
+ token["atom_idx"] : token["atom_idx"] + token["atom_num"]
+ ]
+
+ # find chain and apply chain transformation
+ for chain in data.structure.chains:
+ if chain["asym_id"] == token["asym_id"]:
+ break
+
+ token_dist[i] = np.min(
+ np.linalg.norm(
+ token_coords[:, None, :] - binder_coords[None, :, :],
+ axis=-1,
+ )
+ )
+
+ pocket_mask = token_dist < binder_pocket_cutoff
+
+ if np.sum(pocket_mask) > 0:
+ pocket_feature = (
+ np.zeros(len(token_data)) + const.pocket_contact_info["UNSELECTED"]
+ )
+ pocket_feature[binder_mask] = const.pocket_contact_info["BINDER"]
+
+ if binder_pocket_sampling_geometric_p > 0.0:
+ # select a subset of the pocket, according
+ # to a geometric distribution with one as minimum
+ pocket_mask = select_subset_from_mask(
+ pocket_mask, binder_pocket_sampling_geometric_p
+ )
+
+ pocket_feature[pocket_mask] = const.pocket_contact_info["POCKET"]
+ pocket_feature = from_numpy(pocket_feature).long()
+ pocket_feature = one_hot(pocket_feature, num_classes=len(const.pocket_contact_info))
+
+ # Pad to max tokens if given
+ if max_tokens is not None:
+ pad_len = max_tokens - len(token_data)
+ if pad_len > 0:
+ token_index = pad_dim(token_index, 0, pad_len)
+ residue_index = pad_dim(residue_index, 0, pad_len)
+ asym_id = pad_dim(asym_id, 0, pad_len)
+ entity_id = pad_dim(entity_id, 0, pad_len)
+ sym_id = pad_dim(sym_id, 0, pad_len)
+ mol_type = pad_dim(mol_type, 0, pad_len)
+ res_type = pad_dim(res_type, 0, pad_len)
+ disto_center = pad_dim(disto_center, 0, pad_len)
+ pad_mask = pad_dim(pad_mask, 0, pad_len)
+ resolved_mask = pad_dim(resolved_mask, 0, pad_len)
+ disto_mask = pad_dim(disto_mask, 0, pad_len)
+ pocket_feature = pad_dim(pocket_feature, 0, pad_len)
+
+ token_features = {
+ "token_index": token_index,
+ "residue_index": residue_index,
+ "asym_id": asym_id,
+ "entity_id": entity_id,
+ "sym_id": sym_id,
+ "mol_type": mol_type,
+ "res_type": res_type,
+ "disto_center": disto_center,
+ "token_bonds": bonds,
+ "token_pad_mask": pad_mask,
+ "token_resolved_mask": resolved_mask,
+ "token_disto_mask": disto_mask,
+ "pocket_feature": pocket_feature,
+ }
+ return token_features
+
+
+def process_atom_features(
+ data: Tokenized,
+ atoms_per_window_queries: int = 32,
+ min_dist: float = 2.0,
+ max_dist: float = 22.0,
+ num_bins: int = 64,
+ max_atoms: Optional[int] = None,
+ max_tokens: Optional[int] = None,
+) -> dict[str, Tensor]:
+ """Get the atom features.
+
+ Parameters
+ ----------
+ data : Tokenized
+ The tokenized data.
+ max_atoms : int, optional
+ The maximum number of atoms.
+
+ Returns
+ -------
+ dict[str, Tensor]
+ The atom features.
+
+ """
+ # Filter to tokens' atoms
+ atom_data = []
+ ref_space_uid = []
+ coord_data = []
+ frame_data = []
+ resolved_frame_data = []
+ atom_to_token = []
+ token_to_rep_atom = [] # index on cropped atom table
+ r_set_to_rep_atom = []
+ disto_coords = []
+ atom_idx = 0
+
+ chain_res_ids = {}
+ for token_id, token in enumerate(data.tokens):
+ # Get the chain residue ids
+ chain_idx, res_id = token["asym_id"], token["res_idx"]
+ chain = data.structure.chains[chain_idx]
+
+ if (chain_idx, res_id) not in chain_res_ids:
+ new_idx = len(chain_res_ids)
+ chain_res_ids[(chain_idx, res_id)] = new_idx
+ else:
+ new_idx = chain_res_ids[(chain_idx, res_id)]
+
+ # Map atoms to token indices
+ ref_space_uid.extend([new_idx] * token["atom_num"])
+ atom_to_token.extend([token_id] * token["atom_num"])
+
+ # Add atom data
+ start = token["atom_idx"]
+ end = token["atom_idx"] + token["atom_num"]
+ token_atoms = data.structure.atoms[start:end]
+
+ # Map token to representative atom
+ token_to_rep_atom.append(atom_idx + token["disto_idx"] - start)
+ if (chain["mol_type"] != const.chain_type_ids["NONPOLYMER"]) and token[
+ "resolved_mask"
+ ]:
+ r_set_to_rep_atom.append(atom_idx + token["center_idx"] - start)
+
+ # Get token coordinates
+ token_coords = np.array([token_atoms["coords"]])
+ coord_data.append(token_coords)
+
+ # Get frame data
+ res_type = const.tokens[token["res_type"]]
+
+ if token["atom_num"] < 3 or res_type in ["PAD", "UNK", "-"]:
+ idx_frame_a, idx_frame_b, idx_frame_c = 0, 0, 0
+ mask_frame = False
+ elif (token["mol_type"] == const.chain_type_ids["PROTEIN"]) and (
+ res_type in const.ref_atoms
+ ):
+ idx_frame_a, idx_frame_b, idx_frame_c = (
+ const.ref_atoms[res_type].index("N"),
+ const.ref_atoms[res_type].index("CA"),
+ const.ref_atoms[res_type].index("C"),
+ )
+ mask_frame = (
+ token_atoms["is_present"][idx_frame_a]
+ and token_atoms["is_present"][idx_frame_b]
+ and token_atoms["is_present"][idx_frame_c]
+ )
+ elif (
+ token["mol_type"] == const.chain_type_ids["DNA"]
+ or token["mol_type"] == const.chain_type_ids["RNA"]
+ ) and (res_type in const.ref_atoms):
+ idx_frame_a, idx_frame_b, idx_frame_c = (
+ const.ref_atoms[res_type].index("C1'"),
+ const.ref_atoms[res_type].index("C3'"),
+ const.ref_atoms[res_type].index("C4'"),
+ )
+ mask_frame = (
+ token_atoms["is_present"][idx_frame_a]
+ and token_atoms["is_present"][idx_frame_b]
+ and token_atoms["is_present"][idx_frame_c]
+ )
+ else:
+ idx_frame_a, idx_frame_b, idx_frame_c = 0, 0, 0
+ mask_frame = False
+ frame_data.append(
+ [idx_frame_a + atom_idx, idx_frame_b + atom_idx, idx_frame_c + atom_idx]
+ )
+ resolved_frame_data.append(mask_frame)
+
+ # Get distogram coordinates
+ disto_coords_tok = data.structure.atoms[token["disto_idx"]]["coords"]
+ disto_coords.append(disto_coords_tok)
+
+ # Update atom data. This is technically never used again (we rely on coord_data),
+ # but we update for consistency and to make sure the Atom object has valid, transformed coordinates.
+ token_atoms = token_atoms.copy()
+ token_atoms["coords"] = token_coords[0] # atom has a copy of first coords
+ atom_data.append(token_atoms)
+ atom_idx += len(token_atoms)
+
+ disto_coords = np.array(disto_coords)
+
+ # Compute distogram
+ t_center = torch.Tensor(disto_coords)
+ t_dists = torch.cdist(t_center, t_center)
+ boundaries = torch.linspace(min_dist, max_dist, num_bins - 1)
+ distogram = (t_dists.unsqueeze(-1) > boundaries).sum(dim=-1).long()
+ disto_target = one_hot(distogram, num_classes=num_bins)
+
+ atom_data = np.concatenate(atom_data)
+ coord_data = np.concatenate(coord_data, axis=1)
+ ref_space_uid = np.array(ref_space_uid)
+
+ # Compute features
+ ref_atom_name_chars = from_numpy(atom_data["name"]).long()
+ ref_element = from_numpy(atom_data["element"]).long()
+ ref_charge = from_numpy(atom_data["charge"])
+ ref_pos = from_numpy(
+ atom_data["conformer"].copy()
+ ) # not sure why I need to copy here..
+ ref_space_uid = from_numpy(ref_space_uid)
+ coords = from_numpy(coord_data.copy())
+ resolved_mask = from_numpy(atom_data["is_present"])
+ pad_mask = torch.ones(len(atom_data), dtype=torch.float)
+ atom_to_token = torch.tensor(atom_to_token, dtype=torch.long)
+ token_to_rep_atom = torch.tensor(token_to_rep_atom, dtype=torch.long)
+ r_set_to_rep_atom = torch.tensor(r_set_to_rep_atom, dtype=torch.long)
+ frame_data, resolved_frame_data = compute_frames_nonpolymer(
+ data,
+ coord_data,
+ atom_data["is_present"],
+ atom_to_token,
+ frame_data,
+ resolved_frame_data,
+ ) # Compute frames for NONPOLYMER tokens
+ frames = from_numpy(frame_data.copy())
+ frame_resolved_mask = from_numpy(resolved_frame_data.copy())
+ # Convert to one-hot
+ ref_atom_name_chars = one_hot(
+ ref_atom_name_chars % num_bins, num_classes=num_bins
+ ) # added for lower case letters
+ ref_element = one_hot(ref_element, num_classes=const.num_elements)
+ atom_to_token = one_hot(atom_to_token, num_classes=token_id + 1)
+ token_to_rep_atom = one_hot(token_to_rep_atom, num_classes=len(atom_data))
+ r_set_to_rep_atom = one_hot(r_set_to_rep_atom, num_classes=len(atom_data))
+
+ # Center the ground truth coordinates
+ center = (coords * resolved_mask[None, :, None]).sum(dim=1)
+ center = center / resolved_mask.sum().clamp(min=1)
+ coords = coords - center[:, None]
+
+ # Apply random roto-translation to the input atoms
+ ref_pos = center_random_augmentation(
+ ref_pos[None], resolved_mask[None], centering=False
+ )[0]
+
+ # Compute padding and apply
+ if max_atoms is not None:
+ assert max_atoms % atoms_per_window_queries == 0
+ pad_len = max_atoms - len(atom_data)
+ else:
+ pad_len = (
+ (len(atom_data) - 1) // atoms_per_window_queries + 1
+ ) * atoms_per_window_queries - len(atom_data)
+
+ if pad_len > 0:
+ pad_mask = pad_dim(pad_mask, 0, pad_len)
+ ref_pos = pad_dim(ref_pos, 0, pad_len)
+ resolved_mask = pad_dim(resolved_mask, 0, pad_len)
+ ref_element = pad_dim(ref_element, 0, pad_len)
+ ref_charge = pad_dim(ref_charge, 0, pad_len)
+ ref_atom_name_chars = pad_dim(ref_atom_name_chars, 0, pad_len)
+ ref_space_uid = pad_dim(ref_space_uid, 0, pad_len)
+ coords = pad_dim(coords, 1, pad_len)
+ atom_to_token = pad_dim(atom_to_token, 0, pad_len)
+ token_to_rep_atom = pad_dim(token_to_rep_atom, 1, pad_len)
+ r_set_to_rep_atom = pad_dim(r_set_to_rep_atom, 1, pad_len)
+
+ if max_tokens is not None:
+ pad_len = max_tokens - token_to_rep_atom.shape[0]
+ if pad_len > 0:
+ atom_to_token = pad_dim(atom_to_token, 1, pad_len)
+ token_to_rep_atom = pad_dim(token_to_rep_atom, 0, pad_len)
+ r_set_to_rep_atom = pad_dim(r_set_to_rep_atom, 0, pad_len)
+ disto_target = pad_dim(pad_dim(disto_target, 0, pad_len), 1, pad_len)
+ frames = pad_dim(frames, 0, pad_len)
+ frame_resolved_mask = pad_dim(frame_resolved_mask, 0, pad_len)
+
+ return {
+ "ref_pos": ref_pos,
+ "atom_resolved_mask": resolved_mask,
+ "ref_element": ref_element,
+ "ref_charge": ref_charge,
+ "ref_atom_name_chars": ref_atom_name_chars,
+ "ref_space_uid": ref_space_uid,
+ "coords": coords,
+ "atom_pad_mask": pad_mask,
+ "atom_to_token": atom_to_token,
+ "token_to_rep_atom": token_to_rep_atom,
+ "r_set_to_rep_atom": r_set_to_rep_atom,
+ "disto_target": disto_target,
+ "frames_idx": frames,
+ "frame_resolved_mask": frame_resolved_mask,
+ }
+
+
+def process_msa_features(
+ data: Tokenized,
+ max_seqs_batch: int,
+ max_seqs: int,
+ max_tokens: Optional[int] = None,
+ pad_to_max_seqs: bool = False,
+) -> dict[str, Tensor]:
+ """Get the MSA features.
+
+ Parameters
+ ----------
+ data : Tokenized
+ The tokenized data.
+ max_seqs : int
+ The maximum number of MSA sequences.
+ max_tokens : int
+ The maximum number of tokens.
+ pad_to_max_seqs : bool
+ Whether to pad to the maximum number of sequences.
+
+ Returns
+ -------
+ dict[str, Tensor]
+ The MSA features.
+
+ """
+ # Created paired MSA
+ msa, deletion, paired = construct_paired_msa(data, max_seqs_batch)
+ msa, deletion, paired = (
+ msa.transpose(1, 0),
+ deletion.transpose(1, 0),
+ paired.transpose(1, 0),
+ ) # (N_MSA, N_RES, N_AA)
+
+ # Prepare features
+ msa = torch.nn.functional.one_hot(msa, num_classes=const.num_tokens)
+ msa_mask = torch.ones_like(msa[:, :, 0])
+ profile = msa.float().mean(dim=0)
+ has_deletion = deletion > 0
+ deletion = np.pi / 2 * np.arctan(deletion / 3)
+ deletion_mean = deletion.mean(axis=0)
+
+ # Pad in the MSA dimension (dim=0)
+ if pad_to_max_seqs:
+ pad_len = max_seqs - msa.shape[0]
+ if pad_len > 0:
+ msa = pad_dim(msa, 0, pad_len, const.token_ids["-"])
+ paired = pad_dim(paired, 0, pad_len)
+ msa_mask = pad_dim(msa_mask, 0, pad_len)
+ has_deletion = pad_dim(has_deletion, 0, pad_len)
+ deletion = pad_dim(deletion, 0, pad_len)
+
+ # Pad in the token dimension (dim=1)
+ if max_tokens is not None:
+ pad_len = max_tokens - msa.shape[1]
+ if pad_len > 0:
+ msa = pad_dim(msa, 1, pad_len, const.token_ids["-"])
+ paired = pad_dim(paired, 1, pad_len)
+ msa_mask = pad_dim(msa_mask, 1, pad_len)
+ has_deletion = pad_dim(has_deletion, 1, pad_len)
+ deletion = pad_dim(deletion, 1, pad_len)
+ profile = pad_dim(profile, 0, pad_len)
+ deletion_mean = pad_dim(deletion_mean, 0, pad_len)
+
+ return {
+ "msa": msa,
+ "msa_paired": paired,
+ "deletion_value": deletion,
+ "has_deletion": has_deletion,
+ "deletion_mean": deletion_mean,
+ "profile": profile,
+ "msa_mask": msa_mask,
+ }
+
+
+def process_symmetry_features(
+ cropped: Tokenized, symmetries: dict
+) -> dict[str, Tensor]:
+ """Get the symmetry features.
+
+ Parameters
+ ----------
+ data : Tokenized
+ The tokenized data.
+
+ Returns
+ -------
+ dict[str, Tensor]
+ The symmetry features.
+
+ """
+ features = get_chain_symmetries(cropped)
+ features.update(get_amino_acids_symmetries(cropped))
+ features.update(get_ligand_symmetries(cropped, symmetries))
+
+ return features
+
+
+class BoltzFeaturizer:
+ """Boltz featurizer."""
+
+ def process(
+ self,
+ data: Tokenized,
+ training: bool,
+ max_seqs: int = 4096,
+ atoms_per_window_queries: int = 32,
+ min_dist: float = 2.0,
+ max_dist: float = 22.0,
+ num_bins: int = 64,
+ max_tokens: Optional[int] = None,
+ max_atoms: Optional[int] = None,
+ pad_to_max_seqs: bool = False,
+ compute_symmetries: bool = False,
+ symmetries: Optional[dict] = None,
+ binder_pocket_conditioned_prop: Optional[float] = 0.0,
+ binder_pocket_cutoff: Optional[float] = 6.0,
+ binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
+ only_ligand_binder_pocket: Optional[bool] = False,
+ inference_binder: Optional[int] = None,
+ inference_pocket: Optional[list[tuple[int, int]]] = None,
+ ) -> dict[str, Tensor]:
+ """Compute features.
+
+ Parameters
+ ----------
+ data : Tokenized
+ The tokenized data.
+ training : bool
+ Whether the model is in training mode.
+ max_tokens : int, optional
+ The maximum number of tokens.
+ max_atoms : int, optional
+ The maximum number of atoms
+ max_seqs : int, optional
+ The maximum number of sequences.
+
+ Returns
+ -------
+ dict[str, Tensor]
+ The features for model training.
+
+ """
+ # Compute random number of sequences
+ if training and max_seqs is not None:
+ max_seqs_batch = np.random.randint(1, max_seqs + 1) # noqa: NPY002
+ else:
+ max_seqs_batch = max_seqs
+
+ # Compute token features
+ token_features = process_token_features(
+ data,
+ max_tokens,
+ binder_pocket_conditioned_prop,
+ binder_pocket_cutoff,
+ binder_pocket_sampling_geometric_p,
+ only_ligand_binder_pocket,
+ inference_binder=inference_binder,
+ inference_pocket=inference_pocket,
+ )
+
+ # Compute atom features
+ atom_features = process_atom_features(
+ data,
+ atoms_per_window_queries,
+ min_dist,
+ max_dist,
+ num_bins,
+ max_atoms,
+ max_tokens,
+ )
+
+ # Compute MSA features
+ msa_features = process_msa_features(
+ data,
+ max_seqs_batch,
+ max_seqs,
+ max_tokens,
+ pad_to_max_seqs,
+ )
+
+ # Compute symmetry features
+ symmetry_features = {}
+ if compute_symmetries:
+ symmetry_features = process_symmetry_features(data, symmetries)
+
+ return {
+ **token_features,
+ **atom_features,
+ **msa_features,
+ **symmetry_features,
+ }
diff --git a/src/boltz/data/feature/pad.py b/src/boltz/data/feature/pad.py
new file mode 100644
index 0000000..db6edbc
--- /dev/null
+++ b/src/boltz/data/feature/pad.py
@@ -0,0 +1,84 @@
+import torch
+from torch import Tensor
+from torch.nn.functional import pad
+
+
+def pad_dim(data: Tensor, dim: int, pad_len: float, value: float = 0) -> Tensor:
+ """Pad a tensor along a given dimension.
+
+ Parameters
+ ----------
+ data : Tensor
+ The input tensor.
+ dim : int
+ The dimension to pad.
+ pad_len : float
+ The padding length.
+ value : int, optional
+ The value to pad with.
+
+ Returns
+ -------
+ Tensor
+ The padded tensor.
+
+ """
+ if pad_len == 0:
+ return data
+
+ total_dims = len(data.shape)
+ padding = [0] * (2 * (total_dims - dim))
+ padding[2 * (total_dims - 1 - dim) + 1] = pad_len
+ return pad(data, tuple(padding), value=value)
+
+
+def pad_to_max(data: list[Tensor], value: float = 0) -> tuple[Tensor, Tensor]:
+ """Pad the data in all dimensions to the maximum found.
+
+ Parameters
+ ----------
+ data : List[Tensor]
+ List of tensors to pad.
+ value : float
+ The value to use for padding.
+
+ Returns
+ -------
+ Tensor
+ The padded tensor.
+ Tensor
+ The padding mask.
+
+ """
+ if isinstance(data[0], str):
+ return data, 0
+
+ # Check if all have the same shape
+ if all(d.shape == data[0].shape for d in data):
+ return torch.stack(data, dim=0), 0
+
+ # Get the maximum in each dimension
+ num_dims = len(data[0].shape)
+ max_dims = [max(d.shape[i] for d in data) for i in range(num_dims)]
+
+ # Get the padding lengths
+ pad_lengths = []
+ for d in data:
+ dims = []
+ for i in range(num_dims):
+ dims.append(0)
+ dims.append(max_dims[num_dims - i - 1] - d.shape[num_dims - i - 1])
+ pad_lengths.append(dims)
+
+ # Pad the data
+ padding = [
+ pad(torch.ones_like(d), pad_len, value=0)
+ for d, pad_len in zip(data, pad_lengths)
+ ]
+ data = [pad(d, pad_len, value=value) for d, pad_len in zip(data, pad_lengths)]
+
+ # Stack the data
+ padding = torch.stack(padding, dim=0)
+ data = torch.stack(data, dim=0)
+
+ return data, padding
diff --git a/src/boltz/data/feature/symmetry.py b/src/boltz/data/feature/symmetry.py
new file mode 100644
index 0000000..a257ee8
--- /dev/null
+++ b/src/boltz/data/feature/symmetry.py
@@ -0,0 +1,602 @@
+import itertools
+import pickle
+import random
+from pathlib import Path
+
+import numpy as np
+import torch
+
+from boltz.data import const
+from boltz.data.feature.pad import pad_dim
+from boltz.model.loss.confidence import lddt_dist
+from boltz.model.loss.validation import weighted_minimum_rmsd_single
+
+
+def convert_atom_name(name: str) -> tuple[int, int, int, int]:
+ """Convert an atom name to a standard format.
+
+ Parameters
+ ----------
+ name : str
+ The atom name.
+
+ Returns
+ -------
+ Tuple[int, int, int, int]
+ The converted atom name.
+
+ """
+ name = name.strip()
+ name = [ord(c) - 32 for c in name]
+ name = name + [0] * (4 - len(name))
+ return tuple(name)
+
+
+def get_symmetries(path: str) -> dict:
+ """Create a dictionary for the ligand symmetries.
+
+ Parameters
+ ----------
+ path : str
+ The path to the ligand symmetries.
+
+ Returns
+ -------
+ dict
+ The ligand symmetries.
+
+ """
+ with Path(path).open("rb") as f:
+ data: dict = pickle.load(f) # noqa: S301
+
+ symmetries = {}
+ for key, mol in data.items():
+ try:
+ serialized_sym = bytes.fromhex(mol.GetProp("symmetries"))
+ sym = pickle.loads(serialized_sym) # noqa: S301
+ atom_names = []
+ for atom in mol.GetAtoms():
+ # Get atom name
+ atom_name = convert_atom_name(atom.GetProp("name"))
+ atom_names.append(atom_name)
+
+ symmetries[key] = (sym, atom_names)
+ except Exception: # noqa: BLE001, PERF203, S110
+ pass
+
+ return symmetries
+
+
+def compute_symmetry_idx_dictionary(data):
+ # Compute the symmetry index dictionary
+ total_count = 0
+ all_coords = []
+ for i, chain in enumerate(data.chains):
+ chain.start_idx = total_count
+ for j, token in enumerate(chain.tokens):
+ token.start_idx = total_count - chain.start_idx
+ all_coords.extend(
+ [[atom.coords.x, atom.coords.y, atom.coords.z] for atom in token.atoms]
+ )
+ total_count += len(token.atoms)
+ return all_coords
+
+
+def get_current_idx_list(data):
+ idx = []
+ for chain in data.chains:
+ if chain.in_crop:
+ for token in chain.tokens:
+ if token.in_crop:
+ idx.extend(
+ [
+ chain.start_idx + token.start_idx + i
+ for i in range(len(token.atoms))
+ ]
+ )
+ return idx
+
+
+def all_different_after_swap(l):
+ final = [s[-1] for s in l]
+ return len(final) == len(set(final))
+
+
+def minimum_symmetry_coords(
+ coords: torch.Tensor,
+ feats: dict,
+ index_batch: int,
+ **args_rmsd,
+):
+ all_coords = feats["all_coords"][index_batch].unsqueeze(0).to(coords)
+ all_resolved_mask = (
+ feats["all_resolved_mask"][index_batch].to(coords).to(torch.bool)
+ )
+ crop_to_all_atom_map = (
+ feats["crop_to_all_atom_map"][index_batch].to(coords).to(torch.long)
+ )
+ chain_symmetries = feats["chain_symmetries"][index_batch]
+ amino_acids_symmetries = feats["amino_acids_symmetries"][index_batch]
+ ligand_symmetries = feats["ligand_symmetries"][index_batch]
+
+ # Check best symmetry on chain swap
+ best_true_coords = None
+ best_rmsd = float("inf")
+ best_align_weights = None
+ for c in chain_symmetries:
+ true_all_coords = all_coords.clone()
+ true_all_resolved_mask = all_resolved_mask.clone()
+ for start1, end1, start2, end2, chainidx1, chainidx2 in c:
+ true_all_coords[:, start1:end1] = all_coords[:, start2:end2]
+ true_all_resolved_mask[start1:end1] = all_resolved_mask[start2:end2]
+ true_coords = true_all_coords[:, crop_to_all_atom_map]
+ true_resolved_mask = true_all_resolved_mask[crop_to_all_atom_map]
+ true_coords = pad_dim(true_coords, 1, coords.shape[1] - true_coords.shape[1])
+ true_resolved_mask = pad_dim(
+ true_resolved_mask,
+ 0,
+ coords.shape[1] - true_resolved_mask.shape[0],
+ )
+ try:
+ rmsd, aligned_coords, align_weights = weighted_minimum_rmsd_single(
+ coords,
+ true_coords,
+ atom_mask=true_resolved_mask,
+ atom_to_token=feats["atom_to_token"][index_batch : index_batch + 1],
+ mol_type=feats["mol_type"][index_batch : index_batch + 1],
+ **args_rmsd,
+ )
+ except:
+ print("Warning: error in rmsd computation inside symmetry code")
+ continue
+ rmsd = rmsd.item()
+
+ if rmsd < best_rmsd:
+ best_rmsd = rmsd
+ best_true_coords = aligned_coords
+ best_align_weights = align_weights
+ best_true_resolved_mask = true_resolved_mask
+
+ # atom symmetries (nucleic acid and protein residues), resolved greedily without recomputing alignment
+ true_coords = best_true_coords.clone()
+ true_resolved_mask = best_true_resolved_mask.clone()
+ for symmetric_amino in amino_acids_symmetries:
+ for c in symmetric_amino:
+ # starting from greedy best, try to swap the atoms
+ new_true_coords = true_coords.clone()
+ new_true_resolved_mask = true_resolved_mask.clone()
+ for i, j in c:
+ new_true_coords[:, i] = true_coords[:, j]
+ new_true_resolved_mask[i] = true_resolved_mask[j]
+
+ # compute squared distance, for efficiency we do not recompute the alignment
+ best_mse_loss = torch.sum(
+ ((coords - best_true_coords) ** 2).sum(dim=-1)
+ * best_align_weights
+ * best_true_resolved_mask,
+ dim=-1,
+ ) / torch.sum(best_align_weights * best_true_resolved_mask, dim=-1)
+ new_mse_loss = torch.sum(
+ ((coords - new_true_coords) ** 2).sum(dim=-1)
+ * best_align_weights
+ * new_true_resolved_mask,
+ dim=-1,
+ ) / torch.sum(best_align_weights * new_true_resolved_mask, dim=-1)
+
+ if best_mse_loss > new_mse_loss:
+ best_true_coords = new_true_coords
+ best_true_resolved_mask = new_true_resolved_mask
+
+ # greedily update best coordinates after each amino acid
+ true_coords = best_true_coords.clone()
+ true_resolved_mask = best_true_resolved_mask.clone()
+
+ # Recomputing alignment
+ rmsd, true_coords, best_align_weights = weighted_minimum_rmsd_single(
+ coords,
+ true_coords,
+ atom_mask=true_resolved_mask,
+ atom_to_token=feats["atom_to_token"][index_batch : index_batch + 1],
+ mol_type=feats["mol_type"][index_batch : index_batch + 1],
+ **args_rmsd,
+ )
+ best_rmsd = rmsd.item()
+
+ # atom symmetries (ligand and non-standard), resolved greedily recomputing alignment
+ for symmetric_ligand in ligand_symmetries:
+ for c in symmetric_ligand:
+ new_true_coords = true_coords.clone()
+ new_true_resolved_mask = true_resolved_mask.clone()
+ for i, j in c:
+ new_true_coords[:, j] = true_coords[:, i]
+ new_true_resolved_mask[j] = true_resolved_mask[i]
+ try:
+ # TODO if this is too slow maybe we can get away with not recomputing alignment
+ rmsd, aligned_coords, align_weights = weighted_minimum_rmsd_single(
+ coords,
+ new_true_coords,
+ atom_mask=new_true_resolved_mask,
+ atom_to_token=feats["atom_to_token"][index_batch : index_batch + 1],
+ mol_type=feats["mol_type"][index_batch : index_batch + 1],
+ **args_rmsd,
+ )
+ except Exception as e:
+ raise e
+ print(e)
+ continue
+ rmsd = rmsd.item()
+ if rmsd < best_rmsd:
+ best_true_coords = aligned_coords
+ best_rmsd = rmsd
+ best_true_resolved_mask = new_true_resolved_mask
+
+ true_coords = best_true_coords.clone()
+ true_resolved_mask = best_true_resolved_mask.clone()
+
+ return best_true_coords, best_rmsd, best_true_resolved_mask.unsqueeze(0)
+
+
+def minimum_lddt_symmetry_coords(
+ coords: torch.Tensor,
+ feats: dict,
+ index_batch: int,
+ **args_rmsd,
+):
+ all_coords = feats["all_coords"][index_batch].unsqueeze(0).to(coords)
+ all_resolved_mask = (
+ feats["all_resolved_mask"][index_batch].to(coords).to(torch.bool)
+ )
+ crop_to_all_atom_map = (
+ feats["crop_to_all_atom_map"][index_batch].to(coords).to(torch.long)
+ )
+ chain_symmetries = feats["chain_symmetries"][index_batch]
+ amino_acids_symmetries = feats["amino_acids_symmetries"][index_batch]
+ ligand_symmetries = feats["ligand_symmetries"][index_batch]
+
+ dmat_predicted = torch.cdist(
+ coords[:, : len(crop_to_all_atom_map)], coords[:, : len(crop_to_all_atom_map)]
+ )
+
+ # Check best symmetry on chain swap
+ best_true_coords = None
+ best_lddt = 0
+ for c in chain_symmetries:
+ true_all_coords = all_coords.clone()
+ true_all_resolved_mask = all_resolved_mask.clone()
+ for start1, end1, start2, end2, chainidx1, chainidx2 in c:
+ true_all_coords[:, start1:end1] = all_coords[:, start2:end2]
+ true_all_resolved_mask[start1:end1] = all_resolved_mask[start2:end2]
+ true_coords = true_all_coords[:, crop_to_all_atom_map]
+ true_resolved_mask = true_all_resolved_mask[crop_to_all_atom_map]
+ dmat_true = torch.cdist(true_coords, true_coords)
+ pair_mask = (
+ true_resolved_mask[:, None]
+ * true_resolved_mask[None, :]
+ * (1 - torch.eye(len(true_resolved_mask))).to(true_resolved_mask)
+ )
+
+ lddt = lddt_dist(
+ dmat_predicted, dmat_true, pair_mask, cutoff=15.0, per_atom=False
+ )[0]
+ lddt = lddt.item()
+
+ if lddt > best_lddt:
+ best_lddt = lddt
+ best_true_coords = true_coords
+ best_true_resolved_mask = true_resolved_mask
+
+ # atom symmetries (nucleic acid and protein residues), resolved greedily without recomputing alignment
+ true_coords = best_true_coords.clone()
+ true_resolved_mask = best_true_resolved_mask.clone()
+ for symmetric_amino_or_lig in amino_acids_symmetries + ligand_symmetries:
+ for c in symmetric_amino_or_lig:
+ # starting from greedy best, try to swap the atoms
+ new_true_coords = true_coords.clone()
+ new_true_resolved_mask = true_resolved_mask.clone()
+ indices = []
+ for i, j in c:
+ new_true_coords[:, i] = true_coords[:, j]
+ new_true_resolved_mask[i] = true_resolved_mask[j]
+ indices.append(i)
+
+ indices = (
+ torch.from_numpy(np.asarray(indices)).to(new_true_coords.device).long()
+ )
+
+ pred_coords_subset = coords[:, : len(crop_to_all_atom_map)][:, indices]
+ true_coords_subset = true_coords[:, indices]
+ new_true_coords_subset = new_true_coords[:, indices]
+
+ sub_dmat_pred = torch.cdist(
+ coords[:, : len(crop_to_all_atom_map)], pred_coords_subset
+ )
+ sub_dmat_true = torch.cdist(true_coords, true_coords_subset)
+ sub_dmat_new_true = torch.cdist(new_true_coords, new_true_coords_subset)
+
+ sub_true_pair_lddt = (
+ true_resolved_mask[:, None] * true_resolved_mask[None, indices]
+ )
+ sub_true_pair_lddt[indices] = (
+ sub_true_pair_lddt[indices]
+ * (1 - torch.eye(len(indices))).to(sub_true_pair_lddt).bool()
+ )
+
+ sub_new_true_pair_lddt = (
+ new_true_resolved_mask[:, None] * new_true_resolved_mask[None, indices]
+ )
+ sub_new_true_pair_lddt[indices] = (
+ sub_new_true_pair_lddt[indices]
+ * (1 - torch.eye(len(indices))).to(sub_true_pair_lddt).bool()
+ )
+
+ lddt = lddt_dist(
+ sub_dmat_pred,
+ sub_dmat_true,
+ sub_true_pair_lddt,
+ cutoff=15.0,
+ per_atom=False,
+ )[0]
+ new_lddt = lddt_dist(
+ sub_dmat_pred,
+ sub_dmat_new_true,
+ sub_new_true_pair_lddt,
+ cutoff=15.0,
+ per_atom=False,
+ )[0]
+
+ if new_lddt > lddt:
+ best_true_coords = new_true_coords
+ best_true_resolved_mask = new_true_resolved_mask
+
+ # greedily update best coordinates after each amino acid
+ true_coords = best_true_coords.clone()
+ true_resolved_mask = best_true_resolved_mask.clone()
+
+ # Recomputing alignment
+ true_coords = pad_dim(true_coords, 1, coords.shape[1] - true_coords.shape[1])
+ true_resolved_mask = pad_dim(
+ true_resolved_mask,
+ 0,
+ coords.shape[1] - true_resolved_mask.shape[0],
+ )
+
+ try:
+ rmsd, true_coords, _ = weighted_minimum_rmsd_single(
+ coords,
+ true_coords,
+ atom_mask=true_resolved_mask,
+ atom_to_token=feats["atom_to_token"][index_batch : index_batch + 1],
+ mol_type=feats["mol_type"][index_batch : index_batch + 1],
+ **args_rmsd,
+ )
+ best_rmsd = rmsd.item()
+ except Exception as e:
+ print("Failed proper RMSD computation, returning inf. Error: ", e)
+ best_rmsd = 1000
+
+ return true_coords, best_rmsd, true_resolved_mask.unsqueeze(0)
+
+
+def compute_all_coords_mask(structure):
+ # Compute all coords, crop mask and add start_idx to structure
+ total_count = 0
+ all_coords = []
+ all_coords_crop_mask = []
+ all_resolved_mask = []
+ for i, chain in enumerate(structure.chains):
+ chain.start_idx = total_count
+ for j, token in enumerate(chain.tokens):
+ token.start_idx = total_count - chain.start_idx
+ all_coords.extend(
+ [[atom.coords.x, atom.coords.y, atom.coords.z] for atom in token.atoms]
+ )
+ all_coords_crop_mask.extend(
+ [token.in_crop for _ in range(len(token.atoms))]
+ )
+ all_resolved_mask.extend(
+ [token.is_present for _ in range(len(token.atoms))]
+ )
+ total_count += len(token.atoms)
+ if len(all_coords_crop_mask) != len(all_resolved_mask):
+ pass
+ return all_coords, all_coords_crop_mask, all_resolved_mask
+
+
+def get_chain_symmetries(cropped, max_n_symmetries=100):
+ # get all coordinates and resolved mask
+ structure = cropped.structure
+ all_coords = []
+ all_resolved_mask = []
+ original_atom_idx = []
+ chain_atom_idx = []
+ chain_atom_num = []
+ chain_in_crop = []
+ chain_asym_id = []
+ new_atom_idx = 0
+
+ for chain in structure.chains:
+ atom_idx, atom_num = (
+ chain["atom_idx"],
+ chain["atom_num"],
+ )
+
+ # compute coordinates and resolved mask
+ resolved_mask = structure.atoms["is_present"][atom_idx : atom_idx + atom_num]
+
+ # ensemble_atom_starts = [structure.ensemble[idx]["atom_coord_idx"] for idx in cropped.ensemble_ref_idxs]
+ # coords = np.array(
+ # [structure.coords[ensemble_atom_start + atom_idx: ensemble_atom_start + atom_idx + atom_num]["coords"] for
+ # ensemble_atom_start in ensemble_atom_starts])
+
+ coords = structure.atoms["coords"][atom_idx : atom_idx + atom_num]
+
+ in_crop = False
+ for token in cropped.tokens:
+ if token["asym_id"] == chain["asym_id"]:
+ in_crop = True
+ break
+
+ all_coords.append(coords)
+ all_resolved_mask.append(resolved_mask)
+ original_atom_idx.append(atom_idx)
+ chain_atom_idx.append(new_atom_idx)
+ chain_atom_num.append(atom_num)
+ chain_in_crop.append(in_crop)
+ chain_asym_id.append(chain["asym_id"])
+
+ new_atom_idx += atom_num
+
+ # Compute backmapping from token to all coords
+ crop_to_all_atom_map = []
+ for token in cropped.tokens:
+ chain_idx = chain_asym_id.index(token["asym_id"])
+ start = (
+ chain_atom_idx[chain_idx] - original_atom_idx[chain_idx] + token["atom_idx"]
+ )
+ crop_to_all_atom_map.append(np.arange(start, start + token["atom_num"]))
+
+ # Compute the symmetries between chains
+ swaps = []
+ for i, chain in enumerate(structure.chains):
+ start = chain_atom_idx[i]
+ end = start + chain_atom_num[i]
+ if chain_in_crop[i]:
+ possible_swaps = []
+ for j, chain2 in enumerate(structure.chains):
+ start2 = chain_atom_idx[j]
+ end2 = start2 + chain_atom_num[j]
+ if (
+ chain["entity_id"] == chain2["entity_id"]
+ and end - start == end2 - start2
+ ):
+ possible_swaps.append((start, end, start2, end2, i, j))
+ swaps.append(possible_swaps)
+ combinations = itertools.product(*swaps)
+ # to avoid combinatorial explosion, bound the number of combinations even considered
+ combinations = list(itertools.islice(combinations, max_n_symmetries * 10))
+ # filter for all chains getting a different assignment
+ combinations = [c for c in combinations if all_different_after_swap(c)]
+
+ if len(combinations) > max_n_symmetries:
+ combinations = random.sample(combinations, max_n_symmetries)
+
+ if len(combinations) == 0:
+ combinations.append([])
+
+ features = {}
+ features["all_coords"] = torch.Tensor(
+ np.concatenate(all_coords, axis=0)
+ ) # axis=1 with ensemble
+
+ features["all_resolved_mask"] = torch.Tensor(
+ np.concatenate(all_resolved_mask, axis=0)
+ )
+ features["crop_to_all_atom_map"] = torch.Tensor(
+ np.concatenate(crop_to_all_atom_map, axis=0)
+ )
+ features["chain_symmetries"] = combinations
+
+ return features
+
+
+def get_amino_acids_symmetries(cropped):
+ # Compute standard amino-acids symmetries
+ swaps = []
+ start_index_crop = 0
+ for token in cropped.tokens:
+ symmetries = const.ref_symmetries.get(const.tokens[token["res_type"]], [])
+ if len(symmetries) > 0:
+ residue_swaps = []
+ for sym in symmetries:
+ sym_new_idx = [
+ (i + start_index_crop, j + start_index_crop) for i, j in sym
+ ]
+ residue_swaps.append(sym_new_idx)
+ swaps.append(residue_swaps)
+ start_index_crop += token["atom_num"]
+
+ features = {"amino_acids_symmetries": swaps}
+ return features
+
+
+def get_ligand_symmetries(cropped, symmetries):
+ # Compute ligand and non-standard amino-acids symmetries
+ structure = cropped.structure
+
+ added_molecules = {}
+ index_mols = []
+ atom_count = 0
+ for token in cropped.tokens:
+ # check if molecule is already added by identifying it through asym_id and res_idx
+ atom_count += token["atom_num"]
+ mol_id = (token["asym_id"], token["res_idx"])
+ if mol_id in added_molecules.keys():
+ added_molecules[mol_id] += token["atom_num"]
+ continue
+ added_molecules[mol_id] = token["atom_num"]
+
+ # get the molecule type and indices
+ residue_idx = token["res_idx"] + structure.chains[token["asym_id"]]["res_idx"]
+ mol_name = structure.residues[residue_idx]["name"]
+ atom_idx = structure.residues[residue_idx]["atom_idx"]
+ mol_atom_names = structure.atoms[
+ atom_idx : atom_idx + structure.residues[residue_idx]["atom_num"]
+ ]["name"]
+ mol_atom_names = [tuple(m) for m in mol_atom_names]
+ if mol_name not in const.ref_symmetries.keys():
+ index_mols.append(
+ (mol_name, atom_count - token["atom_num"], mol_id, mol_atom_names)
+ )
+
+ # for each molecule, get the symmetries
+ molecule_symmetries = []
+ for mol_name, start_mol, mol_id, mol_atom_names in index_mols:
+ if not mol_name in symmetries:
+ continue
+ else:
+ swaps = []
+ syms_ccd, mol_atom_names_ccd = symmetries[mol_name]
+ # Get indices of mol_atom_names_ccd that are in mol_atom_names
+ ccd_to_valid_ids = {
+ mol_atom_names_ccd.index(name): i
+ for i, name in enumerate(mol_atom_names)
+ }
+ ccd_valid_ids = set(ccd_to_valid_ids.keys())
+
+ syms = []
+ # Get syms
+ for sym_ccd in syms_ccd:
+ sym_dict = {}
+ bool_add = True
+ for i, j in enumerate(sym_ccd):
+ if i in ccd_valid_ids:
+ if j in ccd_valid_ids:
+ i_true = ccd_to_valid_ids[i]
+ j_true = ccd_to_valid_ids[j]
+ sym_dict[i_true] = j_true
+ else:
+ bool_add = False
+ break
+ if bool_add:
+ syms.append([sym_dict[i] for i in range(len(ccd_valid_ids))])
+
+ for sym in syms:
+ if len(sym) != added_molecules[mol_id]:
+ raise Exception(
+ f"Symmetry length mismatch {len(sym)} {added_molecules[mol_id]}"
+ )
+ # assert (
+ # len(sym) == added_molecules[mol_id]
+ # ), f"Symmetry length mismatch {len(sym)} {added_molecules[mol_id]}"
+ sym_new_idx = []
+ for i, j in enumerate(sym):
+ if i != int(j):
+ sym_new_idx.append((i + start_mol, int(j) + start_mol))
+ if len(sym_new_idx) > 0:
+ swaps.append(sym_new_idx)
+ if len(swaps) > 0:
+ molecule_symmetries.append(swaps)
+
+ features = {"ligand_symmetries": molecule_symmetries}
+
+ return features
diff --git a/src/boltz/data/filter/__init__.py b/src/boltz/data/filter/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/boltz/data/filter/dynamic/__init__.py b/src/boltz/data/filter/dynamic/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/boltz/data/filter/dynamic/date.py b/src/boltz/data/filter/dynamic/date.py
new file mode 100644
index 0000000..d7a9f08
--- /dev/null
+++ b/src/boltz/data/filter/dynamic/date.py
@@ -0,0 +1,76 @@
+from datetime import datetime
+from typing import Literal
+
+from boltz.data.types import Record
+from boltz.data.filter.dynamic.filter import DynamicFilter
+
+
+class DateFilter(DynamicFilter):
+ """A filter that filters complexes based on their date.
+
+ The date can be the deposition, release, or revision date.
+ If the date is not available, the previous date is used.
+
+ If no date is available, the complex is rejected.
+
+ """
+
+ def __init__(
+ self,
+ date: str,
+ ref: Literal["deposited", "revised", "released"],
+ ) -> None:
+ """Initialize the filter.
+
+ Parameters
+ ----------
+ date : str, optional
+ The maximum date of PDB entries to filter
+ ref : Literal["deposited", "revised", "released"]
+ The reference date to use.
+
+ """
+ self.filter_date = datetime.fromisoformat(date)
+ self.ref = ref
+
+ if ref not in ["deposited", "revised", "released"]:
+ msg = (
+ "Invalid reference date. Must be ",
+ "deposited, revised, or released",
+ )
+ raise ValueError(msg)
+
+ def filter(self, record: Record) -> bool:
+ """Filter a record based on its date.
+
+ Parameters
+ ----------
+ record : Record
+ The record to filter.
+
+ Returns
+ -------
+ bool
+ Whether the record should be filtered.
+
+ """
+ structure = record.structure
+
+ if self.ref == "deposited":
+ date = structure.deposited
+ elif self.ref == "released":
+ date = structure.released
+ if not date:
+ date = structure.deposited
+ elif self.ref == "revised":
+ date = structure.revised
+ if not date and structure.released:
+ date = structure.released
+ elif not date:
+ date = structure.deposited
+
+ if date is None or date == "":
+ return False
+
+ date = datetime.fromisoformat(date)
+ return date <= self.filter_date
diff --git a/src/boltz/data/filter/dynamic/filter.py b/src/boltz/data/filter/dynamic/filter.py
new file mode 100644
index 0000000..0060922
--- /dev/null
+++ b/src/boltz/data/filter/dynamic/filter.py
@@ -0,0 +1,24 @@
+from abc import ABC, abstractmethod
+
+from boltz.data.types import Record
+
+
+class DynamicFilter(ABC):
+ """Base class for data filters."""
+
+ @abstractmethod
+ def filter(self, record: Record) -> bool:
+ """Filter a data record.
+
+ Parameters
+ ----------
+ record : Record
+ The object to consider filtering in / out.
+
+ Returns
+ -------
+ bool
+ True if the data passes the filter, False otherwise.
+
+ """
+ raise NotImplementedError
diff --git a/src/boltz/data/filter/dynamic/max_residues.py b/src/boltz/data/filter/dynamic/max_residues.py
new file mode 100644
index 0000000..b8397d4
--- /dev/null
+++ b/src/boltz/data/filter/dynamic/max_residues.py
@@ -0,0 +1,37 @@
+from boltz.data.types import Record
+from boltz.data.filter.dynamic.filter import DynamicFilter
+
+
+class MaxResiduesFilter(DynamicFilter):
+ """A filter that filters structures based on their size."""
+
+ def __init__(self, min_residues: int = 1, max_residues: int = 500) -> None:
+ """Initialize the filter.
+
+ Parameters
+ ----------
+ min_chains : int
+ The minimum number of chains allowed.
+ max_chains : int
+ The maximum number of chains allowed.
+
+ """
+ self.min_residues = min_residues
+ self.max_residues = max_residues
+
+ def filter(self, record: Record) -> bool:
+ """Filter structures based on their resolution.
+
+ Parameters
+ ----------
+ record : Record
+ The record to filter.
+
+ Returns
+ -------
+ bool
+ Whether the record should be filtered.
+
+ """
+ num_residues = sum(chain.num_residues for chain in record.chains)
+ return num_residues <= self.max_residues and num_residues >= self.min_residues
diff --git a/src/boltz/data/filter/dynamic/resolution.py b/src/boltz/data/filter/dynamic/resolution.py
new file mode 100644
index 0000000..8096d6a
--- /dev/null
+++ b/src/boltz/data/filter/dynamic/resolution.py
@@ -0,0 +1,34 @@
+from boltz.data.types import Record
+from boltz.data.filter.dynamic.filter import DynamicFilter
+
+
+class ResolutionFilter(DynamicFilter):
+ """A filter that filters complexes based on their resolution."""
+
+ def __init__(self, resolution: float = 9.0) -> None:
+ """Initialize the filter.
+
+ Parameters
+ ----------
+ resolution : float, optional
+ The maximum allowed resolution.
+
+ """
+ self.resolution = resolution
+
+ def filter(self, record: Record) -> bool:
+ """Filter complexes based on their resolution.
+
+ Parameters
+ ----------
+ record : Record
+ The record to filter.
+
+ Returns
+ -------
+ bool
+ Whether the record should be filtered.
+
+ """
+ structure = record.structure
+ return structure.resolution <= self.resolution
diff --git a/src/boltz/data/filter/dynamic/size.py b/src/boltz/data/filter/dynamic/size.py
new file mode 100644
index 0000000..8d1094e
--- /dev/null
+++ b/src/boltz/data/filter/dynamic/size.py
@@ -0,0 +1,38 @@
+from boltz.data.types import Record
+from boltz.data.filter.dynamic.filter import DynamicFilter
+
+
+class SizeFilter(DynamicFilter):
+ """A filter that filters structures based on their size."""
+
+ def __init__(self, min_chains: int = 1, max_chains: int = 300) -> None:
+ """Initialize the filter.
+
+ Parameters
+ ----------
+ min_chains : int
+ The minimum number of chains allowed.
+ max_chains : int
+ The maximum number of chains allowed.
+
+ """
+ self.min_chains = min_chains
+ self.max_chains = max_chains
+
+ def filter(self, record: Record) -> bool:
+ """Filter structures based on their resolution.
+
+ Parameters
+ ----------
+ record : Record
+ The record to filter.
+
+ Returns
+ -------
+ bool
+ Whether the record should be filtered.
+
+ """
+ num_chains = record.structure.num_chains
+ num_valid = sum(1 for chain in record.chains if chain.valid)
+ return num_chains <= self.max_chains and num_valid >= self.min_chains
diff --git a/src/boltz/data/filter/dynamic/subset.py b/src/boltz/data/filter/dynamic/subset.py
new file mode 100644
index 0000000..53e1260
--- /dev/null
+++ b/src/boltz/data/filter/dynamic/subset.py
@@ -0,0 +1,42 @@
+from pathlib import Path
+
+from boltz.data.types import Record
+from boltz.data.filter.dynamic.filter import DynamicFilter
+
+
+class SubsetFilter(DynamicFilter):
+ """Filter a data record based on a subset of the data."""
+
+ def __init__(self, subset: str, reverse: bool = False) -> None:
+ """Initialize the filter.
+
+ Parameters
+ ----------
+ subset : str
+ The subset of data to consider, one per line.
+
+ """
+ with Path(subset).open("r") as f:
+ subset = f.read().splitlines()
+
+ self.subset = {s.lower() for s in subset}
+ self.reverse = reverse
+
+ def filter(self, record: Record) -> bool:
+ """Filter a data record.
+
+ Parameters
+ ----------
+ record : Record
+ The object to consider filtering in / out.
+
+ Returns
+ -------
+ bool
+ True if the data passes the filter, False otherwise.
+
+ """
+ if self.reverse:
+ return record.id.lower() not in self.subset
+ else: # noqa: RET505
+ return record.id.lower() in self.subset
diff --git a/src/boltz/data/filter/static/__init__.py b/src/boltz/data/filter/static/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/boltz/data/filter/static/filter.py b/src/boltz/data/filter/static/filter.py
new file mode 100644
index 0000000..2967e4e
--- /dev/null
+++ b/src/boltz/data/filter/static/filter.py
@@ -0,0 +1,26 @@
+from abc import ABC, abstractmethod
+
+import numpy as np
+
+from boltz.data.types import Structure
+
+
+class StaticFilter(ABC):
+ """Base class for structure filters."""
+
+ @abstractmethod
+ def filter(self, structure: Structure) -> np.ndarray:
+ """Filter chains in a structure.
+
+ Parameters
+ ----------
+ structure : Structure
+ The structure to filter chains from.
+
+ Returns
+ -------
+ np.ndarray
+ The chains to keep, as a boolean mask.
+
+ """
+ raise NotImplementedError
diff --git a/src/boltz/data/filter/static/ligand.py b/src/boltz/data/filter/static/ligand.py
new file mode 100644
index 0000000..8a652a7
--- /dev/null
+++ b/src/boltz/data/filter/static/ligand.py
@@ -0,0 +1,171 @@
+import numpy as np
+
+from boltz.data import const
+from boltz.data.types import Structure
+from boltz.data.filter.static.filter import StaticFilter
+
+LIGAND_EXCLUSION = {
+ "144",
+ "15P",
+ "1PE",
+ "2F2",
+ "2JC",
+ "3HR",
+ "3SY",
+ "7N5",
+ "7PE",
+ "9JE",
+ "AAE",
+ "ABA",
+ "ACE",
+ "ACN",
+ "ACT",
+ "ACY",
+ "AZI",
+ "BAM",
+ "BCN",
+ "BCT",
+ "BDN",
+ "BEN",
+ "BME",
+ "BO3",
+ "BTB",
+ "BTC",
+ "BU1",
+ "C8E",
+ "CAD",
+ "CAQ",
+ "CBM",
+ "CCN",
+ "CIT",
+ "CL",
+ "CLR",
+ "CM",
+ "CMO",
+ "CO3",
+ "CPT",
+ "CXS",
+ "D10",
+ "DEP",
+ "DIO",
+ "DMS",
+ "DN",
+ "DOD",
+ "DOX",
+ "EDO",
+ "EEE",
+ "EGL",
+ "EOH",
+ "EOX",
+ "EPE",
+ "ETF",
+ "FCY",
+ "FJO",
+ "FLC",
+ "FMT",
+ "FW5",
+ "GOL",
+ "GSH",
+ "GTT",
+ "GYF",
+ "HED",
+ "IHP",
+ "IHS",
+ "IMD",
+ "IOD",
+ "IPA",
+ "IPH",
+ "LDA",
+ "MB3",
+ "MEG",
+ "MES",
+ "MLA",
+ "MLI",
+ "MOH",
+ "MPD",
+ "MRD",
+ "MSE",
+ "MYR",
+ "N",
+ "NA",
+ "NH2",
+ "NH4",
+ "NHE",
+ "NO3",
+ "O4B",
+ "OHE",
+ "OLA",
+ "OLC",
+ "OMB",
+ "OME",
+ "OXA",
+ "P6G",
+ "PE3",
+ "PE4",
+ "PEG",
+ "PEO",
+ "PEP",
+ "PG0",
+ "PG4",
+ "PGE",
+ "PGR",
+ "PLM",
+ "PO4",
+ "POL",
+ "POP",
+ "PVO",
+ "SAR",
+ "SCN",
+ "SEO",
+ "SEP",
+ "SIN",
+ "SO4",
+ "SPD",
+ "SPM",
+ "SR",
+ "STE",
+ "STO",
+ "STU",
+ "TAR",
+ "TBU",
+ "TME",
+ "TPO",
+ "TRS",
+ "UNK",
+ "UNL",
+ "UNX",
+ "UPL",
+ "URE",
+}
+
+
+class ExcludedLigands(StaticFilter):
+ """Filter excluded ligands."""
+
+ def filter(self, structure: Structure) -> np.ndarray:
+ """Filter excluded ligands.
+
+ Parameters
+ ----------
+ structure : Structure
+ The structure to filter chains from.
+
+ Returns
+ -------
+ np.ndarray
+ The chains to keep, as a boolean mask.
+
+ """
+ valid = np.ones(len(structure.chains), dtype=bool)
+
+ for i, chain in enumerate(structure.chains):
+ if chain["mol_type"] != const.chain_type_ids["NONPOLYMER"]:
+ continue
+
+ res_start = chain["res_idx"]
+ res_end = res_start + chain["res_num"]
+ residues = structure.residues[res_start:res_end]
+ if any(res["name"] in LIGAND_EXCLUSION for res in residues):
+ valid[i] = 0
+
+ return valid
diff --git a/src/boltz/data/filter/static/polymer.py b/src/boltz/data/filter/static/polymer.py
new file mode 100644
index 0000000..0ec5f84
--- /dev/null
+++ b/src/boltz/data/filter/static/polymer.py
@@ -0,0 +1,294 @@
+from dataclasses import dataclass
+from typing import List
+
+import numpy as np
+from scipy.spatial.distance import cdist
+
+from boltz.data import const
+from boltz.data.types import Structure
+from boltz.data.filter.static.filter import StaticFilter
+
+
+class MinimumLengthFilter(StaticFilter):
+ """Filter polymers based on their length.
+
+ We use the number of resolved residues when considering
+ the minimum, and the sequence length for the maximum.
+
+ """
+
+ def __init__(self, min_len: int = 4, max_len: int = 5000) -> None:
+ """Initialize the filter.
+
+ Parameters
+ ----------
+ min_len : float, optional
+ The minimum allowed length.
+ max_len : float, optional
+ The maximum allowed length.
+
+ """
+ self._min = min_len
+ self._max = max_len
+
+ def filter(self, structure: Structure) -> np.ndarray:
+ """Filter a chains based on their length.
+
+ Parameters
+ ----------
+ structure : Structure
+ The structure to filter chains from.
+
+ Returns
+ -------
+ np.ndarray
+ The chains to keep, as a boolean mask.
+
+ """
+ valid = np.ones(len(structure.chains), dtype=bool)
+
+ for i, chain in enumerate(structure.chains):
+ if chain["mol_type"] == const.chain_type_ids["NONPOLYMER"]:
+ continue
+
+ res_start = chain["res_idx"]
+ res_end = res_start + chain["res_num"]
+ residues = structure.residues[res_start:res_end]
+ resolved = residues["is_present"].sum()
+
+ if (resolved < self._min) or (resolved > self._max):
+ valid[i] = 0
+
+ return valid
+
+
+class UnknownFilter(StaticFilter):
+ """Filter proteins with all unknown residues."""
+
+ def filter(self, structure: Structure) -> np.ndarray:
+ """Filter proteins with all unknown residues.
+
+ Parameters
+ ----------
+ structure : Structure
+ The structure to filter chains from.
+
+ Returns
+ -------
+ np.ndarray
+ The chains to keep, as a boolean mask.
+
+ """
+ valid = np.ones(len(structure.chains), dtype=bool)
+ unk_toks = {
+ const.chain_type_ids["PROTEIN"]: const.unk_token_ids["PROTEIN"],
+ const.chain_type_ids["DNA"]: const.unk_token_ids["DNA"],
+ const.chain_type_ids["RNA"]: const.unk_token_ids["RNA"],
+ }
+
+ for i, chain in enumerate(structure.chains):
+ if chain["mol_type"] == const.chain_type_ids["NONPOLYMER"]:
+ continue
+
+ res_start = chain["res_idx"]
+ res_end = res_start + chain["res_num"]
+ residues = structure.residues[res_start:res_end]
+
+ unk_id = unk_toks[chain["mol_type"]]
+ if np.all(residues["res_type"] == unk_id):
+ valid[i] = 0
+
+ return valid
+
+
+class ConsecutiveCA(StaticFilter):
+ """Filter proteins with consecutive CA atoms above a threshold."""
+
+ def __init__(self, max_dist: int = 10.0) -> None:
+ """Initialize the filter.
+
+ Parameters
+ ----------
+ max_dist : float, optional
+ The maximum allowed distance.
+
+ """
+ self._max_dist = max_dist
+
+ def filter(self, structure: Structure) -> np.ndarray:
+ """Filter protein if consecutive CA atoms above a threshold.
+
+ Parameters
+ ----------
+ structure : Structure
+ The structure to filter chains from.
+
+ Returns
+ -------
+ np.ndarray
+ The chains to keep, as a boolean mask.
+
+ """
+ valid = np.ones(len(structure.chains), dtype=bool)
+
+ # Remove chain if consecutive CA atoms are above threshold
+ for i, chain in enumerate(structure.chains):
+ # Skip non-protein chains
+ if chain["mol_type"] != const.chain_type_ids["PROTEIN"]:
+ continue
+
+ # Get residues
+ res_start = chain["res_idx"]
+ res_end = res_start + chain["res_num"]
+ residues = structure.residues[res_start:res_end]
+
+ # Get c-alphas
+ ca_ids = residues["atom_center"]
+ ca_atoms = structure.atoms[ca_ids]
+
+ res_valid = residues["is_present"]
+ ca_valid = ca_atoms["is_present"] & res_valid
+ ca_coords = ca_atoms["coords"]
+
+ # Compute distances between consecutive atoms
+ dist = np.linalg.norm(ca_coords[1:] - ca_coords[:-1], axis=1)
+ dist = dist > self._max_dist
+ dist = dist[ca_valid[1:] & ca_valid[:-1]]
+
+ # Remove the chain if any valid pair is above threshold
+ if np.any(dist):
+ valid[i] = 0
+
+ return valid
+
+
+@dataclass(frozen=True)
+class Clash:
+ """A clash between two chains."""
+
+ chain: int
+ other: int
+ num_atoms: int
+ num_clashes: int
+
+
+class ClashingChainsFilter(StaticFilter):
+ """A filter that filters clashing chains.
+
+ Clashing chains are defined as those with >30% of atoms
+ within 1.7 Å of an atom in another chain. If two chains
+ are clashing with each other, the chain with the greater
+ percentage of clashing atoms will be removed. If the same
+ fraction of atoms are clashing, the chain with fewer total
+ atoms is removed. If the chains have the same number of
+ atoms, then the chain with the larger chain id is removed.
+
+ """
+
+ def __init__(self, dist: float = 1.7, freq: float = 0.3) -> None:
+ """Initialize the filter.
+
+ Parameters
+ ----------
+ dist : float, optional
+ The maximum distance for a clash.
+ freq : float, optional
+ The maximum allowed frequency of clashes.
+
+ """
+ self._dist = dist
+ self._freq = freq
+
+ def filter(self, structure: Structure) -> np.ndarray: # noqa: PLR0912, C901
+ """Filter out clashing chains.
+
+ Parameters
+ ----------
+ structure : Structure
+ The structure to filter chains from.
+
+ Returns
+ -------
+ np.ndarray
+ The chains to keep, as a boolean mask.
+
+ """
+ num_chains = len(structure.chains)
+ if num_chains < 2: # noqa: PLR2004
+ return np.ones(num_chains, dtype=bool)
+
+ # Get unique chain pairs
+ pairs = zip(range(num_chains), range(num_chains))
+ pairs = [(i, j) for i, j in pairs if i < j]
+
+ # Compute clashes
+ clashes: List[Clash] = []
+ for i, j in pairs:
+ # Get the chains
+ c1 = structure.chains[i]
+ c2 = structure.chains[j]
+
+ # Get the atoms from each chain
+ c1_start = c1["atom_idx"]
+ c2_start = c2["atom_idx"]
+ c1_end = c1_start + c1["atom_num"]
+ c2_end = c2_start + c2["atom_num"]
+
+ atoms1 = structure.atoms[c1_start:c1_end]
+ atoms2 = structure.atoms[c2_start:c2_end]
+ atoms1 = atoms1[atoms1["is_present"]]
+ atoms2 = atoms2[atoms2["is_present"]]
+
+ # Compute the number of clashes
+ dists = cdist(atoms1["coords"], atoms2["coords"])
+ clashes = dists < self._dist
+ c1_clashes = np.any(clashes, axis=1).sum().item()
+ c2_clashes = np.any(clashes, axis=0).sum().item()
+
+ # Save results
+ if (c1_clashes / len(atoms1)) > self._freq:
+ clashes.append(Clash(i, j, len(atoms1), c1_clashes))
+ if (c2_clashes / len(atoms2)) > self._freq:
+ clashes.append(Clash(j, i, len(atoms2), c2_clashes))
+
+ # Compute indices to clash map
+ removed = set()
+ ids_to_clash = {(c.chain, c.other): c for c in clashes}
+
+ # Filter out chains according to ruleset
+ for clash in clashes:
+ # If either is already removed, skip
+ if clash.chain in removed or clash.other in removed:
+ continue
+
+ # Check if the two chains clash with each other
+ other_clash = ids_to_clash.get((clash.other, clash.chain))
+ if other_clash is not None:
+ # Remove the chain with the most clashes
+ clash1_freq = clash.num_clashes / clash.num_atoms
+ clash2_freq = other_clash.num_clashes / other_clash.num_atoms
+ if clash1_freq > clash2_freq:
+ removed.add(clash.chain)
+ elif clash1_freq < clash2_freq:
+ removed.add(clash.other)
+
+ # If same, remove the chain with fewer atoms
+ elif clash.num_atoms < other_clash.num_atoms:
+ removed.add(clash.chain)
+ elif clash.num_atoms > other_clash.num_atoms:
+ removed.add(clash.other)
+
+ # If same, remove the chain with the larger chain id
+ else:
+ removed.add(max(clash.chain, clash.other))
+
+ # Otherwise, just remove the chain directly
+ else:
+ removed.add(clash.chain)
+
+ # Remove the chains
+ valid = np.ones(len(structure.chains), dtype=bool)
+ for i in removed:
+ valid[i] = 0
+
+ return valid
diff --git a/src/boltz/data/module/__init__.py b/src/boltz/data/module/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/boltz/data/module/inference.py b/src/boltz/data/module/inference.py
new file mode 100644
index 0000000..5c35c1f
--- /dev/null
+++ b/src/boltz/data/module/inference.py
@@ -0,0 +1,267 @@
+from pathlib import Path
+
+import numpy as np
+import pytorch_lightning as pl
+import torch
+from torch import Tensor
+from torch.utils.data import DataLoader
+
+from boltz.data import const
+from boltz.data.feature.featurizer import BoltzFeaturizer
+from boltz.data.feature.pad import pad_to_max
+from boltz.data.tokenize.boltz import BoltzTokenizer
+from boltz.data.types import MSA, Input, Manifest, Record, Structure
+
+
+def load_input(record: Record, target_dir: Path, msa_dir: Path) -> Input:
+ """Load the given input data.
+
+ Parameters
+ ----------
+ record : Record
+ The record to load.
+ target_dir : Path
+ The path to the data directory.
+ msa_dir : Path
+ The path to msa directory.
+
+ Returns
+ -------
+ Input
+ The loaded input.
+
+ """
+ # Load the structure
+ structure = np.load(target_dir / f"{record.id}.npz")
+ structure = Structure(
+ atoms=structure["atoms"],
+ bonds=structure["bonds"],
+ residues=structure["residues"],
+ chains=structure["chains"],
+ connections=structure["connections"],
+ interfaces=structure["interfaces"],
+ mask=structure["mask"],
+ )
+
+ msas = {}
+ for chain in record.chains:
+ msa_id = chain.msa_id
+ # Load the MSA for this chain, if any
+ if msa_id != -1:
+ msa = np.load(msa_dir / f"{msa_id}.npz")
+ msas[chain.chain_id] = MSA(**msa)
+
+ return Input(structure, msas)
+
+
+def collate(data: list[dict[str, Tensor]]) -> dict[str, Tensor]:
+ """Collate the data.
+
+ Parameters
+ ----------
+ data : List[Dict[str, Tensor]]
+ The data to collate.
+
+ Returns
+ -------
+ Dict[str, Tensor]
+ The collated data.
+
+ """
+ # Get the keys
+ keys = data[0].keys()
+
+ # Collate the data
+ collated = {}
+ for key in keys:
+ values = [d[key] for d in data]
+
+ if key not in [
+ "all_coords",
+ "all_resolved_mask",
+ "crop_to_all_atom_map",
+ "chain_symmetries",
+ "amino_acids_symmetries",
+ "ligand_symmetries",
+ "record",
+ ]:
+ # Check if all have the same shape
+ shape = values[0].shape
+ if not all(v.shape == shape for v in values):
+ values, _ = pad_to_max(values, 0)
+ else:
+ values = torch.stack(values, dim=0)
+
+ # Stack the values
+ collated[key] = values
+
+ return collated
+
+
+class PredictionDataset(torch.utils.data.Dataset):
+ """Base iterable dataset."""
+
+ def __init__(
+ self,
+ manifest: Manifest,
+ target_dir: Path,
+ msa_dir: Path,
+ ) -> None:
+ """Initialize the training dataset.
+
+ Parameters
+ ----------
+ manifest : Manifest
+ The manifest to load data from.
+ target_dir : Path
+ The path to the target directory.
+ msa_dir : Path
+ The path to the msa directory.
+
+ """
+ super().__init__()
+ self.manifest = manifest
+ self.target_dir = target_dir
+ self.msa_dir = msa_dir
+ self.tokenizer = BoltzTokenizer()
+ self.featurizer = BoltzFeaturizer()
+
+ def __getitem__(self, idx: int) -> dict:
+ """Get an item from the dataset.
+
+ Returns
+ -------
+ Dict[str, Tensor]
+ The sampled data features.
+
+ """
+ # Get a sample from the dataset
+ record = self.manifest.records[idx]
+
+ # Get the structure
+ try:
+ input_data = load_input(record, self.target_dir, self.msa_dir)
+ except Exception as e: # noqa: BLE001
+ print(f"Failed to load input for {record.id} with error {e}. Skipping.") # noqa: T201
+ return self.__getitem__(0)
+
+ # Tokenize structure
+ try:
+ tokenized = self.tokenizer.tokenize(input_data)
+ except Exception as e: # noqa: BLE001
+ print(f"Tokenizer failed on {record.id} with error {e}. Skipping.") # noqa: T201
+ return self.__getitem__(0)
+
+ # Compute features
+ try:
+ features = self.featurizer.process(
+ tokenized,
+ training=False,
+ max_atoms=None,
+ max_tokens=None,
+ max_seqs=const.max_msa_seqs,
+ pad_to_max_seqs=False,
+ symmetries={},
+ compute_symmetries=False,
+ )
+ except Exception as e: # noqa: BLE001
+ print(f"Featurizer failed on {record.id} with error {e}. Skipping.") # noqa: T201
+ return self.__getitem__(0)
+
+ features["record"] = record
+ return features
+
+ def __len__(self) -> int:
+ """Get the length of the dataset.
+
+ Returns
+ -------
+ int
+ The length of the dataset.
+
+ """
+ return len(self.manifest.records)
+
+
+class BoltzInferenceDataModule(pl.LightningDataModule):
+ """DataModule for Boltz inference."""
+
+ def __init__(
+ self,
+ manifest: Manifest,
+ target_dir: Path,
+ msa_dir: Path,
+ num_workers: int,
+ ) -> None:
+ """Initialize the DataModule.
+
+ Parameters
+ ----------
+ config : DataConfig
+ The data configuration.
+
+ """
+ super().__init__()
+ self.num_workers = num_workers
+ self.manifest = manifest
+ self.target_dir = target_dir
+ self.msa_dir = msa_dir
+
+ def predict_dataloader(self) -> DataLoader:
+ """Get the training dataloader.
+
+ Returns
+ -------
+ DataLoader
+ The training dataloader.
+
+ """
+ dataset = PredictionDataset(
+ manifest=self.manifest,
+ target_dir=self.target_dir,
+ msa_dir=self.msa_dir,
+ )
+ return DataLoader(
+ dataset,
+ batch_size=1,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ shuffle=False,
+ collate_fn=collate,
+ )
+
+ def transfer_batch_to_device(
+ self,
+ batch: dict,
+ device: torch.device,
+ dataloader_idx: int, # noqa: ARG002
+ ) -> dict:
+ """Transfer a batch to the given device.
+
+ Parameters
+ ----------
+ batch : Dict
+ The batch to transfer.
+ device : torch.device
+ The device to transfer to.
+ dataloader_idx : int
+ The dataloader index.
+
+ Returns
+ -------
+ np.Any
+ The transferred batch.
+
+ """
+ for key in batch:
+ if key not in [
+ "all_coords",
+ "all_resolved_mask",
+ "crop_to_all_atom_map",
+ "chain_symmetries",
+ "amino_acids_symmetries",
+ "ligand_symmetries",
+ "record",
+ ]:
+ batch[key] = batch[key].to(device)
+ return batch
diff --git a/src/boltz/data/module/training.py b/src/boltz/data/module/training.py
new file mode 100644
index 0000000..aaee6f7
--- /dev/null
+++ b/src/boltz/data/module/training.py
@@ -0,0 +1,660 @@
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Optional
+
+import numpy as np
+import pytorch_lightning as pl
+import torch
+from torch import Tensor
+from torch.utils.data import DataLoader
+
+from boltz.data.crop.cropper import Cropper
+from boltz.data.feature.featurizer import BoltzFeaturizer
+from boltz.data.feature.pad import pad_to_max
+from boltz.data.feature.symmetry import get_symmetries
+from boltz.data.filter.dynamic.filter import DynamicFilter
+from boltz.data.sample.sampler import Sample, Sampler
+from boltz.data.tokenize.tokenizer import Tokenizer
+from boltz.data.types import MSA, Input, Manifest, Record, Structure
+
+
+@dataclass
+class DatasetConfig:
+ """Dataset configuration."""
+
+ target_dir: str
+ msa_dir: str
+ prob: float
+ sampler: Sampler
+ cropper: Cropper
+ filters: Optional[list] = None
+ split: Optional[str] = None
+ manifest_path: Optional[str] = None
+
+
+@dataclass
+class DataConfig:
+ """Data configuration."""
+
+ datasets: list[DatasetConfig]
+ filters: list[DynamicFilter]
+ featurizer: BoltzFeaturizer
+ tokenizer: Tokenizer
+ max_atoms: int
+ max_tokens: int
+ max_seqs: int
+ samples_per_epoch: int
+ batch_size: int
+ num_workers: int
+ random_seed: int
+ pin_memory: bool
+ symmetries: str
+ atoms_per_window_queries: int
+ min_dist: float
+ max_dist: float
+ num_bins: int
+ overfit: Optional[int] = None
+ pad_to_max_tokens: bool = False
+ pad_to_max_atoms: bool = False
+ pad_to_max_seqs: bool = False
+ crop_validation: bool = False
+ return_train_symmetries: bool = False
+ return_val_symmetries: bool = True
+ train_binder_pocket_conditioned_prop: float = 0.0
+ val_binder_pocket_conditioned_prop: float = 0.0
+ binder_pocket_cutoff: float = 6.0
+ binder_pocket_sampling_geometric_p: float = 0.0
+ val_batch_size: int = 1
+
+
+@dataclass
+class Dataset:
+ """Data holder."""
+
+ target_dir: Path
+ msa_dir: Path
+ manifest: Manifest
+ prob: float
+ sampler: Sampler
+ cropper: Cropper
+ tokenizer: Tokenizer
+ featurizer: BoltzFeaturizer
+
+
+def load_input(record: Record, target_dir: Path, msa_dir: Path) -> Input:
+ """Load the given input data.
+
+ Parameters
+ ----------
+ record : Record
+ The record to load.
+ target_dir : Path
+ The path to the data directory.
+ msa_dir : Path
+ The path to msa directory.
+
+ Returns
+ -------
+ Input
+ The loaded input.
+
+ """
+ # Load the structure
+ structure = np.load(target_dir / "structures" / f"{record.id}.npz")
+ structure = Structure(
+ atoms=structure["atoms"],
+ bonds=structure["bonds"],
+ residues=structure["residues"],
+ chains=structure["chains"],
+ connections=structure["connections"],
+ interfaces=structure["interfaces"],
+ mask=structure["mask"],
+ )
+
+ msas = {}
+ for chain in record.chains:
+ msa_id = chain.msa_id
+ # Load the MSA for this chain, if any
+ if msa_id != -1:
+ msa = np.load(msa_dir / f"{msa_id}.npz")
+ msas[chain.chain_id] = MSA(**msa)
+
+ return Input(structure, msas)
+
+
+def collate(data: list[dict[str, Tensor]]) -> dict[str, Tensor]:
+ """Collate the data.
+
+ Parameters
+ ----------
+ data : list[dict[str, Tensor]]
+ The data to collate.
+
+ Returns
+ -------
+ dict[str, Tensor]
+ The collated data.
+
+ """
+ # Get the keys
+ keys = data[0].keys()
+
+ # Collate the data
+ collated = {}
+ for key in keys:
+ values = [d[key] for d in data]
+
+ if key not in [
+ "all_coords",
+ "all_resolved_mask",
+ "crop_to_all_atom_map",
+ "chain_symmetries",
+ "amino_acids_symmetries",
+ "ligand_symmetries",
+ ]:
+ # Check if all have the same shape
+ shape = values[0].shape
+ if not all(v.shape == shape for v in values):
+ values, _ = pad_to_max(values, 0)
+ else:
+ values = torch.stack(values, dim=0)
+
+ # Stack the values
+ collated[key] = values
+
+ return collated
+
+
+class TrainingDataset(torch.utils.data.Dataset):
+ """Base iterable dataset."""
+
+ def __init__(
+ self,
+ datasets: list[Dataset],
+ samples_per_epoch: int,
+ symmetries: dict,
+ max_atoms: int,
+ max_tokens: int,
+ max_seqs: int,
+ pad_to_max_atoms: bool = False,
+ pad_to_max_tokens: bool = False,
+ pad_to_max_seqs: bool = False,
+ atoms_per_window_queries: int = 32,
+ min_dist: float = 2.0,
+ max_dist: float = 22.0,
+ num_bins: int = 64,
+ overfit: Optional[int] = None,
+ binder_pocket_conditioned_prop: Optional[float] = 0.0,
+ binder_pocket_cutoff: Optional[float] = 6.0,
+ binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
+ return_symmetries: Optional[bool] = False,
+ ) -> None:
+ """Initialize the training dataset."""
+ super().__init__()
+ self.datasets = datasets
+ self.probs = [d.prob for d in datasets]
+ self.samples_per_epoch = samples_per_epoch
+ self.symmetries = symmetries
+ self.max_tokens = max_tokens
+ self.max_seqs = max_seqs
+ self.max_atoms = max_atoms
+ self.pad_to_max_tokens = pad_to_max_tokens
+ self.pad_to_max_atoms = pad_to_max_atoms
+ self.pad_to_max_seqs = pad_to_max_seqs
+ self.atoms_per_window_queries = atoms_per_window_queries
+ self.min_dist = min_dist
+ self.max_dist = max_dist
+ self.num_bins = num_bins
+ self.binder_pocket_conditioned_prop = binder_pocket_conditioned_prop
+ self.binder_pocket_cutoff = binder_pocket_cutoff
+ self.binder_pocket_sampling_geometric_p = binder_pocket_sampling_geometric_p
+ self.return_symmetries = return_symmetries
+ self.samples = []
+ for dataset in datasets:
+ records = dataset.manifest.records
+ if overfit is not None:
+ records = records[:overfit]
+ iterator = dataset.sampler.sample(records, np.random)
+ self.samples.append(iterator)
+
+ def __getitem__(self, idx: int) -> dict[str, Tensor]:
+ """Get an item from the dataset.
+
+ Parameters
+ ----------
+ idx : int
+ The data index.
+
+ Returns
+ -------
+ dict[str, Tensor]
+ The sampled data features.
+
+ """
+ # Pick a random dataset
+ dataset_idx = np.random.choice(
+ len(self.datasets),
+ p=self.probs,
+ )
+ dataset = self.datasets[dataset_idx]
+
+ # Get a sample from the dataset
+ sample: Sample = next(self.samples[dataset_idx])
+
+ # Get the structure
+ try:
+ input_data = load_input(sample.record, dataset.target_dir, dataset.msa_dir)
+ except Exception as e:
+ print(
+ f"Failed to load input for {sample.record.id} with error {e}. Skipping."
+ )
+ return self.__getitem__(idx)
+
+ # Tokenize structure
+ try:
+ tokenized = dataset.tokenizer.tokenize(input_data)
+ except Exception as e:
+ print(f"Tokenizer failed on {sample.record.id} with error {e}. Skipping.")
+ return self.__getitem__(idx)
+
+ # Compute crop
+ try:
+ if self.max_tokens is not None:
+ tokenized = dataset.cropper.crop(
+ tokenized,
+ max_atoms=self.max_atoms,
+ max_tokens=self.max_tokens,
+ random=np.random,
+ chain_id=sample.chain_id,
+ interface_id=sample.interface_id,
+ )
+ except Exception as e:
+ print(f"Cropper failed on {sample.record.id} with error {e}. Skipping.")
+ return self.__getitem__(idx)
+
+ # Check if there are tokens
+ if len(tokenized.tokens) == 0:
+ msg = "No tokens in cropped structure."
+ raise ValueError(msg)
+
+ # Compute features
+ try:
+ features = dataset.featurizer.process(
+ tokenized,
+ training=True,
+ max_atoms=self.max_atoms if self.pad_to_max_atoms else None,
+ max_tokens=self.max_tokens if self.pad_to_max_tokens else None,
+ max_seqs=self.max_seqs,
+ pad_to_max_seqs=self.pad_to_max_seqs,
+ symmetries=self.symmetries,
+ atoms_per_window_queries=self.atoms_per_window_queries,
+ min_dist=self.min_dist,
+ max_dist=self.max_dist,
+ num_bins=self.num_bins,
+ compute_symmetries=self.return_symmetries,
+ binder_pocket_conditioned_prop=self.binder_pocket_conditioned_prop,
+ binder_pocket_cutoff=self.binder_pocket_cutoff,
+ binder_pocket_sampling_geometric_p=self.binder_pocket_sampling_geometric_p,
+ )
+ except Exception as e:
+ print(f"Featurizer failed on {sample.record.id} with error {e}. Skipping.")
+ return self.__getitem__(idx)
+
+ return features
+
+ def __len__(self) -> int:
+ """Get the length of the dataset.
+
+ Returns
+ -------
+ int
+ The length of the dataset.
+
+ """
+ return self.samples_per_epoch
+
+
+class ValidationDataset(torch.utils.data.Dataset):
+ """Base iterable dataset."""
+
+ def __init__(
+ self,
+ datasets: list[Dataset],
+ seed: int,
+ symmetries: dict,
+ max_atoms: Optional[int] = None,
+ max_tokens: Optional[int] = None,
+ max_seqs: Optional[int] = None,
+ pad_to_max_atoms: bool = False,
+ pad_to_max_tokens: bool = False,
+ pad_to_max_seqs: bool = False,
+ atoms_per_window_queries: int = 32,
+ min_dist: float = 2.0,
+ max_dist: float = 22.0,
+ num_bins: int = 64,
+ overfit: Optional[int] = None,
+ crop_validation: bool = False,
+ return_symmetries: Optional[bool] = False,
+ binder_pocket_conditioned_prop: Optional[float] = 0.0,
+ binder_pocket_cutoff: Optional[float] = 6.0,
+ ) -> None:
+ """Initialize the validation dataset."""
+ super().__init__()
+ self.datasets = datasets
+ self.max_atoms = max_atoms
+ self.max_tokens = max_tokens
+ self.max_seqs = max_seqs
+ self.seed = seed
+ self.symmetries = symmetries
+ self.random = np.random if overfit else np.random.RandomState(self.seed)
+ self.pad_to_max_tokens = pad_to_max_tokens
+ self.pad_to_max_atoms = pad_to_max_atoms
+ self.pad_to_max_seqs = pad_to_max_seqs
+ self.overfit = overfit
+ self.crop_validation = crop_validation
+ self.atoms_per_window_queries = atoms_per_window_queries
+ self.min_dist = min_dist
+ self.max_dist = max_dist
+ self.num_bins = num_bins
+ self.return_symmetries = return_symmetries
+ self.binder_pocket_conditioned_prop = binder_pocket_conditioned_prop
+ self.binder_pocket_cutoff = binder_pocket_cutoff
+
+ def __getitem__(self, idx: int) -> dict[str, Tensor]:
+ """Get an item from the dataset.
+
+ Parameters
+ ----------
+ idx : int
+ The data index.
+
+ Returns
+ -------
+ dict[str, Tensor]
+ The sampled data features.
+
+ """
+ # Pick dataset based on idx
+ for dataset in self.datasets:
+ size = len(dataset.manifest.records)
+ if self.overfit is not None:
+ size = min(size, self.overfit)
+ if idx < size:
+ break
+ idx -= size
+
+ # Get a sample from the dataset
+ record = dataset.manifest.records[idx]
+
+ # Get the structure
+ try:
+ input_data = load_input(record, dataset.target_dir, dataset.msa_dir)
+ except Exception as e:
+ print(f"Failed to load input for {record.id} with error {e}. Skipping.")
+ return self.__getitem__(0)
+
+ # Tokenize structure
+ try:
+ tokenized = dataset.tokenizer.tokenize(input_data)
+ except Exception as e:
+ print(f"Tokenizer failed on {record.id} with error {e}. Skipping.")
+ return self.__getitem__(0)
+
+ # Compute crop
+ try:
+ if self.crop_validation and (self.max_tokens is not None):
+ tokenized = dataset.cropper.crop(
+ tokenized,
+ max_tokens=self.max_tokens,
+ random=self.random,
+ max_atoms=self.max_atoms,
+ )
+ except Exception as e:
+ print(f"Cropper failed on {record.id} with error {e}. Skipping.")
+ return self.__getitem__(0)
+
+ # Check if there are tokens
+ if len(tokenized.tokens) == 0:
+ msg = "No tokens in cropped structure."
+ raise ValueError(msg)
+
+ # Compute features
+ try:
+ pad_atoms = self.crop_validation and self.pad_to_max_atoms
+ pad_tokens = self.crop_validation and self.pad_to_max_tokens
+
+ features = dataset.featurizer.process(
+ tokenized,
+ training=False,
+ max_atoms=self.max_atoms if pad_atoms else None,
+ max_tokens=self.max_tokens if pad_tokens else None,
+ max_seqs=self.max_seqs,
+ pad_to_max_seqs=self.pad_to_max_seqs,
+ symmetries=self.symmetries,
+ atoms_per_window_queries=self.atoms_per_window_queries,
+ min_dist=self.min_dist,
+ max_dist=self.max_dist,
+ num_bins=self.num_bins,
+ compute_symmetries=self.return_symmetries,
+ binder_pocket_conditioned_prop=self.binder_pocket_conditioned_prop,
+ binder_pocket_cutoff=self.binder_pocket_cutoff,
+ binder_pocket_sampling_geometric_p=1.0, # this will only sample a single pocket token
+ only_ligand_binder_pocket=True,
+ )
+ except Exception as e:
+ print(f"Featurizer failed on {record.id} with error {e}. Skipping.")
+ return self.__getitem__(0)
+
+ return features
+
+ def __len__(self) -> int:
+ """Get the length of the dataset.
+
+ Returns
+ -------
+ int
+ The length of the dataset.
+
+ """
+ if self.overfit is not None:
+ length = sum(len(d.manifest.records[: self.overfit]) for d in self.datasets)
+ else:
+ length = sum(len(d.manifest.records) for d in self.datasets)
+
+ return length
+
+
+class BoltzTrainingDataModule(pl.LightningDataModule):
+ """DataModule for boltz."""
+
+ def __init__(self, cfg: DataConfig) -> None:
+ """Initialize the DataModule.
+
+ Parameters
+ ----------
+ config : DataConfig
+ The data configuration.
+
+ """
+ super().__init__()
+ self.cfg = cfg
+
+ assert self.cfg.val_batch_size == 1, "Validation only works with batch size=1."
+
+ # Load symmetries
+ symmetries = get_symmetries(cfg.symmetries)
+
+ # Load datasets
+ train: list[Dataset] = []
+ val: list[Dataset] = []
+
+ for data_config in cfg.datasets:
+ # Set target_dir
+ target_dir = Path(data_config.target_dir)
+ msa_dir = Path(data_config.msa_dir)
+
+ # Load manifest
+ if data_config.manifest_path is not None:
+ path = Path(data_config.manifest_path)
+ else:
+ path = target_dir / "manifest.json"
+ manifest: Manifest = Manifest.load(path)
+
+ # Split records if given
+ if data_config.split is not None:
+ with Path(data_config.split).open("r") as f:
+ split = {x.lower() for x in f.read().splitlines()}
+
+ train_records = []
+ val_records = []
+ for record in manifest.records:
+ if record.id.lower() in split:
+ val_records.append(record)
+ else:
+ train_records.append(record)
+ else:
+ train_records = manifest.records
+ val_records = []
+
+ # Filter training records
+ train_records = [
+ record
+ for record in train_records
+ if all(f.filter(record) for f in cfg.filters)
+ ]
+ # Filter training records
+ if data_config.filters is not None:
+ train_records = [
+ record
+ for record in train_records
+ if all(f.filter(record) for f in data_config.filters)
+ ]
+
+ # Create train dataset
+ train_manifest = Manifest(train_records)
+ train.append(
+ Dataset(
+ target_dir,
+ msa_dir,
+ train_manifest,
+ data_config.prob,
+ data_config.sampler,
+ data_config.cropper,
+ cfg.tokenizer,
+ cfg.featurizer,
+ )
+ )
+
+ # Create validation dataset
+ if val_records:
+ val_manifest = Manifest(val_records)
+ val.append(
+ Dataset(
+ target_dir,
+ msa_dir,
+ val_manifest,
+ data_config.prob,
+ data_config.sampler,
+ data_config.cropper,
+ cfg.tokenizer,
+ cfg.featurizer,
+ )
+ )
+
+ # Print dataset sizes
+ for dataset in train:
+ dataset: Dataset
+ print(f"Training dataset size: {len(dataset.manifest.records)}")
+
+ for dataset in val:
+ dataset: Dataset
+ print(f"Validation dataset size: {len(dataset.manifest.records)}")
+
+ # Create wrapper datasets
+ self._train_set = TrainingDataset(
+ datasets=train,
+ samples_per_epoch=cfg.samples_per_epoch,
+ max_atoms=cfg.max_atoms,
+ max_tokens=cfg.max_tokens,
+ max_seqs=cfg.max_seqs,
+ pad_to_max_atoms=cfg.pad_to_max_atoms,
+ pad_to_max_tokens=cfg.pad_to_max_tokens,
+ pad_to_max_seqs=cfg.pad_to_max_seqs,
+ symmetries=symmetries,
+ atoms_per_window_queries=cfg.atoms_per_window_queries,
+ min_dist=cfg.min_dist,
+ max_dist=cfg.max_dist,
+ num_bins=cfg.num_bins,
+ overfit=cfg.overfit,
+ binder_pocket_conditioned_prop=cfg.train_binder_pocket_conditioned_prop,
+ binder_pocket_cutoff=cfg.binder_pocket_cutoff,
+ binder_pocket_sampling_geometric_p=cfg.binder_pocket_sampling_geometric_p,
+ return_symmetries=cfg.return_train_symmetries,
+ )
+ self._val_set = ValidationDataset(
+ datasets=train if cfg.overfit is not None else val,
+ seed=cfg.random_seed,
+ max_atoms=cfg.max_atoms,
+ max_tokens=cfg.max_tokens,
+ max_seqs=cfg.max_seqs,
+ pad_to_max_atoms=cfg.pad_to_max_atoms,
+ pad_to_max_tokens=cfg.pad_to_max_tokens,
+ pad_to_max_seqs=cfg.pad_to_max_seqs,
+ symmetries=symmetries,
+ atoms_per_window_queries=cfg.atoms_per_window_queries,
+ min_dist=cfg.min_dist,
+ max_dist=cfg.max_dist,
+ num_bins=cfg.num_bins,
+ overfit=cfg.overfit,
+ crop_validation=cfg.crop_validation,
+ return_symmetries=cfg.return_val_symmetries,
+ binder_pocket_conditioned_prop=cfg.val_binder_pocket_conditioned_prop,
+ binder_pocket_cutoff=cfg.binder_pocket_cutoff,
+ )
+
+ def setup(self, stage: Optional[str] = None) -> None:
+ """Run the setup for the DataModule.
+
+ Parameters
+ ----------
+ stage : str, optional
+ The stage, one of 'fit', 'validate', 'test'.
+
+ """
+ return
+
+ def train_dataloader(self) -> DataLoader:
+ """Get the training dataloader.
+
+ Returns
+ -------
+ DataLoader
+ The training dataloader.
+
+ """
+ return DataLoader(
+ self._train_set,
+ batch_size=self.cfg.batch_size,
+ num_workers=self.cfg.num_workers,
+ pin_memory=self.cfg.pin_memory,
+ shuffle=False,
+ collate_fn=collate,
+ )
+
+ def val_dataloader(self) -> DataLoader:
+ """Get the validation dataloader.
+
+ Returns
+ -------
+ DataLoader
+ The validation dataloader.
+
+ """
+ return DataLoader(
+ self._val_set,
+ batch_size=self.cfg.val_batch_size,
+ num_workers=self.cfg.num_workers,
+ pin_memory=self.cfg.pin_memory,
+ shuffle=False,
+ collate_fn=collate,
+ )
diff --git a/src/boltz/data/parse/__init__.py b/src/boltz/data/parse/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/boltz/data/parse/a3m.py b/src/boltz/data/parse/a3m.py
new file mode 100644
index 0000000..9df9714
--- /dev/null
+++ b/src/boltz/data/parse/a3m.py
@@ -0,0 +1,134 @@
+import gzip
+from pathlib import Path
+from typing import Optional, TextIO
+
+import numpy as np
+
+from boltz.data import const
+from boltz.data.types import MSA, MSADeletion, MSAResidue, MSASequence
+
+
+def _parse_a3m( # noqa: C901
+ lines: TextIO,
+ taxonomy: Optional[dict[str, str]],
+ max_seqs: Optional[int] = None,
+) -> MSA:
+ """Process an MSA file.
+
+ Parameters
+ ----------
+ lines : TextIO
+ The lines of the MSA file.
+ taxonomy : dict[str, str]
+ The taxonomy database, if available.
+ max_seqs : int, optional
+ The maximum number of sequences.
+
+ Returns
+ -------
+ MSA
+ The MSA object.
+
+ """
+ visited = set()
+ sequences = []
+ deletions = []
+ residues = []
+
+ seq_idx = 0
+ for line in lines:
+ line: str
+ line = line.strip() # noqa: PLW2901
+ if not line or line.startswith("#"):
+ continue
+
+ # Get taxonomy, if annotated
+ if line.startswith(">"):
+ header = line.split()[0]
+ if taxonomy and header.startswith(">UniRef100"):
+ uniref_id = header.split("_")[1]
+ taxonomy_id = taxonomy.get(uniref_id)
+ if taxonomy_id is None:
+ taxonomy_id = -1
+ else:
+ taxonomy_id = -1
+ continue
+
+ # Skip if duplicate sequence
+ str_seq = line.replace("-", "").upper()
+ if str_seq not in visited:
+ visited.add(str_seq)
+ else:
+ continue
+
+ # Process sequence
+ residue = []
+ deletion = []
+ count = 0
+ res_idx = 0
+ for c in line:
+ if c != "-" and c.islower():
+ count += 1
+ continue
+ token = const.prot_letter_to_token[c]
+ token = const.token_ids[token]
+ residue.append(token)
+ if count > 0:
+ deletion.append((res_idx, count))
+ count = 0
+ res_idx += 1
+
+ res_start = len(residues)
+ res_end = res_start + len(residue)
+
+ del_start = len(deletions)
+ del_end = del_start + len(deletion)
+
+ sequences.append((seq_idx, taxonomy_id, res_start, res_end, del_start, del_end))
+ residues.extend(residue)
+ deletions.extend(deletion)
+
+ seq_idx += 1
+ if (max_seqs is not None) and (seq_idx >= max_seqs):
+ break
+
+ # Create MSA object
+ msa = MSA(
+ residues=np.array(residues, dtype=MSAResidue),
+ deletions=np.array(deletions, dtype=MSADeletion),
+ sequences=np.array(sequences, dtype=MSASequence),
+ )
+ return msa
+
+
+def parse_a3m(
+ path: Path,
+ taxonomy: Optional[dict[str, str]],
+ max_seqs: Optional[int] = None,
+) -> MSA:
+ """Process an A3M file.
+
+ Parameters
+ ----------
+ path : Path
+ The path to the a3m(.gz) file.
+ taxonomy : Redis
+ The taxonomy database.
+ max_seqs : int, optional
+ The maximum number of sequences.
+
+ Returns
+ -------
+ MSA
+ The MSA object.
+
+ """
+ # Read the file
+ if path.suffix == ".gz":
+ with gzip.open(str(path), "rt") as f:
+ msa = _parse_a3m(f, taxonomy, max_seqs)
+ else:
+ with path.open("r") as f:
+ msa = _parse_a3m(f, taxonomy, max_seqs)
+
+ return msa
diff --git a/src/boltz/data/parse/fasta.py b/src/boltz/data/parse/fasta.py
new file mode 100644
index 0000000..d6b0753
--- /dev/null
+++ b/src/boltz/data/parse/fasta.py
@@ -0,0 +1,127 @@
+from collections.abc import Mapping
+from pathlib import Path
+
+from Bio import SeqIO
+from rdkit.Chem.rdchem import Mol
+
+from boltz.data.parse.yaml import parse_boltz_schema
+from boltz.data.types import Target
+
+
+def parse_fasta(path: Path, ccd: Mapping[str, Mol]) -> Target: # noqa: C901
+ """Parse a fasta file.
+
+ The name of the fasta file is used as the name of this job.
+ We rely on the fasta record id to determine the entity type.
+
+ > CHAIN_ID|ENTITY_TYPE|MSA_ID
+ SEQUENCE
+ > CHAIN_ID|ENTITY_TYPE|MSA_ID
+ ...
+
+ Where ENTITY_TYPE is either protein, rna, dna, ccd or smiles,
+ and CHAIN_ID is the chain identifier, which should be unique.
+ The MSA_ID is optional and should only be used on proteins.
+
+ Parameters
+ ----------
+ fasta_file : Path
+ Path to the fasta file.
+ ccd : Dict
+ Dictionary of CCD components.
+
+ Returns
+ -------
+ Target
+ The parsed target.
+
+ """
+ # Read fasta file
+ with path.open("r") as f:
+ records = list(SeqIO.parse(f, "fasta"))
+
+ # Make sure all records have a chain id and entity
+ for seq_record in records:
+ if "|" not in seq_record.id:
+ msg = f"Invalid record id: {seq_record.id}"
+ raise ValueError(msg)
+
+ header = seq_record.id.split("|")
+ assert len(header) >= 2, f"Invalid record id: {seq_record.id}"
+
+ chain_id, entity_type = header[:2]
+ if entity_type.lower() not in {"protein", "dna", "rna", "ccd", "smiles"}:
+ msg = f"Invalid entity type: {entity_type}"
+ raise ValueError(msg)
+ if chain_id == "":
+ msg = "Empty chain id in input fasta!"
+ raise ValueError(msg)
+ if entity_type == "":
+ msg = "Empty entity type in input fasta!"
+ raise ValueError(msg)
+
+ # Convert to yaml format
+ sequences = []
+ for seq_record in records:
+ # Get chain id, entity type and sequence
+ header = seq_record.id.split("|")
+ chain_id, entity_type = header[:2]
+ if len(header) == 3 and header[2] != "":
+ assert (
+ entity_type.lower() == "protein"
+ ), "MSA_ID is only allowed for proteins"
+ msa_id = header[2]
+
+ entity_type = entity_type.upper()
+ seq = str(seq_record.seq)
+
+ if entity_type == "PROTEIN":
+ molecule = {
+ "protein": {
+ "id": chain_id,
+ "sequence": seq,
+ "modifications": [],
+ "msa": msa_id,
+ },
+ }
+ elif entity_type == "RNA":
+ molecule = {
+ "rna": {
+ "id": chain_id,
+ "sequence": seq,
+ "modifications": [],
+ },
+ }
+ elif entity_type == "DNA":
+ molecule = {
+ "dna": {
+ "id": chain_id,
+ "sequence": seq,
+ "modifications": [],
+ }
+ }
+ elif entity_type.upper() == "CCD":
+ molecule = {
+ "ligand": {
+ "id": chain_id,
+ "ccd": seq,
+ }
+ }
+ elif entity_type.upper() == "SMILES":
+ molecule = {
+ "ligand": {
+ "id": chain_id,
+ "smiles": seq,
+ }
+ }
+
+ sequences.append(molecule)
+
+ data = {
+ "sequences": sequences,
+ "bonds": [],
+ "version": 1,
+ }
+
+ name = path.stem
+ return parse_boltz_schema(name, data, ccd)
diff --git a/src/boltz/data/parse/schema.py b/src/boltz/data/parse/schema.py
new file mode 100644
index 0000000..bbb4b86
--- /dev/null
+++ b/src/boltz/data/parse/schema.py
@@ -0,0 +1,789 @@
+from collections.abc import Mapping
+from dataclasses import dataclass
+from typing import Optional
+
+import numpy as np
+from rdkit import rdBase
+from rdkit.Chem import AllChem
+from rdkit.Chem.rdchem import Conformer, Mol
+
+from boltz.data import const
+from boltz.data.types import (
+ Atom,
+ Bond,
+ Chain,
+ ChainInfo,
+ Connection,
+ Interface,
+ Record,
+ Residue,
+ Structure,
+ StructureInfo,
+ Target,
+)
+
+####################################################################################################
+# DATACLASSES
+####################################################################################################
+
+
+@dataclass(frozen=True)
+class ParsedAtom:
+ """A parsed atom object."""
+
+ name: str
+ element: int
+ charge: int
+ coords: tuple[float, float, float]
+ conformer: tuple[float, float, float]
+ is_present: bool
+ chirality: int
+
+
+@dataclass(frozen=True)
+class ParsedBond:
+ """A parsed bond object."""
+
+ atom_1: int
+ atom_2: int
+ type: int
+
+
+@dataclass(frozen=True)
+class ParsedResidue:
+ """A parsed residue object."""
+
+ name: str
+ type: int
+ idx: int
+ atoms: list[ParsedAtom]
+ bonds: list[ParsedBond]
+ orig_idx: Optional[int]
+ atom_center: int
+ atom_disto: int
+ is_standard: bool
+ is_present: bool
+
+
+@dataclass(frozen=True)
+class ParsedChain:
+ """A parsed chain object."""
+
+ entity: str
+ type: str
+ residues: list[ParsedResidue]
+
+
+####################################################################################################
+# HELPERS
+####################################################################################################
+
+
+def convert_atom_name(name: str) -> tuple[int, int, int, int]:
+ """Convert an atom name to a standard format.
+
+ Parameters
+ ----------
+ name : str
+ The atom name.
+
+ Returns
+ -------
+ Tuple[int, int, int, int]
+ The converted atom name.
+
+ """
+ name = name.strip()
+ name = [ord(c) - 32 for c in name]
+ name = name + [0] * (4 - len(name))
+ return tuple(name)
+
+
+def compute_3d_conformer(mol: Mol, version: str = "v3") -> bool:
+ """Generate 3D coordinates using EKTDG method.
+
+ Taken from `pdbeccdutils.core.component.Component`.
+
+ Parameters
+ ----------
+ mol: Mol
+ The RDKit molecule to process
+ version: str, optional
+ The ETKDG version, defaults ot v3
+
+ Returns
+ -------
+ bool
+ Whether computation was successful.
+
+ """
+ if version == "v3":
+ options = AllChem.ETKDGv3()
+ elif version == "v2":
+ options = AllChem.ETKDGv2()
+ else:
+ options = AllChem.ETKDGv2()
+
+ options.clearConfs = False
+ conf_id = -1
+
+ try:
+ conf_id = AllChem.EmbedMolecule(mol, options)
+ AllChem.UFFOptimizeMolecule(mol, confId=conf_id, maxIters=1000)
+
+ except RuntimeError:
+ pass # Force field issue here
+ except ValueError:
+ pass # sanitization issue here
+
+ if conf_id != -1:
+ conformer = mol.GetConformer(conf_id)
+ conformer.SetProp("name", "Computed")
+ conformer.SetProp("coord_generation", f"ETKDG{version}")
+
+ return True
+
+ return False
+
+
+def get_conformer(mol: Mol) -> Conformer:
+ """Retrieve an rdkit object for a deemed conformer.
+
+ Inspired by `pdbeccdutils.core.component.Component`.
+
+ Parameters
+ ----------
+ mol: Mol
+ The molecule to process.
+
+ Returns
+ -------
+ Conformer
+ The desired conformer, if any.
+
+ Raises
+ ------
+ ValueError
+ If there are no conformers of the given tyoe.
+
+ """
+ # Try using the computed conformer
+ for c in mol.GetConformers():
+ try:
+ if c.GetProp("name") == "Computed":
+ return c
+ except KeyError: # noqa: PERF203
+ pass
+
+ # Fallback to the ideal coordinates
+ for c in mol.GetConformers():
+ try:
+ if c.GetProp("name") == "Ideal":
+ return c
+ except KeyError: # noqa: PERF203
+ pass
+
+ msg = "Conformer does not exist."
+ raise ValueError(msg)
+
+
+####################################################################################################
+# PARSING
+####################################################################################################
+
+
+def parse_ccd_residue(
+ name: str,
+ ref_mol: Mol,
+ res_idx: int,
+) -> Optional[ParsedResidue]:
+ """Parse an MMCIF ligand.
+
+ First tries to get the SMILES string from the RCSB.
+ Then, tries to infer atom ordering using RDKit.
+
+ Parameters
+ ----------
+ name: str
+ The name of the molecule to parse.
+ ref_mol: Mol
+ The reference molecule to parse.
+ res_idx : int
+ The residue index.
+
+ Returns
+ -------
+ ParsedResidue, optional
+ The output ParsedResidue, if successful.
+
+ """
+ unk_chirality = const.chirality_type_ids[const.unk_chirality_type]
+
+ # Remove hydrogens
+ ref_mol = AllChem.RemoveHs(ref_mol, sanitize=False)
+
+ # Check if this is a single atom CCD residue
+ if ref_mol.GetNumAtoms() == 1:
+ pos = (0, 0, 0)
+ ref_atom = ref_mol.GetAtoms()[0]
+ chirality_type = const.chirality_type_ids.get(
+ ref_atom.GetChiralTag(), unk_chirality
+ )
+ atom = ParsedAtom(
+ name=ref_atom.GetProp("name"),
+ element=ref_atom.GetAtomicNum(),
+ charge=ref_atom.GetFormalCharge(),
+ coords=pos,
+ conformer=(0, 0, 0),
+ is_present=True,
+ chirality=chirality_type,
+ )
+ unk_prot_id = const.unk_token_ids["PROTEIN"]
+ residue = ParsedResidue(
+ name=name,
+ type=unk_prot_id,
+ atoms=[atom],
+ bonds=[],
+ idx=res_idx,
+ orig_idx=None,
+ atom_center=0, # Placeholder, no center
+ atom_disto=0, # Placeholder, no center
+ is_standard=False,
+ is_present=True,
+ )
+ return residue
+
+ # Get reference conformer coordinates
+ conformer = get_conformer(ref_mol)
+
+ # Parse each atom in order of the reference mol
+ atoms = []
+ atom_idx = 0
+ idx_map = {} # Used for bonds later
+
+ for i, atom in enumerate(ref_mol.GetAtoms()):
+ # Get atom name, charge, element and reference coordinates
+ atom_name = atom.GetProp("name")
+ charge = atom.GetFormalCharge()
+ element = atom.GetAtomicNum()
+ ref_coords = conformer.GetAtomPosition(atom.GetIdx())
+ ref_coords = (ref_coords.x, ref_coords.y, ref_coords.z)
+ chirality_type = const.chirality_type_ids.get(
+ atom.GetChiralTag(), unk_chirality
+ )
+
+ # Get PDB coordinates, if any
+ coords = (0, 0, 0)
+ atom_is_present = True
+
+ # Add atom to list
+ atoms.append(
+ ParsedAtom(
+ name=atom_name,
+ element=element,
+ charge=charge,
+ coords=coords,
+ conformer=ref_coords,
+ is_present=atom_is_present,
+ chirality=chirality_type,
+ )
+ )
+ idx_map[i] = atom_idx
+ atom_idx += 1 # noqa: SIM113
+
+ # Load bonds
+ bonds = []
+ unk_bond = const.bond_type_ids[const.unk_bond_type]
+ for bond in ref_mol.GetBonds():
+ idx_1 = bond.GetBeginAtomIdx()
+ idx_2 = bond.GetEndAtomIdx()
+
+ # Skip bonds with atoms ignored
+ if (idx_1 not in idx_map) or (idx_2 not in idx_map):
+ continue
+
+ idx_1 = idx_map[idx_1]
+ idx_2 = idx_map[idx_2]
+ start = min(idx_1, idx_2)
+ end = max(idx_1, idx_2)
+ bond_type = bond.GetBondType().name
+ bond_type = const.bond_type_ids.get(bond_type, unk_bond)
+ bonds.append(ParsedBond(start, end, bond_type))
+
+ unk_prot_id = const.unk_token_ids["PROTEIN"]
+ return ParsedResidue(
+ name=name,
+ type=unk_prot_id,
+ atoms=atoms,
+ bonds=bonds,
+ idx=res_idx,
+ atom_center=0,
+ atom_disto=0,
+ orig_idx=None,
+ is_standard=False,
+ is_present=True,
+ )
+
+
+def parse_polymer(
+ sequence: list[str],
+ entity: str,
+ entity_type: str,
+ components: dict[str, Mol],
+) -> Optional[ParsedChain]:
+ """Process a sequence into a chain object.
+
+ Performs alignment of the full sequence to the polymer
+ residues. Loads coordinates and masks for the atoms in
+ the polymer, following the ordering in const.atom_order.
+
+ Parameters
+ ----------
+ sequence : list[str]
+ The full sequence of the polymer.
+ entity : str
+ The entity id.
+ entity_type : str
+ The entity type.
+ components : dict[str, Mol]
+ The preprocessed PDB components dictionary.
+
+ Returns
+ -------
+ ParsedChain, optional
+ The output chain, if successful.
+
+ Raises
+ ------
+ ValueError
+ If the alignment fails.
+
+ """
+ unk_chirality = const.chirality_type_ids[const.unk_chirality_type]
+
+ # Check what type of sequence this is
+ if entity_type == "rna":
+ chain_type = const.chain_type_ids["RNA"]
+ token_map = const.rna_letter_to_token
+ elif entity_type == "dna":
+ chain_type = const.chain_type_ids["DNA"]
+ token_map = const.dna_letter_to_token
+ elif entity_type == "protein":
+ chain_type = const.chain_type_ids["PROTEIN"]
+ token_map = const.prot_letter_to_token
+ else:
+ msg = f"Unknown polymer type: {entity_type}"
+ raise ValueError(msg)
+
+ # Get coordinates and masks
+ parsed = []
+ for res_idx, res_code in enumerate(sequence):
+ # Load ref residue
+ res_name = token_map[res_code]
+ ref_mol = components[res_name]
+ ref_mol = AllChem.RemoveHs(ref_mol, sanitize=False)
+ ref_conformer = get_conformer(ref_mol)
+
+ # Only use reference atoms set in constants
+ ref_name_to_atom = {a.GetProp("name"): a for a in ref_mol.GetAtoms()}
+ ref_atoms = [ref_name_to_atom[a] for a in const.ref_atoms[res_name]]
+
+ # Iterate, always in the same order
+ atoms: list[ParsedAtom] = []
+
+ for ref_atom in ref_atoms:
+ # Get atom name
+ atom_name = ref_atom.GetProp("name")
+ idx = ref_atom.GetIdx()
+
+ # Get conformer coordinates
+ ref_coords = ref_conformer.GetAtomPosition(idx)
+ ref_coords = (ref_coords.x, ref_coords.y, ref_coords.z)
+
+ # Set 0 coordinate
+ atom_is_present = True
+ coords = (0, 0, 0)
+
+ # Add atom to list
+ atoms.append(
+ ParsedAtom(
+ name=atom_name,
+ element=ref_atom.GetAtomicNum(),
+ charge=ref_atom.GetFormalCharge(),
+ coords=coords,
+ conformer=ref_coords,
+ is_present=atom_is_present,
+ chirality=const.chirality_type_ids.get(
+ ref_atom.GetChiralTag(), unk_chirality
+ ),
+ )
+ )
+
+ atom_center = const.res_to_center_atom_id[res_name]
+ atom_disto = const.res_to_disto_atom_id[res_name]
+ parsed.append(
+ ParsedResidue(
+ name=res_name,
+ type=const.token_ids[res_name],
+ atoms=atoms,
+ bonds=[],
+ idx=res_idx,
+ atom_center=atom_center,
+ atom_disto=atom_disto,
+ is_standard=True,
+ is_present=True,
+ orig_idx=None,
+ )
+ )
+
+ # Return polymer object
+ return ParsedChain(
+ entity=entity,
+ residues=parsed,
+ type=chain_type,
+ )
+
+
+def parse_boltz_schema( # noqa: C901, PLR0915, PLR0912
+ name: str,
+ schema: dict,
+ ccd: Mapping[str, Mol],
+) -> Target:
+ """Parse a Boltz input yaml / json.
+
+ The input file should be a dictionary with the following format:
+
+ version: 1
+ sequences:
+ - protein:
+ id: A
+ sequence: "MADQLTEEQIAEFKEAFSLF"
+ msa: path/to/msa1.a3m
+ - protein:
+ id: [B, C]
+ sequence: "AKLSILPWGHC"
+ msa: path/to/msa2.a3m
+ - rna:
+ id: D
+ sequence: "GCAUAGC"
+ - ligand:
+ id: E
+ smiles: "CC1=CC=CC=C1"
+ - ligand:
+ id: [F, G]
+ ccd: []
+ constraints:
+ - bond:
+ atom1: [A, 1, CA]
+ atom2: [A, 2, N]
+ - pocket:
+ binder: E
+ contacts: [[B, 1], [B, 2]]
+
+ Parameters
+ ----------
+ name : str
+ A name for the input.
+ schema : dict
+ The input schema.
+ components : dict
+ Dictionary of CCD components.
+
+ Returns
+ -------
+ Target
+ The parsed target.
+
+ """
+ # Assert version 1
+ version = schema.get("version", 1)
+ if version != 1:
+ msg = f"Invalid version {version} in input!"
+ raise ValueError(msg)
+
+ # Disable rdkit warnings
+ blocker = rdBase.BlockLogs() # noqa: F841
+
+ # First group items that have the same type, sequence and modifications
+ items_to_group = {}
+ for item in schema["sequences"]:
+ # Get entity type
+ entity_type = next(iter(item.keys())).lower()
+ if entity_type not in {"protein", "dna", "rna", "ligand"}:
+ msg = f"Invalid entity type: {entity_type}"
+ raise ValueError(msg)
+
+ # Get sequence
+ if entity_type in {"protein", "dna", "rna"}:
+ seq = str(item[entity_type]["sequence"])
+ elif entity_type == "ligand":
+ assert "smiles" in item[entity_type] or "ccd" in item[entity_type]
+ assert "smiles" not in item[entity_type] or "ccd" not in item[entity_type]
+ if "smiles" in item[entity_type]:
+ seq = str(item[entity_type]["smiles"])
+ else:
+ seq = str(item[entity_type]["ccd"])
+ items_to_group.setdefault((entity_type, seq), []).append(item)
+
+ # Go through entities and parse them
+ chains: dict[str, ParsedChain] = {}
+ chain_to_msa: dict[str, str] = {}
+ chain_to_moltype: dict[str, int] = {}
+ for entity_id, items in enumerate(items_to_group.values()):
+ # Get entity type and sequence
+ entity_type = next(iter(items[0].keys())).lower()
+
+ # Ensure all the items share the same msa
+ msa = -1
+ if entity_type == "protein":
+ if ("msa" not in items[0][entity_type]) or (
+ items[0][entity_type]["msa"] is None
+ ):
+ msg = """
+ Proteins must have an MSA. If you wish to run the model in
+ single sequence mode, please explicitely pass an empty string.
+ """
+ raise ValueError(msg)
+ msa = items[0][entity_type]["msa"]
+ if not all(item[entity_type]["msa"] == msa for item in items):
+ msg = "All proteins with the same sequence must share the same MSA!"
+ raise ValueError(msg)
+
+ # Parse a polymer
+ if entity_type in {"protein", "dna", "rna"}:
+ seq = list(items[0][entity_type]["sequence"])
+ # Apply modifications
+ for modification in items[0][entity_type].get("modifications", []):
+ code = modification["ccd"]
+ idx = modification["position"] - 1 # 1-indexed
+ seq[idx] = code
+
+ # Parse a polymer
+ parsed_chain = parse_polymer(
+ sequence=seq,
+ entity=entity_id,
+ entity_type=entity_type,
+ components=ccd,
+ )
+
+ # Parse a non-polymer
+ elif entity_type == "ligand" and "ccd" in items[0][entity_type]:
+ seq = items[0][entity_type]["ccd"]
+ if isinstance(seq, str):
+ seq = [seq]
+
+ residues = []
+ for code in seq:
+ if code not in ccd:
+ msg = f"CCD component {code} not found!"
+ raise ValueError(msg)
+
+ # Parse residue
+ residue = parse_ccd_residue(
+ name=code,
+ ref_mol=ccd[code],
+ res_idx=0,
+ )
+ residues.append(residue)
+
+ # Create multi ligand chain
+ parsed_chain = ParsedChain(
+ entity=entity_id,
+ residues=residues,
+ type=const.chain_type_ids["NONPOLYMER"],
+ )
+ elif entity_type == "ligand" and "smiles" in items[0][entity_type]:
+ seq = items[0][entity_type]["smiles"]
+ mol = AllChem.MolFromSmiles(seq)
+ mol = AllChem.AddHs(mol)
+
+ # Set atom names
+ canonical_order = AllChem.CanonicalRankAtoms(mol)
+ for atom, can_idx in zip(mol.GetAtoms(), canonical_order):
+ atom.SetProp("name", atom.GetSymbol().upper() + str(can_idx + 1))
+
+ success = compute_3d_conformer(mol)
+ if not success:
+ msg = f"Failed to compute 3D conformer for {seq}"
+ raise ValueError(msg)
+
+ mol_no_h = AllChem.RemoveHs(mol)
+ residue = parse_ccd_residue(
+ name="LIG",
+ ref_mol=mol_no_h,
+ res_idx=0,
+ )
+ parsed_chain = ParsedChain(
+ entity=entity_id,
+ residues=[residue],
+ type=const.chain_type_ids["NONPOLYMER"],
+ )
+ else:
+ msg = f"Invalid entity type: {entity_type}"
+ raise ValueError(msg)
+
+ # Convert entity_type to mol_type_id
+ mol_type_id = entity_type.upper()
+ mol_type_id = mol_type_id.replace("LIGAND", "NONPOLYMER")
+ mol_type_id = const.chain_type_ids[mol_type_id]
+
+ for item in items:
+ ids = item[entity_type]["id"]
+ if isinstance(ids, str):
+ ids = [ids]
+ for chain_name in ids:
+ chains[chain_name] = parsed_chain
+ chain_to_msa[chain_name] = msa
+ chain_to_moltype[chain_name] = mol_type_id
+
+ # If no chains parsed fail
+ if not chains:
+ msg = "No chains parsed!"
+ raise ValueError(msg)
+
+ # Create tables
+ atom_data = []
+ bond_data = []
+ res_data = []
+ chain_data = []
+
+ # Convert parsed chains to tables
+ atom_idx = 0
+ res_idx = 0
+ asym_id = 0
+ sym_count = {}
+ chain_to_idx = {}
+
+ # Keep a mapping of (chain_name, residue_idx, atom_name) to atom_idx
+ atom_idx_map = {}
+
+ for asym_id, (chain_name, chain) in enumerate(chains.items()):
+ # Compute number of atoms and residues
+ res_num = len(chain.residues)
+ atom_num = sum(len(res.atoms) for res in chain.residues)
+
+ # Find all copies of this chain in the assembly
+ entity_id = int(chain.entity)
+ sym_id = sym_count.get(entity_id, 0)
+ chain_data.append(
+ (
+ chain_name,
+ chain.type,
+ entity_id,
+ sym_id,
+ asym_id,
+ atom_idx,
+ atom_num,
+ res_idx,
+ res_num,
+ )
+ )
+ chain_to_idx[chain_name] = asym_id
+ sym_count[entity_id] = sym_id + 1
+
+ # Add residue, atom, bond, data
+ for res in chain.residues:
+ atom_center = atom_idx + res.atom_center
+ atom_disto = atom_idx + res.atom_disto
+ res_data.append(
+ (
+ res.name,
+ res.type,
+ res.idx,
+ atom_idx,
+ len(res.atoms),
+ atom_center,
+ atom_disto,
+ res.is_standard,
+ res.is_present,
+ )
+ )
+
+ for bond in res.bonds:
+ atom_1 = atom_idx + bond.atom_1
+ atom_2 = atom_idx + bond.atom_2
+ bond_data.append((atom_1, atom_2, bond.type))
+
+ for atom in res.atoms:
+ # Add atom to map
+ atom_idx_map[(chain_name, res.idx, atom.name)] = (
+ asym_id,
+ res_idx,
+ atom_idx,
+ )
+
+ # Add atom to data
+ atom_data.append(
+ (
+ convert_atom_name(atom.name),
+ atom.element,
+ atom.charge,
+ atom.coords,
+ atom.conformer,
+ atom.is_present,
+ atom.chirality,
+ )
+ )
+ atom_idx += 1
+
+ res_idx += 1
+
+ # Parse constraints
+ connections = []
+ constraints = schema.get("constraints", [])
+ for constraint in constraints:
+ if "bond" in constraint:
+ c1, r1, a1 = atom_idx_map[tuple(constraint["bond"]["atom1"])]
+ c2, r2, a2 = atom_idx_map[tuple(constraint["bond"]["atom2"])]
+ connections.append((c1, c2, r1 - 1, r2 - 1, a1, a2)) # 1-indexed
+
+ elif "pocket" in constraint:
+ binder = constraint["pocket"]["binder"]
+ contacts = constraint["pocket"]["contacts"]
+ msg = f"Pocket constraints not implemented yet: {binder} - {contacts}"
+ raise NotImplementedError(msg)
+ else:
+ msg = f"Invalid constraint: {constraint}"
+ raise ValueError(msg)
+
+ # Convert into datatypes
+ atoms = np.array(atom_data, dtype=Atom)
+ bonds = np.array(bond_data, dtype=Bond)
+ residues = np.array(res_data, dtype=Residue)
+ chains = np.array(chain_data, dtype=Chain)
+ interfaces = np.array([], dtype=Interface)
+ connections = np.array(connections, dtype=Connection)
+ mask = np.ones(len(chain_data), dtype=bool)
+
+ data = Structure(
+ atoms=atoms,
+ bonds=bonds,
+ residues=residues,
+ chains=chains,
+ connections=connections,
+ interfaces=interfaces,
+ mask=mask,
+ )
+
+ # Create metadata
+ struct_info = StructureInfo(num_chains=len(chains))
+ chain_infos = []
+ for chain_id, chain in enumerate(chains):
+ chain_info = ChainInfo(
+ chain_id=chain_id,
+ chain_name=chain["name"],
+ mol_type=chain_to_moltype[chain["name"]],
+ cluster_id=-1,
+ msa_id=chain_to_msa[chain["name"]],
+ num_residues=int(chain["res_num"]),
+ valid=True,
+ )
+ chain_infos.append(chain_info)
+
+ record = Record(
+ id=name,
+ structure=struct_info,
+ chains=chain_infos,
+ interfaces=[],
+ )
+ return Target(record=record, structure=data)
diff --git a/src/boltz/data/parse/yaml.py b/src/boltz/data/parse/yaml.py
new file mode 100644
index 0000000..c596d61
--- /dev/null
+++ b/src/boltz/data/parse/yaml.py
@@ -0,0 +1,57 @@
+from pathlib import Path
+
+import yaml
+from rdkit.Chem.rdchem import Mol
+
+from boltz.data.parse.schema import parse_boltz_schema
+from boltz.data.types import Target
+
+
+def parse_yaml(path: Path, ccd: dict[str, Mol]) -> Target:
+ """Parse a Boltz input yaml / json.
+
+ The input file should be a yaml file with the following format:
+
+ sequences:
+ - protein:
+ id: A
+ sequence: "MADQLTEEQIAEFKEAFSLF"
+ - protein:
+ id: [B, C]
+ sequence: "AKLSILPWGHC"
+ - rna:
+ id: D
+ sequence: "GCAUAGC"
+ - ligand:
+ id: E
+ smiles: "CC1=CC=CC=C1"
+ - ligand:
+ id: [F, G]
+ ccd: []
+ constraints:
+ - bond:
+ atom1: [A, 1, CA]
+ atom2: [A, 2, N]
+ - pocket:
+ binder: E
+ contacts: [[B, 1], [B, 2]]
+ version: 1
+
+ Parameters
+ ----------
+ path : Path
+ Path to the YAML input format.
+ components : Dict
+ Dictionary of CCD components.
+
+ Returns
+ -------
+ Target
+ The parsed target.
+
+ """
+ with path.open("r") as file:
+ data = yaml.safe_load(file)
+
+ name = path.stem
+ return parse_boltz_schema(name, data, ccd)
diff --git a/src/boltz/data/sample/__init__.py b/src/boltz/data/sample/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/boltz/data/sample/cluster.py b/src/boltz/data/sample/cluster.py
new file mode 100644
index 0000000..fb5c2e6
--- /dev/null
+++ b/src/boltz/data/sample/cluster.py
@@ -0,0 +1,283 @@
+from typing import Dict, Iterator, List
+
+import numpy as np
+from numpy.random import RandomState
+
+from boltz.data import const
+from boltz.data.types import ChainInfo, InterfaceInfo, Record
+from boltz.data.sample.sampler import Sample, Sampler
+
+
+def get_chain_cluster(chain: ChainInfo, record: Record) -> str: # noqa: ARG001
+ """Get the cluster id for a chain.
+
+ Parameters
+ ----------
+ chain : ChainInfo
+ The chain id to get the cluster id for.
+ record : Record
+ The record the interface is part of.
+
+ Returns
+ -------
+ str
+ The cluster id of the chain.
+
+ """
+ return chain.cluster_id
+
+
+def get_interface_cluster(interface: InterfaceInfo, record: Record) -> str:
+ """Get the cluster id for an interface.
+
+ Parameters
+ ----------
+ interface : InterfaceInfo
+ The interface to get the cluster id for.
+ record : Record
+ The record the interface is part of.
+
+ Returns
+ -------
+ str
+ The cluster id of the interface.
+
+ """
+ chain1 = record.chains[interface.chain_1]
+ chain2 = record.chains[interface.chain_2]
+
+ cluster_1 = str(chain1.cluster_id)
+ cluster_2 = str(chain2.cluster_id)
+
+ cluster_id = (cluster_1, cluster_2)
+ cluster_id = tuple(sorted(cluster_id))
+
+ return cluster_id
+
+
+def get_chain_weight(
+ chain: ChainInfo,
+ record: Record, # noqa: ARG001
+ clusters: Dict[str, int],
+ beta_chain: float,
+ alpha_prot: float,
+ alpha_nucl: float,
+ alpha_ligand: float,
+) -> float:
+ """Get the weight of a chain.
+
+ Parameters
+ ----------
+ chain : ChainInfo
+ The chain to get the weight for.
+ record : Record
+ The record the chain is part of.
+ clusters : Dict[str, int]
+ The cluster sizes.
+ beta_chain : float
+ The beta value for chains.
+ alpha_prot : float
+ The alpha value for proteins.
+ alpha_nucl : float
+ The alpha value for nucleic acids.
+ alpha_ligand : float
+ The alpha value for ligands.
+
+ Returns
+ -------
+ float
+ The weight of the chain.
+
+ """
+ prot_id = const.chain_type_ids["PROTEIN"]
+ rna_id = const.chain_type_ids["RNA"]
+ dna_id = const.chain_type_ids["DNA"]
+ ligand_id = const.chain_type_ids["NONPOLYMER"]
+
+ weight = beta_chain / clusters[chain.cluster_id]
+ if chain.mol_type == prot_id:
+ weight *= alpha_prot
+ elif chain.mol_type in [rna_id, dna_id]:
+ weight *= alpha_nucl
+ elif chain.mol_type == ligand_id:
+ weight *= alpha_ligand
+
+ return weight
+
+
+def get_interface_weight(
+ interface: InterfaceInfo,
+ record: Record,
+ clusters: Dict[str, int],
+ beta_interface: float,
+ alpha_prot: float,
+ alpha_nucl: float,
+ alpha_ligand: float,
+) -> float:
+ """Get the weight of an interface.
+
+ Parameters
+ ----------
+ interface : InterfaceInfo
+ The interface to get the weight for.
+ record : Record
+ The record the interface is part of.
+ clusters : Dict[str, int]
+ The cluster sizes.
+ beta_interface : float
+ The beta value for interfaces.
+ alpha_prot : float
+ The alpha value for proteins.
+ alpha_nucl : float
+ The alpha value for nucleic acids.
+ alpha_ligand : float
+ The alpha value for ligands.
+
+ Returns
+ -------
+ float
+ The weight of the interface.
+
+ """
+ prot_id = const.chain_type_ids["PROTEIN"]
+ rna_id = const.chain_type_ids["RNA"]
+ dna_id = const.chain_type_ids["DNA"]
+ ligand_id = const.chain_type_ids["NONPOLYMER"]
+
+ chain1 = record.chains[interface.chain_1]
+ chain2 = record.chains[interface.chain_2]
+
+ n_prot = (chain1.mol_type) == prot_id
+ n_nuc = chain1.mol_type in [rna_id, dna_id]
+ n_ligand = chain1.mol_type == ligand_id
+
+ n_prot += chain2.mol_type == prot_id
+ n_nuc += chain2.mol_type in [rna_id, dna_id]
+ n_ligand += chain2.mol_type == ligand_id
+
+ weight = beta_interface / clusters[get_interface_cluster(interface, record)]
+ weight *= alpha_prot * n_prot + alpha_nucl * n_nuc + alpha_ligand * n_ligand
+ return weight
+
+
+class ClusterSampler(Sampler):
+ """The weighted sampling approach, as described in AF3.
+
+ Each chain / interface is given a weight according
+ to the following formula, and sampled accordingly:
+
+ w = b / n_clust *(a_prot * n_prot + a_nuc * n_nuc
+ + a_ligand * n_ligand)
+
+ """
+
+ def __init__(
+ self,
+ alpha_prot: float = 3.0,
+ alpha_nucl: float = 3.0,
+ alpha_ligand: float = 1.0,
+ beta_chain: float = 0.5,
+ beta_interface: float = 1.0,
+ ) -> None:
+ """Initialize the sampler.
+
+ Parameters
+ ----------
+ alpha_prot : float, optional
+ The alpha value for proteins.
+ alpha_nucl : float, optional
+ The alpha value for nucleic acids.
+ alpha_ligand : float, optional
+ The alpha value for ligands.
+ beta_chain : float, optional
+ The beta value for chains.
+ beta_interface : float, optional
+ The beta value for interfaces.
+
+ """
+ self.alpha_prot = alpha_prot
+ self.alpha_nucl = alpha_nucl
+ self.alpha_ligand = alpha_ligand
+ self.beta_chain = beta_chain
+ self.beta_interface = beta_interface
+
+ def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]: # noqa: C901, PLR0912
+ """Sample a structure from the dataset infinitely.
+
+ Parameters
+ ----------
+ records : List[Record]
+ The records to sample from.
+ random : RandomState
+ The random state for reproducibility.
+
+ Yields
+ ------
+ Sample
+ A data sample.
+
+ """
+ # Compute chain cluster sizes
+ chain_clusters: Dict[str, int] = {}
+ for record in records:
+ for chain in record.chains:
+ if not chain.valid:
+ continue
+ cluster_id = get_chain_cluster(chain, record)
+ if cluster_id not in chain_clusters:
+ chain_clusters[cluster_id] = 0
+ chain_clusters[cluster_id] += 1
+
+ # Compute interface clusters sizes
+ interface_clusters: Dict[str, int] = {}
+ for record in records:
+ for interface in record.interfaces:
+ if not interface.valid:
+ continue
+ cluster_id = get_interface_cluster(interface, record)
+ if cluster_id not in interface_clusters:
+ interface_clusters[cluster_id] = 0
+ interface_clusters[cluster_id] += 1
+
+ # Compute weights
+ items, weights = [], []
+ for record in records:
+ for chain_id, chain in enumerate(record.chains):
+ if not chain.valid:
+ continue
+ weight = get_chain_weight(
+ chain,
+ record,
+ chain_clusters,
+ self.beta_chain,
+ self.alpha_prot,
+ self.alpha_nucl,
+ self.alpha_ligand,
+ )
+ items.append((record, 0, chain_id))
+ weights.append(weight)
+
+ for int_id, interface in enumerate(record.interfaces):
+ if not interface.valid:
+ continue
+ weight = get_interface_weight(
+ interface,
+ record,
+ interface_clusters,
+ self.beta_interface,
+ self.alpha_prot,
+ self.alpha_nucl,
+ self.alpha_ligand,
+ )
+ items.append((record, 1, int_id))
+ weights.append(weight)
+
+ # Sample infinitely
+ weights = np.array(weights) / np.sum(weights)
+ while True:
+ item_idx = random.choice(len(items), p=weights)
+ record, kind, index = items[item_idx]
+ if kind == 0:
+ yield Sample(record=record, chain_id=index)
+ else:
+ yield Sample(record=record, interface_id=index)
diff --git a/src/boltz/data/sample/distillation.py b/src/boltz/data/sample/distillation.py
new file mode 100644
index 0000000..9314f51
--- /dev/null
+++ b/src/boltz/data/sample/distillation.py
@@ -0,0 +1,57 @@
+from typing import Iterator, List
+
+from numpy.random import RandomState
+
+from boltz.data.types import Record
+from boltz.data.sample.sampler import Sample, Sampler
+
+
+class DistillationSampler(Sampler):
+ """A sampler for monomer distillation data."""
+
+ def __init__(self, small_size: int = 200, small_prob: float = 0.01) -> None:
+ """Initialize the sampler.
+
+ Parameters
+ ----------
+ small_size : int, optional
+ The maximum size to be considered small.
+ small_prob : float, optional
+ The probability of sampling a small item.
+
+ """
+ self._size = small_size
+ self._prob = small_prob
+
+ def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]:
+ """Sample a structure from the dataset infinitely.
+
+ Parameters
+ ----------
+ records : List[Record]
+ The records to sample from.
+ random : RandomState
+ The random state for reproducibility.
+
+ Yields
+ ------
+ Sample
+ A data sample.
+
+ """
+ # Remove records with invalid chains
+ records = [r for r in records if r.chains[0].valid]
+
+ # Split in small and large proteins. We assume that there is only
+ # one chain per record, as is the case for monomer distillation
+ small = [r for r in records if r.chains[0].num_residues <= self._size]
+ large = [r for r in records if r.chains[0].num_residues > self._size]
+
+ # Sample infinitely
+ while True:
+ # Sample small or large
+ samples = small if random.rand() < self._prob else large
+
+ # Sample item from the list
+ index = random.randint(0, len(samples))
+ yield Sample(record=samples[index])
diff --git a/src/boltz/data/sample/random.py b/src/boltz/data/sample/random.py
new file mode 100644
index 0000000..e2ee231
--- /dev/null
+++ b/src/boltz/data/sample/random.py
@@ -0,0 +1,39 @@
+from dataclasses import replace
+from typing import Iterator, List
+
+from numpy.random import RandomState
+
+from boltz.data.types import Record
+from boltz.data.sample.sampler import Sample, Sampler
+
+
+class RandomSampler(Sampler):
+ """A simple random sampler with replacement."""
+
+ def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]:
+ """Sample a structure from the dataset infinitely.
+
+ Parameters
+ ----------
+ records : List[Record]
+ The records to sample from.
+ random : RandomState
+ The random state for reproducibility.
+
+ Yields
+ ------
+ Sample
+ A data sample.
+
+ """
+ while True:
+ # Sample item from the list
+ index = random.randint(0, len(records))
+ record = records[index]
+
+ # Remove invalid chains and interfaces
+ chains = [c for c in record.chains if c.valid]
+ interfaces = [i for i in record.interfaces if i.valid]
+ record = replace(record, chains=chains, interfaces=interfaces)
+
+ yield Sample(record=record)
diff --git a/src/boltz/data/sample/sampler.py b/src/boltz/data/sample/sampler.py
new file mode 100644
index 0000000..6c6ab6d
--- /dev/null
+++ b/src/boltz/data/sample/sampler.py
@@ -0,0 +1,49 @@
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from typing import Iterator, List, Optional
+
+from numpy.random import RandomState
+
+from boltz.data.types import Record
+
+
+@dataclass
+class Sample:
+ """A sample with optional chain and interface IDs.
+
+ Attributes
+ ----------
+ record : Record
+ The record.
+ chain_id : Optional[int]
+ The chain ID.
+ interface_id : Optional[int]
+ The interface ID.
+ """
+
+ record: Record
+ chain_id: Optional[int] = None
+ interface_id: Optional[int] = None
+
+
+class Sampler(ABC):
+ """Abstract base class for samplers."""
+
+ @abstractmethod
+ def sample(self, records: List[Record], random: RandomState) -> Iterator[Sample]:
+ """Sample a structure from the dataset infinitely.
+
+ Parameters
+ ----------
+ records : List[Record]
+ The records to sample from.
+ random : RandomState
+ The random state for reproducibility.
+
+ Yields
+ ------
+ Sample
+ A data sample.
+
+ """
+ raise NotImplementedError
diff --git a/src/boltz/data/tokenize/__init__.py b/src/boltz/data/tokenize/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/boltz/data/tokenize/boltz.py b/src/boltz/data/tokenize/boltz.py
new file mode 100644
index 0000000..621c653
--- /dev/null
+++ b/src/boltz/data/tokenize/boltz.py
@@ -0,0 +1,191 @@
+from dataclasses import astuple, dataclass
+
+import numpy as np
+
+from boltz.data import const
+from boltz.data.tokenize.tokenizer import Tokenizer
+from boltz.data.types import Input, Token, TokenBond, Tokenized
+
+
+@dataclass
+class TokenData:
+ """TokenData datatype."""
+
+ token_idx: int
+ atom_idx: int
+ atom_num: int
+ res_idx: int
+ res_type: int
+ sym_id: int
+ asym_id: int
+ entity_id: int
+ mol_type: int
+ center_idx: int
+ disto_idx: int
+ center_coords: np.ndarray
+ disto_coords: np.ndarray
+ resolved_mask: bool
+ disto_mask: bool
+
+
+class BoltzTokenizer(Tokenizer):
+ """Tokenize an input structure for training."""
+
+ def tokenize(self, data: Input) -> Tokenized:
+ """Tokenize the input data.
+
+ Parameters
+ ----------
+ data : Inpput
+ The input data.
+
+ Returns
+ -------
+ Tokenized
+ The tokenized data.
+
+ """
+ # Get structure data
+ struct = data.structure
+
+ # Create token data
+ token_data = []
+
+ # Keep track of atom_idx to token_idx
+ token_idx = 0
+ atom_to_token = {}
+
+ # Filter to valid chains only
+ chains = struct.chains[struct.mask]
+
+ for chain in chains:
+ # Get residue indices
+ res_start = chain["res_idx"]
+ res_end = chain["res_idx"] + chain["res_num"]
+
+ for res in struct.residues[res_start:res_end]:
+ # Get atom indices
+ atom_start = res["atom_idx"]
+ atom_end = res["atom_idx"] + res["atom_num"]
+
+ # Standard residues are tokens
+ if res["is_standard"]:
+ # Get center and disto atoms
+ center = struct.atoms[res["atom_center"]]
+ disto = struct.atoms[res["atom_disto"]]
+
+ # Token is present if centers are
+ is_present = res["is_present"] & center["is_present"]
+ is_disto_present = res["is_present"] & disto["is_present"]
+
+ # Apply chain transformation
+ c_coords = center["coords"]
+ d_coords = disto["coords"]
+
+ # Create token
+ token = TokenData(
+ token_idx=token_idx,
+ atom_idx=res["atom_idx"],
+ atom_num=res["atom_num"],
+ res_idx=res["res_idx"],
+ res_type=res["res_type"],
+ sym_id=chain["sym_id"],
+ asym_id=chain["asym_id"],
+ entity_id=chain["entity_id"],
+ mol_type=chain["mol_type"],
+ center_idx=res["atom_center"],
+ disto_idx=res["atom_disto"],
+ center_coords=c_coords,
+ disto_coords=d_coords,
+ resolved_mask=is_present,
+ disto_mask=is_disto_present,
+ )
+ token_data.append(astuple(token))
+
+ # Update atom_idx to token_idx
+ for atom_idx in range(atom_start, atom_end):
+ atom_to_token[atom_idx] = token_idx
+
+ token_idx += 1
+
+ # Non-standard are tokenized per atom
+ else:
+ # We use the unk protein token as res_type
+ unk_token = const.unk_token["PROTEIN"]
+ unk_id = const.token_ids[unk_token]
+
+ # Get atom coordinates
+ atom_data = struct.atoms[atom_start:atom_end]
+ atom_coords = atom_data["coords"]
+
+ # Tokenize each atom
+ for i, atom in enumerate(atom_data):
+ # Token is present if atom is
+ is_present = res["is_present"] & atom["is_present"]
+ index = atom_start + i
+
+ # Create token
+ token = TokenData(
+ token_idx=token_idx,
+ atom_idx=index,
+ atom_num=1,
+ res_idx=res["res_idx"],
+ res_type=unk_id,
+ sym_id=chain["sym_id"],
+ asym_id=chain["asym_id"],
+ entity_id=chain["entity_id"],
+ mol_type=chain["mol_type"],
+ center_idx=index,
+ disto_idx=index,
+ center_coords=atom_coords[i],
+ disto_coords=atom_coords[i],
+ resolved_mask=is_present,
+ disto_mask=is_present,
+ )
+ token_data.append(astuple(token))
+
+ # Update atom_idx to token_idx
+ atom_to_token[index] = token_idx
+ token_idx += 1
+
+ # Create token bonds
+ token_bonds = []
+
+ # Add atom-atom bonds from ligands
+ for bond in struct.bonds:
+ if (
+ bond["atom_1"] not in atom_to_token
+ or bond["atom_2"] not in atom_to_token
+ ):
+ continue
+ token_bond = (
+ atom_to_token[bond["atom_1"]],
+ atom_to_token[bond["atom_2"]],
+ )
+ token_bonds.append(token_bond)
+
+ # Add connection bonds (covalent)
+ for conn in struct.connections:
+ if (
+ conn["atom_1"] not in atom_to_token
+ or conn["atom_2"] not in atom_to_token
+ ):
+ continue
+ token_bond = (
+ atom_to_token[conn["atom_1"]],
+ atom_to_token[conn["atom_2"]],
+ )
+ token_bonds.append(token_bond)
+
+ # Consider adding missing bond for modified residues to standard?
+ # I'm not sure it's necessary because the bond is probably always
+ # the same and the model can use the residue indices to infer it
+ token_data = np.array(token_data, dtype=Token)
+ token_bonds = np.array(token_bonds, dtype=TokenBond)
+ tokenized = Tokenized(
+ token_data,
+ token_bonds,
+ data.structure,
+ data.msa,
+ )
+ return tokenized
diff --git a/src/boltz/data/tokenize/tokenizer.py b/src/boltz/data/tokenize/tokenizer.py
new file mode 100644
index 0000000..5304cd9
--- /dev/null
+++ b/src/boltz/data/tokenize/tokenizer.py
@@ -0,0 +1,24 @@
+from abc import ABC, abstractmethod
+
+from boltz.data.types import Input, Tokenized
+
+
+class Tokenizer(ABC):
+ """Tokenize an input structure for training."""
+
+ @abstractmethod
+ def tokenize(self, data: Input) -> Tokenized:
+ """Tokenize the input data.
+
+ Parameters
+ ----------
+ data : Inpput
+ The input data.
+
+ Returns
+ -------
+ Tokenized
+ The tokenized data.
+
+ """
+ raise NotImplementedError
diff --git a/src/boltz/data/types.py b/src/boltz/data/types.py
new file mode 100644
index 0000000..e3b4afb
--- /dev/null
+++ b/src/boltz/data/types.py
@@ -0,0 +1,478 @@
+import json
+from dataclasses import asdict, dataclass
+from pathlib import Path
+from typing import Optional, Union
+
+import numpy as np
+from mashumaro.mixins.dict import DataClassDictMixin
+
+####################################################################################################
+# SERIALIZABLE
+####################################################################################################
+
+
+class NumpySerializable:
+ """Serializable datatype."""
+
+ @classmethod
+ def load(cls: "NumpySerializable", path: Path) -> "NumpySerializable":
+ """Load the object from an NPZ file.
+
+ Parameters
+ ----------
+ path : Path
+ The path to the file.
+
+ Returns
+ -------
+ Serializable
+ The loaded object.
+
+ """
+ return cls(**np.load(path))
+
+ def dump(self, path: Path) -> None:
+ """Dump the object to an NPZ file.
+
+ Parameters
+ ----------
+ path : Path
+ The path to the file.
+
+ """
+ np.savez_compressed(str(path), **asdict(self))
+
+
+class JSONSerializable(DataClassDictMixin):
+ """Serializable datatype."""
+
+ @classmethod
+ def load(cls: "JSONSerializable", path: Path) -> "JSONSerializable":
+ """Load the object from a JSON file.
+
+ Parameters
+ ----------
+ path : Path
+ The path to the file.
+
+ Returns
+ -------
+ Serializable
+ The loaded object.
+
+ """
+ with path.open("r") as f:
+ return cls.from_dict(json.load(f))
+
+ def dump(self, path: Path) -> None:
+ """Dump the object to a JSON file.
+
+ Parameters
+ ----------
+ path : Path
+ The path to the file.
+
+ """
+ with path.open("w") as f:
+ json.dump(self.to_dict(), f)
+
+
+####################################################################################################
+# STRUCTURE
+####################################################################################################
+
+Atom = [
+ ("name", np.dtype("4i1")),
+ ("element", np.dtype("i1")),
+ ("charge", np.dtype("i1")),
+ ("coords", np.dtype("3f4")),
+ ("conformer", np.dtype("3f4")),
+ ("is_present", np.dtype("?")),
+ ("chirality", np.dtype("i1")),
+]
+
+Bond = [
+ ("atom_1", np.dtype("i4")),
+ ("atom_2", np.dtype("i4")),
+ ("type", np.dtype("i1")),
+]
+
+Residue = [
+ ("name", np.dtype(" "Structure":
+ """Load a structure from an NPZ file.
+
+ Parameters
+ ----------
+ path : Path
+ The path to the file.
+
+ Returns
+ -------
+ Structure
+ The loaded structure.
+
+ """
+ structure = np.load(path)
+ return cls(
+ atoms=structure["atoms"],
+ bonds=structure["bonds"],
+ residues=structure["residues"],
+ chains=structure["chains"],
+ connections=structure["connections"].astype(Connection),
+ interfaces=structure["interfaces"],
+ mask=structure["mask"],
+ )
+
+ def remove_invalid_chains(self) -> "Structure": # noqa: PLR0915
+ """Remove invalid chains.
+
+ Parameters
+ ----------
+ structure : Structure
+ The structure to process.
+
+ Returns
+ -------
+ Structure
+ The structure with masked chains removed.
+
+ """
+ entity_counter = {}
+ atom_idx, res_idx, chain_idx = 0, 0, 0
+ atoms, residues, chains = [], [], []
+ atom_map, res_map, chain_map = {}, {}, {}
+ for i, chain in enumerate(self.chains):
+ # Skip masked chains
+ if not self.mask[i]:
+ continue
+
+ # Update entity counter
+ entity_id = chain["entity_id"]
+ if entity_id not in entity_counter:
+ entity_counter[entity_id] = 0
+ else:
+ entity_counter[entity_id] += 1
+
+ # Update the chain
+ new_chain = chain.copy()
+ new_chain["atom_idx"] = atom_idx
+ new_chain["res_idx"] = res_idx
+ new_chain["asym_id"] = chain_idx
+ new_chain["sym_id"] = entity_counter[entity_id]
+ chains.append(new_chain)
+ chain_map[i] = chain_idx
+ chain_idx += 1
+
+ # Add the chain residues
+ res_start = chain["res_idx"]
+ res_end = chain["res_idx"] + chain["res_num"]
+ for j, res in enumerate(self.residues[res_start:res_end]):
+ # Update the residue
+ new_res = res.copy()
+ new_res["atom_idx"] = atom_idx
+ new_res["atom_center"] = (
+ atom_idx + new_res["atom_center"] - res["atom_idx"]
+ )
+ new_res["atom_disto"] = (
+ atom_idx + new_res["atom_disto"] - res["atom_idx"]
+ )
+ residues.append(new_res)
+ res_map[res_start + j] = res_idx
+ res_idx += 1
+
+ # Update the atoms
+ start = res["atom_idx"]
+ end = res["atom_idx"] + res["atom_num"]
+ atoms.append(self.atoms[start:end])
+ atom_map.update({k: atom_idx + k - start for k in range(start, end)})
+ atom_idx += res["atom_num"]
+
+ # Concatenate the tables
+ atoms = np.concatenate(atoms, dtype=Atom)
+ residues = np.array(residues, dtype=Residue)
+ chains = np.array(chains, dtype=Chain)
+
+ # Update bonds
+ bonds = []
+ for bond in self.bonds:
+ atom_1 = bond["atom_1"]
+ atom_2 = bond["atom_2"]
+ if (atom_1 in atom_map) and (atom_2 in atom_map):
+ new_bond = bond.copy()
+ new_bond["atom_1"] = atom_map[atom_1]
+ new_bond["atom_2"] = atom_map[atom_2]
+ bonds.append(new_bond)
+
+ # Update connections
+ connections = []
+ for connection in self.connections:
+ chain_1 = connection["chain_1"]
+ chain_2 = connection["chain_2"]
+ res_1 = connection["res_1"]
+ res_2 = connection["res_2"]
+ atom_1 = connection["atom_1"]
+ atom_2 = connection["atom_2"]
+ if (atom_1 in atom_map) and (atom_2 in atom_map):
+ new_connection = connection.copy()
+ new_connection["chain_1"] = chain_map[chain_1]
+ new_connection["chain_2"] = chain_map[chain_2]
+ new_connection["res_1"] = res_map[res_1]
+ new_connection["res_2"] = res_map[res_2]
+ new_connection["atom_1"] = atom_map[atom_1]
+ new_connection["atom_2"] = atom_map[atom_2]
+ connections.append(new_connection)
+
+ # Create arrays
+ bonds = np.array(bonds, dtype=Bond)
+ connections = np.array(connections, dtype=Connection)
+ interfaces = np.array([], dtype=Interface)
+ mask = np.ones(len(chains), dtype=bool)
+
+ return Structure(
+ atoms=atoms,
+ bonds=bonds,
+ residues=residues,
+ chains=chains,
+ connections=connections,
+ interfaces=interfaces,
+ mask=mask,
+ )
+
+
+####################################################################################################
+# MSA
+####################################################################################################
+
+
+MSAResidue = [
+ ("res_type", np.dtype("i1")),
+]
+
+MSADeletion = [
+ ("res_idx", np.dtype("i2")),
+ ("deletion", np.dtype("i2")),
+]
+
+MSASequence = [
+ ("seq_idx", np.dtype("i2")),
+ ("taxonomy", np.dtype("i4")),
+ ("res_start", np.dtype("i4")),
+ ("res_end", np.dtype("i4")),
+ ("del_start", np.dtype("i4")),
+ ("del_end", np.dtype("i4")),
+]
+
+
+@dataclass(frozen=True)
+class MSA(NumpySerializable):
+ """MSA datatype."""
+
+ sequences: np.ndarray
+ deletions: np.ndarray
+ residues: np.ndarray
+
+
+####################################################################################################
+# RECORD
+####################################################################################################
+
+
+@dataclass(frozen=True)
+class StructureInfo:
+ """StructureInfo datatype."""
+
+ resolution: Optional[float] = None
+ method: Optional[str] = None
+ deposited: Optional[str] = None
+ released: Optional[str] = None
+ revised: Optional[str] = None
+ num_chains: Optional[int] = None
+ num_interfaces: Optional[int] = None
+
+
+@dataclass(frozen=False)
+class ChainInfo:
+ """ChainInfo datatype."""
+
+ chain_id: int
+ chain_name: str
+ mol_type: int
+ cluster_id: Union[str, int]
+ msa_id: Union[str, int]
+ num_residues: int
+ valid: bool = True
+
+
+@dataclass(frozen=True)
+class InterfaceInfo:
+ """InterfaceInfo datatype."""
+
+ chain_1: int
+ chain_2: int
+ valid: bool = True
+
+
+@dataclass(frozen=True)
+class Record(JSONSerializable):
+ """Record datatype."""
+
+ id: str
+ structure: StructureInfo
+ chains: list[ChainInfo]
+ interfaces: list[InterfaceInfo]
+
+
+####################################################################################################
+# TARGET
+####################################################################################################
+
+
+@dataclass(frozen=True)
+class Target:
+ """Target datatype."""
+
+ record: Record
+ structure: Structure
+
+
+@dataclass(frozen=True)
+class Manifest(JSONSerializable):
+ """Manifest datatype."""
+
+ records: list[Record]
+
+ @classmethod
+ def load(cls: "JSONSerializable", path: Path) -> "JSONSerializable":
+ """Load the object from a JSON file.
+
+ Parameters
+ ----------
+ path : Path
+ The path to the file.
+
+ Returns
+ -------
+ Serializable
+ The loaded object.
+
+ Raises
+ ------
+ TypeError
+ If the file is not a valid manifest file.
+
+ """
+ with path.open("r") as f:
+ data = json.load(f)
+ if isinstance(data, dict):
+ manifest = cls.from_dict(data)
+ elif isinstance(data, list):
+ records = [Record.from_dict(r) for r in data]
+ manifest = cls(records=records)
+ else:
+ msg = "Invalid manifest file."
+ raise TypeError(msg)
+
+ return manifest
+
+
+####################################################################################################
+# INPUT
+####################################################################################################
+
+
+@dataclass(frozen=True)
+class Input:
+ """Input datatype."""
+
+ structure: Structure
+ msa: dict[str, MSA]
+ record: Optional[Record] = None
+
+
+####################################################################################################
+# TOKENS
+####################################################################################################
+
+Token = [
+ ("token_idx", np.dtype("i4")),
+ ("atom_idx", np.dtype("i4")),
+ ("atom_num", np.dtype("i4")),
+ ("res_idx", np.dtype("i4")),
+ ("res_type", np.dtype("i1")),
+ ("sym_id", np.dtype("i4")),
+ ("asym_id", np.dtype("i4")),
+ ("entity_id", np.dtype("i4")),
+ ("mol_type", np.dtype("i1")),
+ ("center_idx", np.dtype("i4")),
+ ("disto_idx", np.dtype("i4")),
+ ("center_coords", np.dtype("3f4")),
+ ("disto_coords", np.dtype("3f4")),
+ ("resolved_mask", np.dtype("?")),
+ ("disto_mask", np.dtype("?")),
+]
+
+TokenBond = [
+ ("token_1", np.dtype("i4")),
+ ("token_2", np.dtype("i4")),
+]
+
+
+@dataclass(frozen=True)
+class Tokenized:
+ """Tokenized datatype."""
+
+ tokens: np.ndarray
+ bonds: np.ndarray
+ structure: Structure
+ msa: dict[str, MSA]
diff --git a/src/boltz/data/write/__init__.py b/src/boltz/data/write/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/boltz/data/write/mmcif.py b/src/boltz/data/write/mmcif.py
new file mode 100644
index 0000000..26d525a
--- /dev/null
+++ b/src/boltz/data/write/mmcif.py
@@ -0,0 +1,190 @@
+import io
+from typing import Iterator
+
+import ihm
+from modelcif import Assembly, AsymUnit, Entity, System, dumper
+from modelcif.model import AbInitioModel, Atom, ModelGroup
+from rdkit import Chem
+
+from boltz.data import const
+from boltz.data.types import Structure
+from boltz.data.write.utils import generate_tags
+
+
+def to_mmcif(structure: Structure) -> str: # noqa: C901
+ """Write a structure into an MMCIF file.
+
+ Parameters
+ ----------
+ structure : Structure
+ The input structure
+
+ Returns
+ -------
+ str
+ the output MMCIF file
+
+ """
+ system = System()
+
+ # Load periodic table for element mapping
+ periodic_table = Chem.GetPeriodicTable()
+
+ # Map entities to chain_ids
+ entity_to_chains = {}
+ entity_to_moltype = {}
+
+ for chain in structure.chains:
+ entity_id = chain["entity_id"]
+ mol_type = chain["mol_type"]
+ entity_to_chains.setdefault(entity_id, []).append(chain)
+ entity_to_moltype[entity_id] = mol_type
+
+ # Map entities to sequences
+ sequences = {}
+ for entity in entity_to_chains:
+ # Get the first chain
+ chain = entity_to_chains[entity][0]
+
+ # Get the sequence
+ res_start = chain["res_idx"]
+ res_end = chain["res_idx"] + chain["res_num"]
+ residues = structure.residues[res_start:res_end]
+ sequence = [str(res["name"]) for res in residues]
+ sequences[entity] = sequence
+
+ # Create entity objects
+ entities_map = {}
+ for entity, sequence in sequences.items():
+ mol_type = entity_to_moltype[entity]
+
+ if mol_type == const.chain_type_ids["PROTEIN"]:
+ alphabet = ihm.LPeptideAlphabet()
+ chem_comp = lambda x: ihm.LPeptideChemComp(id=x, code=x, code_canonical="X") # noqa: E731
+ elif mol_type == const.chain_type_ids["DNA"]:
+ alphabet = ihm.DNAAlphabet()
+ chem_comp = lambda x: ihm.DNAChemComp(id=x, code=x, code_canonical="N") # noqa: E731
+ elif mol_type == const.chain_type_ids["RNA"]:
+ alphabet = ihm.RNAAlphabet()
+ chem_comp = lambda x: ihm.RNAChemComp(id=x, code=x, code_canonical="N") # noqa: E731
+ elif len(sequence) > 1:
+ alphabet = {}
+ chem_comp = lambda x: ihm.SaccharideChemComp(id=x) # noqa: E731
+ else:
+ alphabet = {}
+ chem_comp = lambda x: ihm.NonPolymerChemComp(id=x) # noqa: E731
+
+ seq = [
+ alphabet[item] if item in alphabet else chem_comp(item) for item in sequence
+ ]
+ model_e = Entity(seq)
+ for chain in entity_to_chains[entity]:
+ chain_idx = chain["asym_id"]
+ entities_map[chain_idx] = model_e
+
+ # We don't assume that symmetry is perfect, so we dump everything
+ # into the asymmetric unit, and produce just a single assembly
+ chain_tags = generate_tags()
+ asym_unit_map = {}
+ for chain in structure.chains:
+ # Define the model assembly
+ chain_idx = chain["asym_id"]
+ chain_tag = next(chain_tags)
+ asym = AsymUnit(
+ entities_map[chain_idx],
+ details="Model subunit %s" % chain_tag,
+ id=chain_tag,
+ )
+ asym_unit_map[chain_idx] = asym
+ modeled_assembly = Assembly(asym_unit_map.values(), name="Modeled assembly")
+
+ # class _LocalPLDDT(modelcif.qa_metric.Local, modelcif.qa_metric.PLDDT):
+ # name = "pLDDT"
+ # software = None
+ # description = "Predicted lddt"
+
+ # class _GlobalPLDDT(modelcif.qa_metric.Global, modelcif.qa_metric.PLDDT):
+ # name = "pLDDT"
+ # software = None
+ # description = "Global pLDDT, mean of per-residue pLDDTs"
+
+ class _MyModel(AbInitioModel):
+ def get_atoms(self) -> Iterator[Atom]:
+ # Add all atom sites.
+ for chain in structure.chains:
+ # We rename the chains in alphabetical order
+ het = chain["mol_type"] == const.chain_type_ids["NONPOLYMER"]
+ chain_idx = chain["asym_id"]
+ res_start = chain["res_idx"]
+ res_end = chain["res_idx"] + chain["res_num"]
+
+ residues = structure.residues[res_start:res_end]
+ for residue in residues:
+ atom_start = residue["atom_idx"]
+ atom_end = residue["atom_idx"] + residue["atom_num"]
+ atoms = structure.atoms[atom_start:atom_end]
+ atom_coords = atoms["coords"]
+ for i, atom in enumerate(atoms):
+ # This should not happen on predictions, but just in case.
+ if not atom["is_present"]:
+ continue
+
+ name = atom["name"]
+ name = [chr(c + 32) for c in name if c != 0]
+ name = "".join(name)
+ element = periodic_table.GetElementSymbol(
+ atom["element"].item()
+ )
+ element = element.upper()
+ residue_index = residue["res_idx"] + 1
+ pos = atom_coords[i]
+ yield Atom(
+ asym_unit=asym_unit_map[chain_idx],
+ type_symbol=element,
+ seq_id=residue_index,
+ atom_id=name,
+ x=pos[0],
+ y=pos[1],
+ z=pos[2],
+ het=het,
+ biso=1.00,
+ occupancy=1.00,
+ )
+
+ def add_scores(self):
+ return
+ # local scores
+ # plddt_per_residue = {}
+ # for i in range(n):
+ # for mask, b_factor in zip(atom_mask[i], b_factors[i]):
+ # if mask < 0.5:
+ # continue
+ # # add 1 per residue, not 1 per atom
+ # if chain_index[i] not in plddt_per_residue:
+ # # first time a chain index is seen: add the key and start the residue dict
+ # plddt_per_residue[chain_index[i]] = {residue_index[i]: b_factor}
+ # if residue_index[i] not in plddt_per_residue[chain_index[i]]:
+ # plddt_per_residue[chain_index[i]][residue_index[i]] = b_factor
+ # plddts = []
+ # for chain_idx in plddt_per_residue:
+ # for residue_idx in plddt_per_residue[chain_idx]:
+ # plddt = plddt_per_residue[chain_idx][residue_idx]
+ # plddts.append(plddt)
+ # self.qa_metrics.append(
+ # _LocalPLDDT(
+ # asym_unit_map[chain_idx].residue(residue_idx), plddt
+ # )
+ # )
+ # # global score
+ # self.qa_metrics.append((_GlobalPLDDT(np.mean(plddts))))
+
+ # Add the model and modeling protocol to the file and write them out:
+ model = _MyModel(assembly=modeled_assembly, name="Model")
+ # model.add_scores()
+
+ model_group = ModelGroup([model], name="All models")
+ system.model_groups.append(model_group)
+
+ fh = io.StringIO()
+ dumper.write(fh, [system])
+ return fh.getvalue()
diff --git a/src/boltz/data/write/pdb.py b/src/boltz/data/write/pdb.py
new file mode 100644
index 0000000..fc3beac
--- /dev/null
+++ b/src/boltz/data/write/pdb.py
@@ -0,0 +1,113 @@
+from rdkit import Chem
+
+from boltz.data import const
+from boltz.data.types import Structure
+from boltz.data.write.utils import generate_tags
+
+
+def to_pdb(structure: Structure) -> str: # noqa: PLR0915
+ """Write a structure into a PDB file.
+
+ Parameters
+ ----------
+ structure : Structure
+ The input structure
+
+ Returns
+ -------
+ str
+ the output PDB file
+
+ """
+ pdb_lines = []
+
+ atom_index = 1
+ atom_reindex_ter = []
+ chain_tags = generate_tags()
+
+ # Load periodic table for element mapping
+ periodic_table = Chem.GetPeriodicTable()
+
+ # Add all atom sites.
+ for chain in structure.chains:
+ # We rename the chains in alphabetical order
+ chain_idx = chain["asym_id"]
+ chain_tag = next(chain_tags)
+
+ res_start = chain["res_idx"]
+ res_end = chain["res_idx"] + chain["res_num"]
+
+ residues = structure.residues[res_start:res_end]
+ for residue in residues:
+ atom_start = residue["atom_idx"]
+ atom_end = residue["atom_idx"] + residue["atom_num"]
+ atoms = structure.atoms[atom_start:atom_end]
+ atom_coords = atoms["coords"]
+ for i, atom in enumerate(atoms):
+ # This should not happen on predictions, but just in case.
+ if not atom["is_present"]:
+ continue
+
+ record_type = (
+ "ATOM"
+ if chain["mol_type"] != const.chain_type_ids["NONPOLYMER"]
+ else "HETATM"
+ )
+ name = atom["name"]
+ name = [chr(c + 32) for c in name if c != 0]
+ name = "".join(name)
+ name = name if len(name) == 4 else f" {name}" # noqa: PLR2004
+ alt_loc = ""
+ insertion_code = ""
+ occupancy = 1.00
+ element = periodic_table.GetElementSymbol(atom["element"].item())
+ element = element.upper()
+ charge = ""
+ residue_index = residue["res_idx"] + 1
+ pos = atom_coords[i]
+ res_name_3 = (
+ "LIG" if record_type == "HETATM" else str(residue["name"][:3])
+ )
+ b_factor = 1.00
+
+ # PDB is a columnar format, every space matters here!
+ atom_line = (
+ f"{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}"
+ f"{res_name_3:>3} {chain_tag:>1}"
+ f"{residue_index:>4}{insertion_code:>1} "
+ f"{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}"
+ f"{occupancy:>6.2f}{b_factor:>6.2f} "
+ f"{element:>2}{charge:>2}"
+ )
+ pdb_lines.append(atom_line)
+ atom_reindex_ter.append(atom_index)
+ atom_index += 1
+
+ should_terminate = chain_idx < (len(structure.chains) - 1)
+ if should_terminate:
+ # Close the chain.
+ chain_end = "TER"
+ chain_termination_line = (
+ f"{chain_end:<6}{atom_index:>5} "
+ f"{res_name_3:>3} "
+ f"{chain_tag:>1}{residue_index:>4}"
+ )
+ pdb_lines.append(chain_termination_line)
+ atom_index += 1
+
+ # Dump CONECT records.
+ for bonds in [structure.bonds, structure.connections]:
+ for bond in bonds:
+ atom1 = structure.atoms[bond["atom_1"]]
+ atom2 = structure.atoms[bond["atom_2"]]
+ if not atom1["is_present"] or not atom2["is_present"]:
+ continue
+ atom1_idx = atom_reindex_ter[bond["atom_1"]]
+ atom2_idx = atom_reindex_ter[bond["atom_2"]]
+ conect_line = f"CONECT{atom1_idx:>5}{atom2_idx:>5}"
+ pdb_lines.append(conect_line)
+
+ pdb_lines.append("END")
+ pdb_lines.append("")
+ pdb_lines = [line.ljust(80) for line in pdb_lines]
+ return "\n".join(pdb_lines)
diff --git a/src/boltz/data/write/utils.py b/src/boltz/data/write/utils.py
new file mode 100644
index 0000000..617d871
--- /dev/null
+++ b/src/boltz/data/write/utils.py
@@ -0,0 +1,23 @@
+import string
+from collections.abc import Iterator
+
+
+def generate_tags() -> Iterator[str]:
+ """Generate chain tags.
+
+ Yields
+ ------
+ str
+ The next chain tag
+
+ """
+ for i in range(1, 4):
+ for j in range(len(string.ascii_uppercase) ** i):
+ tag = ""
+ for k in range(i):
+ tag += string.ascii_uppercase[
+ j
+ // (len(string.ascii_uppercase) ** k)
+ % len(string.ascii_uppercase)
+ ]
+ yield tag
diff --git a/src/boltz/data/write/writer.py b/src/boltz/data/write/writer.py
new file mode 100644
index 0000000..6f4aa59
--- /dev/null
+++ b/src/boltz/data/write/writer.py
@@ -0,0 +1,155 @@
+from dataclasses import asdict, replace
+from pathlib import Path
+from typing import Literal
+
+import numpy as np
+from pytorch_lightning import LightningModule, Trainer
+from pytorch_lightning.callbacks import BasePredictionWriter
+from torch import Tensor
+
+from boltz.data.types import (
+ Interface,
+ Record,
+ Structure,
+)
+from boltz.data.write.mmcif import to_mmcif
+from boltz.data.write.pdb import to_pdb
+
+
+class BoltzWriter(BasePredictionWriter):
+ """Custom writer for predictions."""
+
+ def __init__(
+ self,
+ data_dir: str,
+ output_dir: str,
+ output_format: Literal["pdb", "mmcif"] = "mmcif",
+ ) -> None:
+ """Initialize the writer.
+
+ Parameters
+ ----------
+ output_dir : str
+ The directory to save the predictions.
+
+ """
+ super().__init__(write_interval="batch")
+ if output_format not in ["pdb", "mmcif"]:
+ msg = f"Invalid output format: {output_format}"
+ raise ValueError(msg)
+
+ self.data_dir = Path(data_dir)
+ self.output_dir = Path(output_dir)
+ self.output_format = output_format
+ self.failed = 0
+
+ # Create the output directories
+ self.output_dir.mkdir(parents=True, exist_ok=True)
+
+ def write_on_batch_end(
+ self,
+ trainer: Trainer, # noqa: ARG002
+ pl_module: LightningModule, # noqa: ARG002
+ prediction: dict[str, Tensor],
+ batch_indices: list[int], # noqa: ARG002
+ batch: dict[str, Tensor],
+ batch_idx: int, # noqa: ARG002
+ dataloader_idx: int, # noqa: ARG002
+ ) -> None:
+ """Write the predictions to disk."""
+ if prediction["exception"]:
+ self.failed += 1
+ return
+
+ # Get the records
+ records: list[Record] = batch["record"]
+
+ # Get the predictions
+ coords = prediction["coords"]
+ coords = coords.unsqueeze(0)
+
+ pad_masks = prediction["masks"]
+ if prediction.get("confidence") is not None:
+ confidences = prediction["confidence"]
+ confidences = confidences.reshape(len(records), -1).tolist()
+ else:
+ confidences = [0.0 for _ in range(len(records))]
+
+ # Iterate over the records
+ for record, coord, pad_mask, _confidence in zip(
+ records, coords, pad_masks, confidences
+ ):
+ # Load the structure
+ path = self.data_dir / f"{record.id}.npz"
+ structure: Structure = Structure.load(path)
+
+ # Compute chain map with masked removed, to be used later
+ chain_map = {}
+ for i, mask in enumerate(structure.mask):
+ if mask:
+ chain_map[len(chain_map)] = i
+
+ # Remove masked chains completely
+ structure = structure.remove_invalid_chains()
+
+ for model_idx in range(coord.shape[0]):
+ # Get model coord
+ model_coord = coord[model_idx]
+ # Unpad
+ coord_unpad = model_coord[pad_mask.bool()]
+ coord_unpad = coord_unpad.cpu().numpy()
+
+ # New atom table
+ atoms = structure.atoms
+ atoms["coords"] = coord_unpad
+ atoms["is_present"] = True
+
+ # Mew residue table
+ residues = structure.residues
+ residues["is_present"] = True
+
+ # Update the structure
+ interfaces = np.array([], dtype=Interface)
+ new_structure: Structure = replace(
+ structure,
+ atoms=atoms,
+ residues=residues,
+ interfaces=interfaces,
+ )
+
+ # Update chain info
+ chain_info = []
+ for chain in new_structure.chains:
+ old_chain_idx = chain_map[chain["asym_id"]]
+ old_chain_info = record.chains[old_chain_idx]
+ new_chain_info = replace(
+ old_chain_info,
+ chain_id=int(chain["asym_id"]),
+ valid=True,
+ )
+ chain_info.append(new_chain_info)
+
+ # Save the structure
+ struct_dir = self.output_dir / record.id
+ struct_dir.mkdir(exist_ok=True)
+
+ if self.output_format == "pdb":
+ path = struct_dir / f"{record.id}_model_{model_idx}.pdb"
+ with path.open("w") as f:
+ f.write(to_pdb(new_structure))
+ elif self.output_format == "mmcif":
+ path = struct_dir / f"{record.id}_model_{model_idx}.cif"
+ with path.open("w") as f:
+ f.write(to_mmcif(new_structure))
+ else:
+ path = struct_dir / f"{record.id}_model_{model_idx}.npz"
+ np.savez_compressed(path, **asdict(new_structure))
+
+ def on_predict_epoch_end(
+ self,
+ trainer: Trainer, # noqa: ARG00s2
+ pl_module: LightningModule, # noqa: ARG002
+ ) -> None:
+ """Print the number of failed examples."""
+ # Print number of failed examples
+ print(f"Number of failed examples: {self.failed}") # noqa: T201
diff --git a/src/boltz/main.py b/src/boltz/main.py
new file mode 100644
index 0000000..b1c43f3
--- /dev/null
+++ b/src/boltz/main.py
@@ -0,0 +1,403 @@
+import pickle
+import urllib.request
+from dataclasses import asdict, dataclass
+from pathlib import Path
+from typing import Literal, Optional
+
+import click
+import torch
+from pytorch_lightning import Trainer
+from pytorch_lightning.strategies import DDPStrategy
+from rdkit.Chem.rdchem import Mol
+from tqdm import tqdm
+
+from boltz.data.module.inference import BoltzInferenceDataModule
+from boltz.data.parse.a3m import parse_a3m
+from boltz.data.parse.fasta import parse_fasta
+from boltz.data.parse.yaml import parse_yaml
+from boltz.data.types import MSA, Manifest, Record
+from boltz.data.write.writer import BoltzWriter
+from boltz.model.model import Boltz1
+
+CCD_URL = "https://www.dropbox.com/scl/fi/h4mjhcbhzzkkj4piu1k6x/ccd.pkl?rlkey=p43trjrs9ots4qk84ygk24seu&st=bymcsoqe&dl=1"
+MODEL_URL = "https://www.dropbox.com/scl/fi/8qo9aryyttzp97z74dchn/boltz1.ckpt?rlkey=jvxl2jsn0kajnyfmesbj4lb89&st=dipi1sbw&dl=1"
+
+
+@dataclass
+class BoltzProcessedInput:
+ """Processed input data."""
+
+ manifest: Manifest
+ targets_dir: Path
+ msa_dir: Path
+
+
+@dataclass
+class BoltzDiffusionParams:
+ """Diffusion process parameters."""
+
+ gamma_0: float = 0.605
+ gamma_min: float = 1.107
+ noise_scale: float = 0.901
+ rho: float = 8
+ step_scale: float = 1.638
+ sigma_min: float = 0.0004
+ sigma_max: float = 160.0
+ sigma_data: float = 16.0
+ P_mean: float = -1.2
+ P_std: float = 1.5
+ coordinate_augmentation: bool = True
+ alignment_reverse_diff: bool = True
+ synchronize_sigmas: bool = True
+ use_inference_model_cache: bool = True
+
+
+def download(cache: Path) -> None:
+ """Download all the required data.
+
+ Parameters
+ ----------
+ cache : Path
+ The cache directory.
+
+ """
+ # Warn user, just in case
+ click.echo(
+ f"Downloading data and model to {cache}. "
+ "You may change this by setting the --cache flag."
+ )
+
+ # Download CCD, while capturing the output
+ ccd = cache / "ccd.pkl"
+ if not ccd.exists():
+ click.echo(f"Downloading CCD to {ccd}.")
+ urllib.request.urlretrieve(CCD_URL, str(ccd)) # noqa: S310
+
+ # Download model
+ model = cache / "boltz1.ckpt"
+ if not model.exists():
+ click.echo(f"Downloading model to {model}")
+ urllib.request.urlretrieve(MODEL_URL, str(model)) # noqa: S310
+
+
+def check_inputs(
+ data: Path,
+ outdir: Path,
+ override: bool = False,
+) -> list[Path]:
+ """Check the input data and output directory.
+
+ If the input data is a directory, it will be expanded
+ to all files in this directory. Then, we check if there
+ are any existing predictions and remove them from the
+ list of input data, unless the override flag is set.
+
+ Parameters
+ ----------
+ data : Path
+ The input data.
+ outdir : Path
+ The output directory.
+ override: bool
+ Whether to override existing predictions.
+
+ Returns
+ -------
+ list[Path]
+ The list of input data.
+
+ """
+ click.echo("Checking input data.")
+
+ # Check if data is a directory
+ if data.is_dir():
+ data = list(data.glob("*"))
+ data = [d for d in data if d.suffix in [".fasta", ".yaml"]]
+ else:
+ data = [data]
+
+ # Check if existing predictions are found
+ existing = (outdir / "predictions").rglob("*")
+ existing = {e.name for e in existing if e.is_dir()}
+
+ # Remove them from the input data
+ if existing and not override:
+ data = [d for d in data if d.stem not in existing]
+ msg = "Found existing predictions, skipping and running only the missing ones."
+ click.echo(msg)
+ elif existing and override:
+ msg = "Found existing predictions, will override."
+ click.echo(msg)
+
+ return data
+
+
+def process_inputs(
+ data: list[Path],
+ out_dir: Path,
+ ccd: dict[str, Mol],
+ max_msa_seqs: int = 4096,
+) -> BoltzProcessedInput:
+ """Process the input data and output directory.
+
+ Parameters
+ ----------
+ data : list[Path]
+ The input data.
+ out_dir : Path
+ The output directory.
+ ccd : dict[str, Mol]
+ The CCD dictionary.
+ max_msa_seqs : int, optional
+ Max number of MSA seuqneces, by default 4096.
+
+ Returns
+ -------
+ BoltzProcessedInput
+ The processed input data.
+
+ """
+ click.echo("Processing input data.")
+
+ # Create output directories
+ structure_dir = out_dir / "processed" / "structures"
+ processed_msa_dir = out_dir / "processed" / "msa"
+ predictions_dir = out_dir / "predictions"
+
+ out_dir.mkdir(parents=True, exist_ok=True)
+ structure_dir.mkdir(parents=True, exist_ok=True)
+ predictions_dir.mkdir(parents=True, exist_ok=True)
+ processed_msa_dir.mkdir(parents=True, exist_ok=True)
+
+ # Parse input data
+ records: list[Record] = []
+ for path in tqdm(data):
+ # Parse data
+ if path.suffix == ".fasta":
+ target = parse_fasta(path, ccd)
+ elif path.suffix == ".yaml":
+ target = parse_yaml(path, ccd)
+
+ # Keep record
+ records.append(target.record)
+
+ # Dump structure
+ struct_path = structure_dir / f"{target.record.id}.npz"
+ target.structure.dump(struct_path)
+
+ # Parse MSA data
+ msas = {chain.msa_id for r in records for chain in r.chains if chain.msa_id != -1}
+ msa_id_map = {}
+ for msa_idx, msa_id in enumerate(msas):
+ # Check that raw MSA exists
+ msa_path = Path(msa_id)
+ if not msa_path.exists():
+ msg = f"MSA file {msa_path} not found."
+ raise FileNotFoundError(msg)
+
+ # Dump processed MSA
+ processed = processed_msa_dir / f"{msa_idx}.npz"
+ msa_id_map[msa_id] = msa_idx
+ if not processed.exists():
+ msa: MSA = parse_a3m(
+ msa_path,
+ taxonomy=None,
+ max_seqs=max_msa_seqs,
+ )
+ msa.dump(processed)
+
+ # Modify records to point to processed MSA
+ for record in records:
+ for c in record.chains:
+ if c.msa_id != -1 and c.msa_id in msa_id_map:
+ c.msa_id = msa_id_map[c.msa_id]
+
+ # Dump manifest
+ manifest = Manifest(records)
+ manifest.dump(out_dir / "processed" / "manifest.json")
+
+ return BoltzProcessedInput(
+ manifest=manifest,
+ targets_dir=structure_dir,
+ msa_dir=processed_msa_dir,
+ )
+
+
+@click.group()
+def cli() -> None:
+ """Boltz1."""
+ return
+
+
+@cli.command()
+@click.argument("data", type=click.Path(exists=True))
+@click.option(
+ "--out_dir",
+ type=click.Path(exists=False),
+ help="The path where to save the predictions.",
+ default="./",
+)
+@click.option(
+ "--cache",
+ type=click.Path(exists=False),
+ help="The directory where to download the data and model. Default is ~/.boltz.",
+ default="~/.boltz",
+)
+@click.option(
+ "--checkpoint",
+ type=click.Path(exists=True),
+ help="An optional checkpoint, will use the provided Boltz-1 model by default.",
+ default=None,
+)
+@click.option(
+ "--devices",
+ type=int,
+ help="The number of devices to use for prediction. Default is 1.",
+ default=1,
+)
+@click.option(
+ "--accelerator",
+ type=click.Choice(["gpu", "cpu", "tpu"]),
+ help="The accelerator to use for prediction. Default is gpu.",
+ default="gpu",
+)
+@click.option(
+ "--recycling_steps",
+ type=int,
+ help="The number of recycling steps to use for prediction. Default is 3.",
+ default=3,
+)
+@click.option(
+ "--sampling_steps",
+ type=int,
+ help="The number of sampling steps to use for prediction. Default is 200.",
+ default=200,
+)
+@click.option(
+ "--diffusion_samples",
+ type=int,
+ help="The number of diffusion samples to use for prediction. Default is 1.",
+ default=1,
+)
+@click.option(
+ "--output_format",
+ type=click.Choice(["pdb", "mmcif"]),
+ help="The output format to use for the predictions. Default is mmcif.",
+ default="mmcif",
+)
+@click.option(
+ "--num_workers",
+ type=int,
+ help="The number of dataloader workers to use for prediction. Default is 2.",
+ default=2,
+)
+@click.option(
+ "--override",
+ is_flag=True,
+ help="Whether to override existing found predictions. Default is False.",
+)
+def predict(
+ data: str,
+ out_dir: str,
+ cache: str = "~/.boltz",
+ checkpoint: Optional[str] = None,
+ devices: int = 1,
+ accelerator: str = "gpu",
+ recycling_steps: int = 3,
+ sampling_steps: int = 200,
+ diffusion_samples: int = 1,
+ output_format: Literal["pdb", "mmcif"] = "mmcif",
+ num_workers: int = 2,
+ override: bool = False,
+) -> None:
+ """Run predictions with Boltz-1."""
+ # If cpu, write a friendly warning
+ if accelerator == "cpu":
+ msg = "Running on CPU, this will be slow. Consider using a GPU."
+ click.echo(msg)
+
+ # Set no grad
+ torch.set_grad_enabled(False)
+
+ # Set cache path
+ cache = Path(cache).expanduser()
+ cache.mkdir(parents=True, exist_ok=True)
+
+ # Create output directories
+ data = Path(data).expanduser()
+ out_dir = Path(out_dir).expanduser()
+ out_dir = out_dir / f"boltz_results_{data.stem}"
+ out_dir.mkdir(parents=True, exist_ok=True)
+
+ # Download necessary data and model
+ download(cache)
+
+ # Load CCD
+ ccd_path = cache / "ccd.pkl"
+ with ccd_path.open("rb") as file:
+ ccd = pickle.load(file) # noqa: S301
+
+ # Set checkpoint
+ if checkpoint is None:
+ checkpoint = cache / "boltz1.ckpt"
+
+ # Check if data is a directory
+ data = check_inputs(data, out_dir, override)
+ processed = process_inputs(data, out_dir, ccd)
+
+ # Create data module
+ data_module = BoltzInferenceDataModule(
+ manifest=processed.manifest,
+ target_dir=processed.targets_dir,
+ msa_dir=processed.msa_dir,
+ num_workers=num_workers,
+ )
+
+ # Load model
+ predict_args = {
+ "recycling_steps": recycling_steps,
+ "sampling_steps": sampling_steps,
+ "diffusion_samples": diffusion_samples,
+ }
+ model_module: Boltz1 = Boltz1.load_from_checkpoint(
+ checkpoint,
+ strict=True,
+ predict_args=predict_args,
+ map_location="cpu",
+ diffusion_process_args=asdict(BoltzDiffusionParams()),
+ )
+ model_module.eval()
+
+ # Create prediction writer
+ pred_writer = BoltzWriter(
+ data_dir=processed.targets_dir,
+ output_dir=out_dir / "predictions",
+ output_format=output_format,
+ )
+
+ # Set up trainer
+ strategy = "auto"
+ if (isinstance(devices, int) and devices > 1) or (
+ isinstance(devices, list) and len(devices) > 1
+ ):
+ strategy = DDPStrategy()
+
+ trainer = Trainer(
+ default_root_dir=out_dir,
+ strategy=strategy,
+ callbacks=[pred_writer],
+ accelerator=accelerator,
+ devices=devices,
+ precision=32,
+ )
+
+ # Compute predictions
+ trainer.predict(
+ model_module,
+ datamodule=data_module,
+ return_predictions=False,
+ )
+
+
+if __name__ == "__main__":
+ cli()
diff --git a/src/boltz/model/__init__.py b/src/boltz/model/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/boltz/model/layers/__init__.py b/src/boltz/model/layers/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/boltz/model/layers/attention.py b/src/boltz/model/layers/attention.py
new file mode 100644
index 0000000..6402627
--- /dev/null
+++ b/src/boltz/model/layers/attention.py
@@ -0,0 +1,131 @@
+from einops.layers.torch import Rearrange
+import torch
+from torch import Tensor, nn
+
+import boltz.model.layers.initialize as init
+
+
+class AttentionPairBias(nn.Module):
+ """Attention pair bias layer."""
+
+ def __init__(
+ self,
+ c_s: int,
+ c_z: int,
+ num_heads: int,
+ inf: float = 1e6,
+ initial_norm: bool = True,
+ ) -> None:
+ """Initialize the attention pair bias layer.
+
+ Parameters
+ ----------
+ c_s : int
+ The input sequence dimension.
+ c_z : int
+ The input pairwise dimension.
+ num_heads : int
+ The number of heads.
+ inf : float, optional
+ The inf value, by default 1e6
+ initial_norm: bool, optional
+ Whether to apply layer norm to the input, by default True
+
+ """
+ super().__init__()
+
+ assert c_s % num_heads == 0
+
+ self.c_s = c_s
+ self.num_heads = num_heads
+ self.head_dim = c_s // num_heads
+ self.inf = inf
+
+ self.initial_norm = initial_norm
+ if self.initial_norm:
+ self.norm_s = nn.LayerNorm(c_s)
+
+ self.proj_q = nn.Linear(c_s, c_s)
+ self.proj_k = nn.Linear(c_s, c_s, bias=False)
+ self.proj_v = nn.Linear(c_s, c_s, bias=False)
+ self.proj_g = nn.Linear(c_s, c_s, bias=False)
+
+ self.proj_z = nn.Sequential(
+ nn.LayerNorm(c_z),
+ nn.Linear(c_z, num_heads, bias=False),
+ Rearrange("b ... h -> b h ..."),
+ )
+
+ self.proj_o = nn.Linear(c_s, c_s, bias=False)
+ init.final_init_(self.proj_o.weight)
+
+ def forward(
+ self,
+ s: Tensor,
+ z: Tensor,
+ mask: Tensor,
+ multiplicity: int = 1,
+ to_keys=None,
+ model_cache=None,
+ ) -> Tensor:
+ """Forward pass.
+
+ Parameters
+ ----------
+ s : torch.Tensor
+ The input sequence tensor (B, S, D)
+ z : torch.Tensor
+ The input pairwise tensor (B, N, N, D)
+ mask : torch.Tensor
+ The pairwise mask tensor (B, N, N)
+ multiplicity : int, optional
+ The diffusion batch size, by default 1
+
+ Returns
+ -------
+ torch.Tensor
+ The output sequence tensor.
+
+ """
+ B = s.shape[0]
+
+ # Layer norms
+ if self.initial_norm:
+ s = self.norm_s(s)
+
+ if to_keys is not None:
+ k_in = to_keys(s)
+ mask = to_keys(mask.unsqueeze(-1)).squeeze(-1)
+ else:
+ k_in = s
+
+ # Compute projections
+ q = self.proj_q(s).view(B, -1, self.num_heads, self.head_dim)
+ k = self.proj_k(k_in).view(B, -1, self.num_heads, self.head_dim)
+ v = self.proj_v(k_in).view(B, -1, self.num_heads, self.head_dim)
+
+ # Caching z projection during diffusion roll-out
+ if model_cache is None or "z" not in model_cache:
+ z = self.proj_z(z)
+
+ if model_cache is not None:
+ model_cache["z"] = z
+ else:
+ z = model_cache["z"]
+ z = z.repeat_interleave(multiplicity, 0)
+
+ g = self.proj_g(s).sigmoid()
+
+ with torch.autocast("cuda", enabled=False):
+ # Compute attention weights
+ attn = torch.einsum("bihd,bjhd->bhij", q.float(), k.float())
+ attn = attn / (self.head_dim**0.5) + z.float()
+ attn = attn + (1 - mask[:, None, None].float()) * -self.inf
+ attn = attn.softmax(dim=-1)
+
+ # Compute output
+ o = torch.einsum("bhij,bjhd->bihd", attn, v.float()).to(v.dtype)
+ o = o.reshape(B, -1, self.c_s)
+ o = self.proj_o(g * o)
+
+ return o
diff --git a/src/boltz/model/layers/dropout.py b/src/boltz/model/layers/dropout.py
new file mode 100644
index 0000000..f417c86
--- /dev/null
+++ b/src/boltz/model/layers/dropout.py
@@ -0,0 +1,34 @@
+import torch
+from torch import Tensor
+
+
+def get_dropout_mask(
+ dropout: float,
+ z: Tensor,
+ training: bool,
+ columnwise: bool = False,
+) -> Tensor:
+ """Get the dropout mask.
+
+ Parameters
+ ----------
+ dropout : float
+ The dropout rate
+ z : torch.Tensor
+ The tensor to apply dropout to
+ training : bool
+ Whether the model is in training mode
+ columnwise : bool, optional
+ Whether to apply dropout columnwise
+
+ Returns
+ -------
+ torch.Tensor
+ The dropout mask
+
+ """
+ dropout = dropout * training
+ v = z[:, 0:1, :, 0:1] if columnwise else z[:, :, 0:1, 0:1]
+ d = torch.rand_like(v) > dropout
+ d = d * 1.0 / (1.0 - dropout)
+ return d
diff --git a/src/boltz/model/layers/initialize.py b/src/boltz/model/layers/initialize.py
new file mode 100644
index 0000000..db76a5a
--- /dev/null
+++ b/src/boltz/model/layers/initialize.py
@@ -0,0 +1,99 @@
+"""Utility functions for initializing weights and biases."""
+
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import numpy as np
+from scipy.stats import truncnorm
+import torch
+
+
+def _prod(nums):
+ out = 1
+ for n in nums:
+ out = out * n
+ return out
+
+
+def _calculate_fan(linear_weight_shape, fan="fan_in"):
+ fan_out, fan_in = linear_weight_shape
+
+ if fan == "fan_in":
+ f = fan_in
+ elif fan == "fan_out":
+ f = fan_out
+ elif fan == "fan_avg":
+ f = (fan_in + fan_out) / 2
+ else:
+ raise ValueError("Invalid fan option")
+
+ return f
+
+
+def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
+ shape = weights.shape
+ f = _calculate_fan(shape, fan)
+ scale = scale / max(1, f)
+ a = -2
+ b = 2
+ std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1)
+ size = _prod(shape)
+ samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size)
+ samples = np.reshape(samples, shape)
+ with torch.no_grad():
+ weights.copy_(torch.tensor(samples, device=weights.device))
+
+
+def lecun_normal_init_(weights):
+ trunc_normal_init_(weights, scale=1.0)
+
+
+def he_normal_init_(weights):
+ trunc_normal_init_(weights, scale=2.0)
+
+
+def glorot_uniform_init_(weights):
+ torch.nn.init.xavier_uniform_(weights, gain=1)
+
+
+def final_init_(weights):
+ with torch.no_grad():
+ weights.fill_(0.0)
+
+
+def gating_init_(weights):
+ with torch.no_grad():
+ weights.fill_(0.0)
+
+
+def bias_init_zero_(bias):
+ with torch.no_grad():
+ bias.fill_(0.0)
+
+
+def bias_init_one_(bias):
+ with torch.no_grad():
+ bias.fill_(1.0)
+
+
+def normal_init_(weights):
+ torch.nn.init.kaiming_normal_(weights, nonlinearity="linear")
+
+
+def ipa_point_weights_init_(weights):
+ with torch.no_grad():
+ softplus_inverse_1 = 0.541324854612918
+ weights.fill_(softplus_inverse_1)
diff --git a/src/boltz/model/layers/outer_product_mean.py b/src/boltz/model/layers/outer_product_mean.py
new file mode 100644
index 0000000..cf964eb
--- /dev/null
+++ b/src/boltz/model/layers/outer_product_mean.py
@@ -0,0 +1,90 @@
+import torch
+from torch import Tensor, nn
+
+import boltz.model.layers.initialize as init
+
+
+class OuterProductMean(nn.Module):
+ """Outer product mean layer."""
+
+ def __init__(
+ self, c_in: int, c_hidden: int, c_out: int, chunk_size: int = None
+ ) -> None:
+ """Initialize the outer product mean layer.
+
+ Parameters
+ ----------
+ c_in : int
+ The input dimension.
+ c_hidden : int
+ The hidden dimension.
+ c_out : int
+ The output dimension.
+ chunk_size : int, optional
+ The inference chunk size, by default None.
+
+ """
+ super().__init__()
+ self.chunk_size = chunk_size
+ self.c_hidden = c_hidden
+ self.norm = nn.LayerNorm(c_in)
+ self.proj_a = nn.Linear(c_in, c_hidden, bias=False)
+ self.proj_b = nn.Linear(c_in, c_hidden, bias=False)
+ self.proj_o = nn.Linear(c_hidden * c_hidden, c_out)
+ init.final_init_(self.proj_o.weight)
+ init.final_init_(self.proj_o.bias)
+
+ def forward(self, m: Tensor, mask: Tensor) -> Tensor:
+ """Forward pass.
+
+ Parameters
+ ----------
+ m : torch.Tensor
+ The sequence tensor (B, S, N, c_in).
+ mask : torch.Tensor
+ The mask tensor (B, S, N).
+
+ Returns
+ -------
+ torch.Tensor
+ The output tensor (B, N, N, c_out).
+
+ """
+ # Expand mask
+ mask = mask.unsqueeze(-1).to(m)
+
+ # Compute projections
+ m = self.norm(m)
+ a = self.proj_a(m) * mask
+ b = self.proj_b(m) * mask
+
+ # Compute pairwise mask
+ mask = mask[:, :, None, :] * mask[:, :, :, None]
+
+ # Compute outer product mean
+ if self.chunk_size is not None and not self.training:
+ # Compute squentially in chunks
+ for i in range(0, self.c_hidden, self.chunk_size):
+ a_chunk = a[:, :, :, i : i + self.chunk_size]
+ sliced_weight_proj_o = self.proj_o.weight[
+ :, i * self.c_hidden : (i + self.chunk_size) * self.c_hidden
+ ]
+
+ z = torch.einsum("bsic,bsjd->bijcd", a_chunk, b)
+ z = z.reshape(*z.shape[:3], -1)
+ z = z / mask.sum(dim=1).clamp(min=1)
+
+ # Project to output
+ if i == 0:
+ z_out = z.to(m) @ sliced_weight_proj_o.T
+ else:
+ z_out = z_out + z.to(m) @ sliced_weight_proj_o.T
+ return z_out
+ else:
+ z = torch.einsum("bsic,bsjd->bijcd", a.float(), b.float())
+ z = z.reshape(*z.shape[:3], -1)
+ z = z / mask.sum(dim=1).clamp(min=1)
+
+ # Project to output
+ z = self.proj_o(z.to(m))
+ return z
diff --git a/src/boltz/model/layers/pair_averaging.py b/src/boltz/model/layers/pair_averaging.py
new file mode 100644
index 0000000..43416cb
--- /dev/null
+++ b/src/boltz/model/layers/pair_averaging.py
@@ -0,0 +1,137 @@
+import torch
+from torch import Tensor, nn
+
+import boltz.model.layers.initialize as init
+
+
+class PairWeightedAveraging(nn.Module):
+ """Pair weighted averaging layer."""
+
+ def __init__(
+ self,
+ c_m: int,
+ c_z: int,
+ c_h: int,
+ num_heads: int,
+ inf: float = 1e6,
+ chunk_heads: bool = False,
+ ) -> None:
+ """Initialize the pair weighted averaging layer.
+
+ Parameters
+ ----------
+ c_m: int
+ The dimension of the input sequence.
+ c_z: int
+ The dimension of the input pairwise tensor.
+ c_h: int
+ The dimension of the hidden.
+ num_heads: int
+ The number of heads.
+ inf: float
+ The value to use for masking, default 1e6.
+ chunk_heads: bool
+ Whether to sequentially compute heads at inference, default False.
+
+ """
+ super().__init__()
+ self.c_m = c_m
+ self.c_z = c_z
+ self.c_h = c_h
+ self.num_heads = num_heads
+ self.inf = inf
+ self.chunk_heads = chunk_heads
+
+ self.norm_m = nn.LayerNorm(c_m)
+ self.norm_z = nn.LayerNorm(c_z)
+
+ self.proj_m = nn.Linear(c_m, c_h * num_heads, bias=False)
+ self.proj_g = nn.Linear(c_m, c_h * num_heads, bias=False)
+ self.proj_z = nn.Linear(c_z, num_heads, bias=False)
+ self.proj_o = nn.Linear(c_h * num_heads, c_m, bias=False)
+ init.final_init_(self.proj_o.weight)
+
+ def forward(self, m: Tensor, z: Tensor, mask: Tensor) -> Tensor:
+ """Forward pass.
+
+ Parameters
+ ----------
+ m : torch.Tensor
+ The input sequence tensor (B, S, N, D)
+ z : torch.Tensor
+ The input pairwise tensor (B, N, N, D)
+ mask : torch.Tensor
+ The pairwise mask tensor (B, N, N)
+
+ Returns
+ -------
+ torch.Tensor
+ The output sequence tensor (B, S, N, D)
+
+ """
+ # Compute layer norms
+ m = self.norm_m(m)
+ z = self.norm_z(z)
+
+ if self.chunk_heads and not self.training:
+ # Compute heads sequentially
+ o_chunks = []
+ for head_idx in range(self.num_heads):
+ sliced_weight_proj_m = self.proj_m.weight[
+ head_idx * self.c_h : (head_idx + 1) * self.c_h, :
+ ]
+ sliced_weight_proj_g = self.proj_g.weight[
+ head_idx * self.c_h : (head_idx + 1) * self.c_h, :
+ ]
+ sliced_weight_proj_z = self.proj_z.weight[head_idx : (head_idx + 1), :]
+ sliced_weight_proj_o = self.proj_o.weight[
+ :, head_idx * self.c_h : (head_idx + 1) * self.c_h
+ ]
+
+ # Project input tensors
+ v: Tensor = m @ sliced_weight_proj_m.T
+ v = v.reshape(*v.shape[:3], 1, self.c_h)
+ v = v.permute(0, 3, 1, 2, 4)
+
+ # Compute weights
+ b: Tensor = z @ sliced_weight_proj_z.T
+ b = b.permute(0, 3, 1, 2)
+ b = b + (1 - mask[:, None]) * -self.inf
+ w = torch.softmax(b, dim=-1)
+
+ # Compute gating
+ g: Tensor = m @ sliced_weight_proj_g.T
+ g = g.sigmoid()
+
+ # Compute output
+ o = torch.einsum("bhij,bhsjd->bhsid", w, v)
+ o = o.permute(0, 2, 3, 1, 4)
+ o = o.reshape(*o.shape[:3], 1 * self.c_h)
+ o_chunks = g * o
+ if head_idx == 0:
+ o_out = o_chunks @ sliced_weight_proj_o.T
+ else:
+ o_out += o_chunks @ sliced_weight_proj_o.T
+ return o_out
+ else:
+ # Project input tensors
+ v: Tensor = self.proj_m(m)
+ v = v.reshape(*v.shape[:3], self.num_heads, self.c_h)
+ v = v.permute(0, 3, 1, 2, 4)
+
+ # Compute weights
+ b: Tensor = self.proj_z(z)
+ b = b.permute(0, 3, 1, 2)
+ b = b + (1 - mask[:, None]) * -self.inf
+ w = torch.softmax(b, dim=-1)
+
+ # Compute gating
+ g: Tensor = self.proj_g(m)
+ g = g.sigmoid()
+
+ # Compute output
+ o = torch.einsum("bhij,bhsjd->bhsid", w, v)
+ o = o.permute(0, 2, 3, 1, 4)
+ o = o.reshape(*o.shape[:3], self.num_heads * self.c_h)
+ o = self.proj_o(g * o)
+ return o
diff --git a/src/boltz/model/layers/transition.py b/src/boltz/model/layers/transition.py
new file mode 100644
index 0000000..1ae464d
--- /dev/null
+++ b/src/boltz/model/layers/transition.py
@@ -0,0 +1,82 @@
+from typing import Optional
+
+from torch import Tensor, nn
+
+import boltz.model.layers.initialize as init
+
+
+class Transition(nn.Module):
+ """Perform a two-layer MLP."""
+
+ def __init__(
+ self,
+ dim: int = 128,
+ hidden: int = 512,
+ out_dim: Optional[int] = None,
+ chunk_size: int = None,
+ ) -> None:
+ """Initialize the TransitionUpdate module.
+
+ Parameters
+ ----------
+ dim: int
+ The dimension of the input, default 128
+ hidden: int
+ The dimension of the hidden, default 512
+ out_dim: Optional[int]
+ The dimension of the output, default None
+ chunk_size: int
+ The chunk size for inference, default None
+
+ """
+ super().__init__()
+ if out_dim is None:
+ out_dim = dim
+
+ self.norm = nn.LayerNorm(dim, eps=1e-5)
+ self.fc1 = nn.Linear(dim, hidden, bias=False)
+ self.fc2 = nn.Linear(dim, hidden, bias=False)
+ self.fc3 = nn.Linear(hidden, out_dim, bias=False)
+ self.silu = nn.SiLU()
+ self.hidden = hidden
+ self.chunk_size = chunk_size
+
+ init.bias_init_one_(self.norm.weight)
+ init.bias_init_zero_(self.norm.bias)
+
+ init.lecun_normal_init_(self.fc1.weight)
+ init.lecun_normal_init_(self.fc2.weight)
+ init.final_init_(self.fc3.weight)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Perform a forward pass.
+
+ Parameters
+ ----------
+ x: torch.Tensor
+ The input data of shape (..., D)
+
+ Returns
+ -------
+ x: torch.Tensor
+ The output data of shape (..., D)
+
+ """
+ x = self.norm(x)
+
+ if self.chunk_size is None or self.training:
+ x = self.silu(self.fc1(x)) * self.fc2(x)
+ x = self.fc3(x)
+ return x
+ else:
+ # Compute in chunks
+ for i in range(0, self.hidden, self.chunk_size):
+ fc1_slice = self.fc1.weight[i : i + self.chunk_size, :]
+ fc2_slice = self.fc2.weight[i : i + self.chunk_size, :]
+ fc3_slice = self.fc3.weight[:, i : i + self.chunk_size]
+ x_chunk = self.silu((x @ fc1_slice.T)) * (x @ fc2_slice.T)
+ if i == 0:
+ x_out = x_chunk @ fc3_slice.T
+ else:
+ x_out = x_out + x_chunk @ fc3_slice.T
+ return x_out
diff --git a/src/boltz/model/layers/triangular_attention/__init__.py b/src/boltz/model/layers/triangular_attention/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/boltz/model/layers/triangular_attention/attention.py b/src/boltz/model/layers/triangular_attention/attention.py
new file mode 100644
index 0000000..e315334
--- /dev/null
+++ b/src/boltz/model/layers/triangular_attention/attention.py
@@ -0,0 +1,165 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partial, partialmethod
+from typing import List, Optional
+
+import torch
+import torch.nn as nn
+
+from boltz.model.layers.triangular_attention.primitives import (
+ Attention,
+ LayerNorm,
+ Linear,
+)
+from boltz.model.layers.triangular_attention.utils import (
+ chunk_layer,
+ permute_final_dims,
+)
+
+
+class TriangleAttention(nn.Module):
+ def __init__(self, c_in, c_hidden, no_heads, starting=True, inf=1e9):
+ """
+ Args:
+ c_in:
+ Input channel dimension
+ c_hidden:
+ Overall hidden channel dimension (not per-head)
+ no_heads:
+ Number of attention heads
+ """
+ super(TriangleAttention, self).__init__()
+
+ self.c_in = c_in
+ self.c_hidden = c_hidden
+ self.no_heads = no_heads
+ self.starting = starting
+ self.inf = inf
+
+ self.layer_norm = LayerNorm(self.c_in)
+
+ self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")
+
+ self.mha = Attention(
+ self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
+ )
+
+ @torch.jit.ignore
+ def _chunk(
+ self,
+ x: torch.Tensor,
+ biases: List[torch.Tensor],
+ chunk_size: int,
+ use_memory_efficient_kernel: bool = False,
+ use_deepspeed_evo_attention: bool = False,
+ use_lma: bool = False,
+ inplace_safe: bool = False,
+ ) -> torch.Tensor:
+ "triangle! triangle!"
+ mha_inputs = {
+ "q_x": x,
+ "kv_x": x,
+ "biases": biases,
+ }
+
+ return chunk_layer(
+ partial(
+ self.mha,
+ use_memory_efficient_kernel=use_memory_efficient_kernel,
+ use_deepspeed_evo_attention=use_deepspeed_evo_attention,
+ use_lma=use_lma,
+ ),
+ mha_inputs,
+ chunk_size=chunk_size,
+ no_batch_dims=len(x.shape[:-2]),
+ _out=x if inplace_safe else None,
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ chunk_size: Optional[int] = None,
+ use_memory_efficient_kernel: bool = False,
+ use_deepspeed_evo_attention: bool = False,
+ use_lma: bool = False,
+ inplace_safe: bool = False,
+ ) -> torch.Tensor:
+ """
+ Args:
+ x:
+ [*, I, J, C_in] input tensor (e.g. the pair representation)
+ Returns:
+ [*, I, J, C_in] output tensor
+ """
+ if mask is None:
+ # [*, I, J]
+ mask = x.new_ones(
+ x.shape[:-1],
+ )
+
+ if not self.starting:
+ x = x.transpose(-2, -3)
+ mask = mask.transpose(-1, -2)
+
+ # [*, I, J, C_in]
+ x = self.layer_norm(x)
+
+ # [*, I, 1, 1, J]
+ mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
+
+ # [*, H, I, J]
+ triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
+
+ # [*, 1, H, I, J]
+ triangle_bias = triangle_bias.unsqueeze(-4)
+
+ biases = [mask_bias, triangle_bias]
+
+ if chunk_size is not None:
+ x = self._chunk(
+ x,
+ biases,
+ chunk_size,
+ use_memory_efficient_kernel=use_memory_efficient_kernel,
+ use_deepspeed_evo_attention=use_deepspeed_evo_attention,
+ use_lma=use_lma,
+ inplace_safe=inplace_safe,
+ )
+ else:
+ x = self.mha(
+ q_x=x,
+ kv_x=x,
+ biases=biases,
+ use_memory_efficient_kernel=use_memory_efficient_kernel,
+ use_deepspeed_evo_attention=use_deepspeed_evo_attention,
+ use_lma=use_lma,
+ )
+
+ if not self.starting:
+ x = x.transpose(-2, -3)
+
+ return x
+
+
+# Implements Algorithm 13
+TriangleAttentionStartingNode = TriangleAttention
+
+
+class TriangleAttentionEndingNode(TriangleAttention):
+ """Implement Algorithm 14."""
+
+ __init__ = partialmethod(TriangleAttention.__init__, starting=False)
diff --git a/src/boltz/model/layers/triangular_attention/primitives.py b/src/boltz/model/layers/triangular_attention/primitives.py
new file mode 100644
index 0000000..319cc58
--- /dev/null
+++ b/src/boltz/model/layers/triangular_attention/primitives.py
@@ -0,0 +1,701 @@
+# Copyright 2021 AlQuraishi Laboratory
+# Copyright 2021 DeepMind Technologies Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+import math
+from typing import Optional, Callable, List, Tuple
+
+from boltz.model.layers.triangular_attention.utils import (
+ permute_final_dims,
+ flatten_final_dims,
+ is_fp16_enabled,
+)
+from boltz.model.layers import initialize
+
+deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
+ds4s_is_installed = (
+ deepspeed_is_installed
+ and importlib.util.find_spec("deepspeed.ops.deepspeed4science") is not None
+)
+if deepspeed_is_installed:
+ import deepspeed
+
+if ds4s_is_installed:
+ from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention
+
+fa_is_installed = importlib.util.find_spec("flash_attn") is not None
+if fa_is_installed:
+ from flash_attn.bert_padding import unpad_input
+ from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func
+
+
+import torch
+import torch.nn as nn
+
+DEFAULT_LMA_Q_CHUNK_SIZE = 1024
+DEFAULT_LMA_KV_CHUNK_SIZE = 4096
+
+
+class Linear(nn.Linear):
+ """
+ A Linear layer with built-in nonstandard initializations. Called just
+ like torch.nn.Linear.
+
+ Implements the initializers in 1.11.4, plus some additional ones found
+ in the code.
+ """
+
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ bias: bool = True,
+ init: str = "default",
+ init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
+ precision=None,
+ ):
+ """
+ Args:
+ in_dim:
+ The final dimension of inputs to the layer
+ out_dim:
+ The final dimension of layer outputs
+ bias:
+ Whether to learn an additive bias. True by default
+ init:
+ The initializer to use. Choose from:
+
+ "default": LeCun fan-in truncated normal initialization
+ "relu": He initialization w/ truncated normal distribution
+ "glorot": Fan-average Glorot uniform initialization
+ "gating": Weights=0, Bias=1
+ "normal": Normal initialization with std=1/sqrt(fan_in)
+ "final": Weights=0, Bias=0
+
+ Overridden by init_fn if the latter is not None.
+ init_fn:
+ A custom initializer taking weight and bias as inputs.
+ Overrides init if not None.
+ """
+ super(Linear, self).__init__(in_dim, out_dim, bias=bias)
+
+ if bias:
+ with torch.no_grad():
+ self.bias.fill_(0)
+
+ with torch.no_grad():
+ if init_fn is not None:
+ init_fn(self.weight, self.bias)
+ else:
+ if init == "default":
+ initialize.lecun_normal_init_(self.weight)
+ elif init == "relu":
+ initialize.he_normal_init_(self.weight)
+ elif init == "glorot":
+ initialize.glorot_uniform_init_(self.weight)
+ elif init == "gating":
+ initialize.gating_init_(self.weight)
+ if bias:
+ self.bias.fill_(1.0)
+ elif init == "normal":
+ initialize.normal_init_(self.weight)
+ elif init == "final":
+ initialize.final_init_(self.weight)
+ else:
+ raise ValueError("Invalid init string.")
+
+ self.precision = precision
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ d = input.dtype
+ deepspeed_is_initialized = (
+ deepspeed_is_installed and deepspeed.comm.comm.is_initialized()
+ )
+ if self.precision is not None:
+ with torch.autocast("cuda", enabled=False):
+ bias = (
+ self.bias.to(dtype=self.precision)
+ if self.bias is not None
+ else None
+ )
+ return nn.functional.linear(
+ input.to(dtype=self.precision),
+ self.weight.to(dtype=self.precision),
+ bias,
+ ).to(dtype=d)
+
+ if d is torch.bfloat16 and not deepspeed_is_initialized:
+ with torch.autocast("cuda", enabled=False):
+ bias = self.bias.to(dtype=d) if self.bias is not None else None
+ return nn.functional.linear(input, self.weight.to(dtype=d), bias)
+
+ return nn.functional.linear(input, self.weight, self.bias)
+
+
+class LayerNorm(nn.Module):
+ def __init__(self, c_in, eps=1e-5):
+ super(LayerNorm, self).__init__()
+
+ self.c_in = (c_in,)
+ self.eps = eps
+
+ self.weight = nn.Parameter(torch.ones(c_in))
+ self.bias = nn.Parameter(torch.zeros(c_in))
+
+ def forward(self, x):
+ d = x.dtype
+ deepspeed_is_initialized = (
+ deepspeed_is_installed and deepspeed.comm.comm.is_initialized()
+ )
+ if d is torch.bfloat16 and not deepspeed_is_initialized:
+ with torch.autocast("cuda", enabled=False):
+ out = nn.functional.layer_norm(
+ x,
+ self.c_in,
+ self.weight.to(dtype=d),
+ self.bias.to(dtype=d),
+ self.eps,
+ )
+ else:
+ out = nn.functional.layer_norm(
+ x,
+ self.c_in,
+ self.weight,
+ self.bias,
+ self.eps,
+ )
+
+ return out
+
+
+@torch.jit.ignore
+def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
+ """
+ Softmax, but without automatic casting to fp32 when the input is of
+ type bfloat16
+ """
+ d = t.dtype
+ deepspeed_is_initialized = (
+ deepspeed_is_installed and deepspeed.comm.comm.is_initialized()
+ )
+ if d is torch.bfloat16 and not deepspeed_is_initialized:
+ with torch.autocast("cuda", enabled=False):
+ s = torch.nn.functional.softmax(t, dim=dim)
+ else:
+ s = torch.nn.functional.softmax(t, dim=dim)
+
+ return s
+
+
+# @torch.jit.script
+def _attention(
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ biases: List[torch.Tensor],
+) -> torch.Tensor:
+ # [*, H, C_hidden, K]
+ key = permute_final_dims(key, (1, 0))
+
+ # [*, H, Q, K]
+ a = torch.matmul(query, key)
+
+ for b in biases:
+ a += b
+
+ a = softmax_no_cast(a, -1)
+
+ # [*, H, Q, C_hidden]
+ a = torch.matmul(a, value)
+
+ return a
+
+
+def get_checkpoint_fn():
+ deepspeed_is_configured = (
+ deepspeed_is_installed and deepspeed.checkpointing.is_configured()
+ )
+ if deepspeed_is_configured:
+ checkpoint = deepspeed.checkpointing.checkpoint
+ else:
+ checkpoint = torch.utils.checkpoint.checkpoint
+
+ return checkpoint
+
+
+@torch.jit.ignore
+def _attention_chunked_trainable(
+ query,
+ key,
+ value,
+ biases,
+ chunk_size,
+ chunk_dim,
+ checkpoint,
+):
+ if checkpoint and len(biases) > 2:
+ raise ValueError("Checkpointed version permits only permits two bias terms")
+
+ def _checkpointable_attention(q, k, v, b1, b2):
+ bs = [b for b in [b1, b2] if b is not None]
+ a = _attention(q, k, v, bs)
+ return a
+
+ o_chunks = []
+ checkpoint_fn = get_checkpoint_fn()
+ count = query.shape[chunk_dim]
+ for start in range(0, count, chunk_size):
+ end = start + chunk_size
+ idx = [slice(None)] * len(query.shape)
+ idx[chunk_dim] = slice(start, end)
+ idx_tup = tuple(idx)
+ q_chunk = query[idx_tup]
+ k_chunk = key[idx_tup]
+ v_chunk = value[idx_tup]
+
+ def _slice_bias(b):
+ idx[chunk_dim] = (
+ slice(start, end) if b.shape[chunk_dim] != 1 else slice(None)
+ )
+ return b[tuple(idx)]
+
+ if checkpoint:
+ bias_1_chunk, bias_2_chunk = [
+ _slice_bias(b) if b is not None else None
+ for b in (biases + [None, None])[:2]
+ ]
+
+ o_chunk = checkpoint_fn(
+ _checkpointable_attention,
+ q_chunk,
+ k_chunk,
+ v_chunk,
+ bias_1_chunk,
+ bias_2_chunk,
+ )
+ else:
+ bias_chunks = [_slice_bias(b) for b in biases]
+
+ o_chunk = _attention(q_chunk, k_chunk, v_chunk, bias_chunks)
+
+ o_chunk = o_chunk.transpose(-2, -3)
+ o_chunks.append(o_chunk)
+
+ o = torch.cat(o_chunks, dim=chunk_dim)
+ return o
+
+
+def attention_core(q, k, v, param):
+ pass
+
+
+class Attention(nn.Module):
+ """
+ Standard multi-head attention using AlphaFold's default layer
+ initialization. Allows multiple bias vectors.
+ """
+
+ def __init__(
+ self,
+ c_q: int,
+ c_k: int,
+ c_v: int,
+ c_hidden: int,
+ no_heads: int,
+ gating: bool = True,
+ ):
+ """
+ Args:
+ c_q:
+ Input dimension of query data
+ c_k:
+ Input dimension of key data
+ c_v:
+ Input dimension of value data
+ c_hidden:
+ Per-head hidden dimension
+ no_heads:
+ Number of attention heads
+ gating:
+ Whether the output should be gated using query data
+ """
+ super(Attention, self).__init__()
+
+ self.c_q = c_q
+ self.c_k = c_k
+ self.c_v = c_v
+ self.c_hidden = c_hidden
+ self.no_heads = no_heads
+ self.gating = gating
+
+ # DISCREPANCY: c_hidden is not the per-head channel dimension, as
+ # stated in the supplement, but the overall channel dimension.
+
+ self.linear_q = Linear(
+ self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot"
+ )
+ self.linear_k = Linear(
+ self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot"
+ )
+ self.linear_v = Linear(
+ self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot"
+ )
+ self.linear_o = Linear(
+ self.c_hidden * self.no_heads, self.c_q, bias=False, init="final"
+ )
+
+ self.linear_g = None
+ if self.gating:
+ self.linear_g = Linear(
+ self.c_q, self.c_hidden * self.no_heads, bias=False, init="gating"
+ )
+
+ self.sigmoid = nn.Sigmoid()
+
+ def _prep_qkv(
+ self, q_x: torch.Tensor, kv_x: torch.Tensor, apply_scale: bool = True
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ # [*, Q/K/V, H * C_hidden]
+ q = self.linear_q(q_x)
+ k = self.linear_k(kv_x)
+ v = self.linear_v(kv_x)
+
+ # [*, Q/K, H, C_hidden]
+ q = q.view(q.shape[:-1] + (self.no_heads, -1))
+ k = k.view(k.shape[:-1] + (self.no_heads, -1))
+ v = v.view(v.shape[:-1] + (self.no_heads, -1))
+
+ # [*, H, Q/K, C_hidden]
+ q = q.transpose(-2, -3)
+ k = k.transpose(-2, -3)
+ v = v.transpose(-2, -3)
+
+ if apply_scale:
+ q /= math.sqrt(self.c_hidden)
+
+ return q, k, v
+
+ def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor:
+ if self.linear_g is not None:
+ g = self.sigmoid(self.linear_g(q_x))
+
+ # [*, Q, H, C_hidden]
+ g = g.view(g.shape[:-1] + (self.no_heads, -1))
+ o = o * g
+
+ # [*, Q, H * C_hidden]
+ o = flatten_final_dims(o, 2)
+
+ # [*, Q, C_q]
+ o = self.linear_o(o)
+
+ return o
+
+ def forward(
+ self,
+ q_x: torch.Tensor,
+ kv_x: torch.Tensor,
+ biases: Optional[List[torch.Tensor]] = None,
+ use_memory_efficient_kernel: bool = False,
+ use_deepspeed_evo_attention: bool = False,
+ use_lma: bool = False,
+ lma_q_chunk_size: int = DEFAULT_LMA_Q_CHUNK_SIZE,
+ lma_kv_chunk_size: int = DEFAULT_LMA_KV_CHUNK_SIZE,
+ use_flash: bool = False,
+ flash_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+ Args:
+ q_x:
+ [*, Q, C_q] query data
+ kv_x:
+ [*, K, C_k] key data
+ biases:
+ List of biases that broadcast to [*, H, Q, K]
+ use_memory_efficient_kernel:
+ Whether to use a custom memory-efficient attention kernel.
+ This should be the default choice for most. If none of the
+ "use_<...>" flags are True, a stock PyTorch implementation
+ is used instead
+ use_deepspeed_evo_attention:
+ Whether to use DeepSpeed memory-efficient attention kernel.
+ If none of the "use_<...>" flags are True, a stock PyTorch
+ implementation is used instead
+ use_lma:
+ Whether to use low-memory attention (Staats & Rabe 2021). If
+ none of the "use_<...>" flags are True, a stock PyTorch
+ implementation is used instead
+ lma_q_chunk_size:
+ Query chunk size (for LMA)
+ lma_kv_chunk_size:
+ Key/Value chunk size (for LMA)
+ Returns
+ [*, Q, C_q] attention update
+ """
+ if use_lma and (lma_q_chunk_size is None or lma_kv_chunk_size is None):
+ raise ValueError(
+ "If use_lma is specified, lma_q_chunk_size and "
+ "lma_kv_chunk_size must be provided"
+ )
+
+ if use_flash and biases is not None:
+ raise ValueError(
+ "use_flash is incompatible with the bias option. For masking, "
+ "use flash_mask instead"
+ )
+
+ attn_options = [
+ use_memory_efficient_kernel,
+ use_deepspeed_evo_attention,
+ use_lma,
+ use_flash,
+ ]
+ if sum(attn_options) > 1:
+ raise ValueError("Choose at most one alternative attention algorithm")
+
+ if biases is None:
+ biases = []
+
+ # DeepSpeed attention kernel applies scaling internally
+ q, k, v = self._prep_qkv(q_x, kv_x, apply_scale=not use_deepspeed_evo_attention)
+
+ if is_fp16_enabled():
+ use_memory_efficient_kernel = False
+
+ if use_memory_efficient_kernel:
+ if len(biases) > 2:
+ raise ValueError(
+ "If use_memory_efficient_kernel is True, you may only "
+ "provide up to two bias terms"
+ )
+ o = attention_core(q, k, v, *((biases + [None] * 2)[:2]))
+ o = o.transpose(-2, -3)
+ elif use_deepspeed_evo_attention:
+ if len(biases) > 2:
+ raise ValueError(
+ "If use_deepspeed_evo_attention is True, you may only "
+ "provide up to two bias terms"
+ )
+ o = _deepspeed_evo_attn(q, k, v, biases)
+ elif use_lma:
+ biases = [
+ b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
+ for b in biases
+ ]
+ o = _lma(q, k, v, biases, lma_q_chunk_size, lma_kv_chunk_size)
+ o = o.transpose(-2, -3)
+ elif use_flash:
+ o = _flash_attn(q, k, v, flash_mask)
+ else:
+ o = _attention(q, k, v, biases)
+ o = o.transpose(-2, -3)
+
+ o = self._wrap_up(o, q_x)
+
+ return o
+
+
+@torch.jit.ignore
+def _deepspeed_evo_attn(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ biases: List[torch.Tensor],
+):
+ """ ""
+ Compute attention using the DeepSpeed DS4Sci_EvoformerAttention kernel.
+
+ Args:
+ q:
+ [*, H, Q, C_hidden] query data
+ k:
+ [*, H, K, C_hidden] key data
+ v:
+ [*, H, V, C_hidden] value data
+ biases:
+ List of biases that broadcast to [*, H, Q, K]
+ """
+
+ if not ds4s_is_installed:
+ raise ValueError(
+ "_deepspeed_evo_attn requires that DeepSpeed be installed "
+ "and that the deepspeed.ops.deepspeed4science package exists"
+ )
+
+ def reshape_dims(x):
+ no_batch_dims = len(x.shape[:-3])
+ if no_batch_dims < 2:
+ return x.reshape(*((1,) * (2 - no_batch_dims) + x.shape))
+ if no_batch_dims > 2:
+ return x.reshape(*((x.shape[0], -1) + x.shape[-3:]))
+ return x
+
+ # [*, Q/K, H, C_hidden]
+ q = q.transpose(-2, -3)
+ k = k.transpose(-2, -3)
+ v = v.transpose(-2, -3)
+
+ # Reshape tensors to match expected input shape [B, N, Q/K, H, C_hidden]
+ # for DS4Sci_EvoformerAttention() by adding or flattening batch dims as needed.
+ orig_shape = q.shape
+ if len(orig_shape[:-3]) != 2:
+ q = reshape_dims(q)
+ k = reshape_dims(k)
+ v = reshape_dims(v)
+ biases = [reshape_dims(b) for b in biases]
+
+ # DeepSpeed attn. kernel requires inputs to be type bf16 or fp16
+ # Cast to bf16 so kernel can be used during inference
+ orig_dtype = q.dtype
+ if orig_dtype not in [torch.bfloat16, torch.float16]:
+ o = DS4Sci_EvoformerAttention(
+ q.to(dtype=torch.bfloat16),
+ k.to(dtype=torch.bfloat16),
+ v.to(dtype=torch.bfloat16),
+ [b.to(dtype=torch.bfloat16) for b in biases],
+ )
+
+ o = o.to(dtype=orig_dtype)
+ else:
+ o = DS4Sci_EvoformerAttention(q, k, v, biases)
+
+ o = o.reshape(orig_shape)
+ return o
+
+
+def _lma(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ biases: List[torch.Tensor],
+ q_chunk_size: int,
+ kv_chunk_size: int,
+):
+ no_q, no_kv = q.shape[-2], k.shape[-2]
+
+ # [*, H, Q, C_hidden]
+ o = q.new_zeros(q.shape)
+ for q_s in range(0, no_q, q_chunk_size):
+ q_chunk = q[..., q_s : q_s + q_chunk_size, :]
+ large_bias_chunks = [b[..., q_s : q_s + q_chunk_size, :] for b in biases]
+
+ maxes = []
+ weights = []
+ values = []
+ for kv_s in range(0, no_kv, kv_chunk_size):
+ k_chunk = k[..., kv_s : kv_s + kv_chunk_size, :]
+ v_chunk = v[..., kv_s : kv_s + kv_chunk_size, :]
+ small_bias_chunks = [
+ b[..., kv_s : kv_s + kv_chunk_size] for b in large_bias_chunks
+ ]
+
+ a = torch.einsum(
+ "...hqd,...hkd->...hqk",
+ q_chunk,
+ k_chunk,
+ )
+
+ for b in small_bias_chunks:
+ a += b
+
+ max_a = torch.max(a, dim=-1, keepdim=True)[0]
+ exp_a = torch.exp(a - max_a)
+ exp_v = torch.einsum("...hvf,...hqv->...hqf", v_chunk, exp_a)
+
+ maxes.append(max_a.detach().squeeze(-1))
+ weights.append(torch.sum(exp_a, dim=-1))
+ values.append(exp_v)
+
+ chunk_max = torch.stack(maxes, dim=-3)
+ chunk_weights = torch.stack(weights, dim=-3)
+ chunk_values = torch.stack(values, dim=-4)
+
+ global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0]
+ max_diffs = torch.exp(chunk_max - global_max)
+ chunk_values = chunk_values * max_diffs.unsqueeze(-1)
+ chunk_weights = chunk_weights * max_diffs
+
+ all_values = torch.sum(chunk_values, dim=-4)
+ all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4)
+
+ q_chunk_out = all_values / all_weights
+
+ o[..., q_s : q_s + q_chunk_size, :] = q_chunk_out
+
+ return o
+
+
+@torch.jit.ignore
+def _flash_attn(q, k, v, kv_mask):
+ if not fa_is_installed:
+ raise ValueError("_flash_attn requires that FlashAttention be installed")
+
+ batch_dims = q.shape[:-3]
+ no_heads, n, c = q.shape[-3:]
+ dtype = q.dtype
+
+ q = q.half()
+ k = k.half()
+ v = v.half()
+ kv_mask = kv_mask.half()
+
+ # [*, B, N, H, C]
+ q = q.transpose(-2, -3)
+ k = k.transpose(-2, -3)
+ v = v.transpose(-2, -3)
+
+ # [B_flat, N, H, C]
+ q = q.reshape(-1, *q.shape[-3:])
+ k = k.reshape(-1, *k.shape[-3:])
+ v = v.reshape(-1, *v.shape[-3:])
+
+ # Flattened batch size
+ batch_size = q.shape[0]
+
+ # [B_flat * N, H, C]
+ q = q.reshape(-1, *q.shape[-2:])
+
+ q_max_s = n
+ q_cu_seqlens = torch.arange(
+ 0, (batch_size + 1) * n, step=n, dtype=torch.int32, device=q.device
+ )
+
+ # [B_flat, N, 2, H, C]
+ kv = torch.stack([k, v], dim=-3)
+ kv_shape = kv.shape
+
+ # [B_flat, N, 2 * H * C]
+ kv = kv.reshape(*kv.shape[:-3], -1)
+
+ kv_unpad, _, kv_cu_seqlens, kv_max_s = unpad_input(kv, kv_mask)
+ kv_unpad = kv_unpad.reshape(-1, *kv_shape[-3:])
+
+ out = flash_attn_unpadded_kvpacked_func(
+ q,
+ kv_unpad,
+ q_cu_seqlens,
+ kv_cu_seqlens,
+ q_max_s,
+ kv_max_s,
+ dropout_p=0.0,
+ softmax_scale=1.0, # q has been scaled already
+ )
+
+ # [*, B, N, H, C]
+ out = out.reshape(*batch_dims, n, no_heads, c)
+
+ out = out.to(dtype=dtype)
+
+ return out
diff --git a/src/boltz/model/layers/triangular_attention/utils.py b/src/boltz/model/layers/triangular_attention/utils.py
new file mode 100644
index 0000000..89899da
--- /dev/null
+++ b/src/boltz/model/layers/triangular_attention/utils.py
@@ -0,0 +1,380 @@
+# Copyright 2021 AlQuraishi Laboratory
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from functools import partial
+from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
+
+import torch
+
+
+def add(m1, m2, inplace):
+ # The first operation in a checkpoint can't be in-place, but it's
+ # nice to have in-place addition during inference. Thus...
+ if not inplace:
+ m1 = m1 + m2
+ else:
+ m1 += m2
+
+ return m1
+
+
+def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
+ zero_index = -1 * len(inds)
+ first_inds = list(range(len(tensor.shape[:zero_index])))
+ return tensor.permute(first_inds + [zero_index + i for i in inds])
+
+
+def is_fp16_enabled():
+ # Autocast world
+ fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16
+ fp16_enabled = fp16_enabled and torch.is_autocast_enabled()
+
+ return fp16_enabled
+
+
+# With tree_map, a poor man's JAX tree_map
+def dict_map(fn, dic, leaf_type):
+ new_dict = {}
+ for k, v in dic.items():
+ if type(v) is dict:
+ new_dict[k] = dict_map(fn, v, leaf_type)
+ else:
+ new_dict[k] = tree_map(fn, v, leaf_type)
+
+ return new_dict
+
+
+def tree_map(fn, tree, leaf_type):
+ if isinstance(tree, dict):
+ return dict_map(fn, tree, leaf_type)
+ elif isinstance(tree, list):
+ return [tree_map(fn, x, leaf_type) for x in tree]
+ elif isinstance(tree, tuple):
+ return tuple([tree_map(fn, x, leaf_type) for x in tree])
+ elif isinstance(tree, leaf_type):
+ return fn(tree)
+ else:
+ raise ValueError(f"Tree of type {type(tree)} not supported")
+
+
+tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
+
+
+def flatten_final_dims(t: torch.Tensor, no_dims: int):
+ return t.reshape(t.shape[:-no_dims] + (-1,))
+
+
+def _fetch_dims(tree):
+ shapes = []
+ tree_type = type(tree)
+ if tree_type is dict:
+ for v in tree.values():
+ shapes.extend(_fetch_dims(v))
+ elif tree_type is list or tree_type is tuple:
+ for t in tree:
+ shapes.extend(_fetch_dims(t))
+ elif tree_type is torch.Tensor:
+ shapes.append(tree.shape)
+ else:
+ raise ValueError("Not supported")
+
+ return shapes
+
+
+@torch.jit.ignore
+def _flat_idx_to_idx(
+ flat_idx: int,
+ dims: Tuple[int],
+) -> Tuple[int]:
+ idx = []
+ for d in reversed(dims):
+ idx.append(flat_idx % d)
+ flat_idx = flat_idx // d
+
+ return tuple(reversed(idx))
+
+
+@torch.jit.ignore
+def _get_minimal_slice_set(
+ start: Sequence[int],
+ end: Sequence[int],
+ dims: int,
+ start_edges: Optional[Sequence[bool]] = None,
+ end_edges: Optional[Sequence[bool]] = None,
+) -> Sequence[Tuple[int]]:
+ """
+ Produces an ordered sequence of tensor slices that, when used in
+ sequence on a tensor with shape dims, yields tensors that contain every
+ leaf in the contiguous range [start, end]. Care is taken to yield a
+ short sequence of slices, and perhaps even the shortest possible (I'm
+ pretty sure it's the latter).
+
+ end is INCLUSIVE.
+ """
+
+ # start_edges and end_edges both indicate whether, starting from any given
+ # dimension, the start/end index is at the top/bottom edge of the
+ # corresponding tensor, modeled as a tree
+ def reduce_edge_list(l):
+ tally = 1
+ for i in range(len(l)):
+ reversed_idx = -1 * (i + 1)
+ l[reversed_idx] *= tally
+ tally = l[reversed_idx]
+
+ if start_edges is None:
+ start_edges = [s == 0 for s in start]
+ reduce_edge_list(start_edges)
+ if end_edges is None:
+ end_edges = [e == (d - 1) for e, d in zip(end, dims)]
+ reduce_edge_list(end_edges)
+
+ # Base cases. Either start/end are empty and we're done, or the final,
+ # one-dimensional tensor can be simply sliced
+ if len(start) == 0:
+ return [tuple()]
+ elif len(start) == 1:
+ return [(slice(start[0], end[0] + 1),)]
+
+ slices = []
+ path = []
+
+ # Dimensions common to start and end can be selected directly
+ for s, e in zip(start, end):
+ if s == e:
+ path.append(slice(s, s + 1))
+ else:
+ break
+
+ path = tuple(path)
+ divergence_idx = len(path)
+
+ # start == end, and we're done
+ if divergence_idx == len(dims):
+ return [tuple(path)]
+
+ def upper():
+ sdi = start[divergence_idx]
+ return [
+ path + (slice(sdi, sdi + 1),) + s
+ for s in _get_minimal_slice_set(
+ start[divergence_idx + 1 :],
+ [d - 1 for d in dims[divergence_idx + 1 :]],
+ dims[divergence_idx + 1 :],
+ start_edges=start_edges[divergence_idx + 1 :],
+ end_edges=[1 for _ in end_edges[divergence_idx + 1 :]],
+ )
+ ]
+
+ def lower():
+ edi = end[divergence_idx]
+ return [
+ path + (slice(edi, edi + 1),) + s
+ for s in _get_minimal_slice_set(
+ [0 for _ in start[divergence_idx + 1 :]],
+ end[divergence_idx + 1 :],
+ dims[divergence_idx + 1 :],
+ start_edges=[1 for _ in start_edges[divergence_idx + 1 :]],
+ end_edges=end_edges[divergence_idx + 1 :],
+ )
+ ]
+
+ # If both start and end are at the edges of the subtree rooted at
+ # divergence_idx, we can just select the whole subtree at once
+ if start_edges[divergence_idx] and end_edges[divergence_idx]:
+ slices.append(path + (slice(start[divergence_idx], end[divergence_idx] + 1),))
+ # If just start is at the edge, we can grab almost all of the subtree,
+ # treating only the ragged bottom edge as an edge case
+ elif start_edges[divergence_idx]:
+ slices.append(path + (slice(start[divergence_idx], end[divergence_idx]),))
+ slices.extend(lower())
+ # Analogous to the previous case, but the top is ragged this time
+ elif end_edges[divergence_idx]:
+ slices.extend(upper())
+ slices.append(
+ path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),)
+ )
+ # If both sides of the range are ragged, we need to handle both sides
+ # separately. If there's contiguous meat in between them, we can index it
+ # in one big chunk
+ else:
+ slices.extend(upper())
+ middle_ground = end[divergence_idx] - start[divergence_idx]
+ if middle_ground > 1:
+ slices.append(
+ path + (slice(start[divergence_idx] + 1, end[divergence_idx]),)
+ )
+ slices.extend(lower())
+
+ return [tuple(s) for s in slices]
+
+
+@torch.jit.ignore
+def _chunk_slice(
+ t: torch.Tensor,
+ flat_start: int,
+ flat_end: int,
+ no_batch_dims: int,
+) -> torch.Tensor:
+ """
+ Equivalent to
+
+ t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
+
+ but without the need for the initial reshape call, which can be
+ memory-intensive in certain situations. The only reshape operations
+ in this function are performed on sub-tensors that scale with
+ (flat_end - flat_start), the chunk size.
+ """
+
+ batch_dims = t.shape[:no_batch_dims]
+ start_idx = list(_flat_idx_to_idx(flat_start, batch_dims))
+ # _get_minimal_slice_set is inclusive
+ end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims))
+
+ # Get an ordered list of slices to perform
+ slices = _get_minimal_slice_set(
+ start_idx,
+ end_idx,
+ batch_dims,
+ )
+
+ sliced_tensors = [t[s] for s in slices]
+
+ return torch.cat([s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors])
+
+
+def chunk_layer(
+ layer: Callable,
+ inputs: Dict[str, Any],
+ chunk_size: int,
+ no_batch_dims: int,
+ low_mem: bool = False,
+ _out: Any = None,
+ _add_into_out: bool = False,
+) -> Any:
+ """
+ Implements the "chunking" procedure described in section 1.11.8.
+
+ Layer outputs and inputs are assumed to be simple "pytrees,"
+ consisting only of (arbitrarily nested) lists, tuples, and dicts with
+ torch.Tensor leaves.
+
+ Args:
+ layer:
+ The layer to be applied chunk-wise
+ inputs:
+ A (non-nested) dictionary of keyworded inputs. All leaves must
+ be tensors and must share the same batch dimensions.
+ chunk_size:
+ The number of sub-batches per chunk. If multiple batch
+ dimensions are specified, a "sub-batch" is defined as a single
+ indexing of all batch dimensions simultaneously (s.t. the
+ number of sub-batches is the product of the batch dimensions).
+ no_batch_dims:
+ How many of the initial dimensions of each input tensor can
+ be considered batch dimensions.
+ low_mem:
+ Avoids flattening potentially large input tensors. Unnecessary
+ in most cases, and is ever so slightly slower than the default
+ setting.
+ Returns:
+ The reassembled output of the layer on the inputs.
+ """
+ if not (len(inputs) > 0):
+ raise ValueError("Must provide at least one input")
+
+ initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)]
+ orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
+
+ def _prep_inputs(t):
+ if not low_mem:
+ if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
+ t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
+ t = t.reshape(-1, *t.shape[no_batch_dims:])
+ else:
+ t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
+ return t
+
+ prepped_inputs = tensor_tree_map(_prep_inputs, inputs)
+ prepped_outputs = None
+ if _out is not None:
+ reshape_fn = lambda t: t.view([-1] + list(t.shape[no_batch_dims:]))
+ prepped_outputs = tensor_tree_map(reshape_fn, _out)
+
+ flat_batch_dim = 1
+ for d in orig_batch_dims:
+ flat_batch_dim *= d
+
+ no_chunks = flat_batch_dim // chunk_size + (flat_batch_dim % chunk_size != 0)
+
+ i = 0
+ out = prepped_outputs
+ for _ in range(no_chunks):
+ # Chunk the input
+ if not low_mem:
+ select_chunk = lambda t: t[i : i + chunk_size] if t.shape[0] != 1 else t
+ else:
+ select_chunk = partial(
+ _chunk_slice,
+ flat_start=i,
+ flat_end=min(flat_batch_dim, i + chunk_size),
+ no_batch_dims=len(orig_batch_dims),
+ )
+
+ chunks = tensor_tree_map(select_chunk, prepped_inputs)
+
+ # Run the layer on the chunk
+ output_chunk = layer(**chunks)
+
+ # Allocate space for the output
+ if out is None:
+ allocate = lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:])
+ out = tensor_tree_map(allocate, output_chunk)
+
+ # Put the chunk in its pre-allocated space
+ out_type = type(output_chunk)
+ if out_type is dict:
+
+ def assign(d1, d2):
+ for k, v in d1.items():
+ if type(v) is dict:
+ assign(v, d2[k])
+ else:
+ if _add_into_out:
+ v[i : i + chunk_size] += d2[k]
+ else:
+ v[i : i + chunk_size] = d2[k]
+
+ assign(out, output_chunk)
+ elif out_type is tuple:
+ for x1, x2 in zip(out, output_chunk):
+ if _add_into_out:
+ x1[i : i + chunk_size] += x2
+ else:
+ x1[i : i + chunk_size] = x2
+ elif out_type is torch.Tensor:
+ if _add_into_out:
+ out[i : i + chunk_size] += output_chunk
+ else:
+ out[i : i + chunk_size] = output_chunk
+ else:
+ raise ValueError("Not supported")
+
+ i += chunk_size
+
+ reshape = lambda t: t.view(orig_batch_dims + t.shape[1:])
+ out = tensor_tree_map(reshape, out)
+
+ return out
diff --git a/src/boltz/model/layers/triangular_mult.py b/src/boltz/model/layers/triangular_mult.py
new file mode 100644
index 0000000..ba06662
--- /dev/null
+++ b/src/boltz/model/layers/triangular_mult.py
@@ -0,0 +1,144 @@
+import torch
+from torch import Tensor, nn
+
+from boltz.model.layers import initialize as init
+
+
+class TriangleMultiplicationOutgoing(nn.Module):
+ """TriangleMultiplicationOutgoing."""
+
+ def __init__(self, dim: int = 128) -> None:
+ """Initialize the TriangularUpdate module.
+
+ Parameters
+ ----------
+ dim: int
+ The dimension of the input, default 128
+
+ """
+ super().__init__()
+
+ self.norm_in = nn.LayerNorm(dim, eps=1e-5)
+ self.p_in = nn.Linear(dim, 2 * dim, bias=False)
+ self.g_in = nn.Linear(dim, 2 * dim, bias=False)
+
+ self.norm_out = nn.LayerNorm(dim)
+ self.p_out = nn.Linear(dim, dim, bias=False)
+ self.g_out = nn.Linear(dim, dim, bias=False)
+
+ init.bias_init_one_(self.norm_in.weight)
+ init.bias_init_zero_(self.norm_in.bias)
+
+ init.lecun_normal_init_(self.p_in.weight)
+ init.gating_init_(self.g_in.weight)
+
+ init.bias_init_one_(self.norm_out.weight)
+ init.bias_init_zero_(self.norm_out.bias)
+
+ init.final_init_(self.p_out.weight)
+ init.gating_init_(self.g_out.weight)
+
+ def forward(self, x: Tensor, mask: Tensor) -> Tensor:
+ """Perform a forward pass.
+
+ Parameters
+ ----------
+ x: torch.Tensor
+ The input data of shape (B, N, N, D)
+ mask: torch.Tensor
+ The input mask of shape (B, N, N)
+
+ Returns
+ -------
+ x: torch.Tensor
+ The output data of shape (B, N, N, D)
+
+ """
+ # Input gating: D -> D
+ x = self.norm_in(x)
+ x_in = x
+ x = self.p_in(x) * self.g_in(x).sigmoid()
+
+ # Apply mask
+ x = x * mask.unsqueeze(-1)
+
+ # Split input and cast to float
+ a, b = torch.chunk(x.float(), 2, dim=-1)
+
+ # Triangular projection
+ x = torch.einsum("bikd,bjkd->bijd", a, b)
+
+ # Output gating
+ x = self.p_out(self.norm_out(x)) * self.g_out(x_in).sigmoid()
+
+ return x
+
+
+class TriangleMultiplicationIncoming(nn.Module):
+ """TriangleMultiplicationIncoming."""
+
+ def __init__(self, dim: int = 128) -> None:
+ """Initialize the TriangularUpdate module.
+
+ Parameters
+ ----------
+ dim: int
+ The dimension of the input, default 128
+
+ """
+ super().__init__()
+
+ self.norm_in = nn.LayerNorm(dim, eps=1e-5)
+ self.p_in = nn.Linear(dim, 2 * dim, bias=False)
+ self.g_in = nn.Linear(dim, 2 * dim, bias=False)
+
+ self.norm_out = nn.LayerNorm(dim)
+ self.p_out = nn.Linear(dim, dim, bias=False)
+ self.g_out = nn.Linear(dim, dim, bias=False)
+
+ init.bias_init_one_(self.norm_in.weight)
+ init.bias_init_zero_(self.norm_in.bias)
+
+ init.lecun_normal_init_(self.p_in.weight)
+ init.gating_init_(self.g_in.weight)
+
+ init.bias_init_one_(self.norm_out.weight)
+ init.bias_init_zero_(self.norm_out.bias)
+
+ init.final_init_(self.p_out.weight)
+ init.gating_init_(self.g_out.weight)
+
+ def forward(self, x: Tensor, mask: Tensor) -> Tensor:
+ """Perform a forward pass.
+
+ Parameters
+ ----------
+ x: torch.Tensor
+ The input data of shape (B, N, N, D)
+ mask: torch.Tensor
+ The input mask of shape (B, N, N)
+
+ Returns
+ -------
+ x: torch.Tensor
+ The output data of shape (B, N, N, D)
+
+ """
+ # Input gating: D -> D
+ x = self.norm_in(x)
+ x_in = x
+ x = self.p_in(x) * self.g_in(x).sigmoid()
+
+ # Apply mask
+ x = x * mask.unsqueeze(-1)
+
+ # Split input and cast to float
+ a, b = torch.chunk(x.float(), 2, dim=-1)
+
+ # Triangular projection
+ x = torch.einsum("bkid,bkjd->bijd", a, b)
+
+ # Output gating
+ x = self.p_out(self.norm_out(x)) * self.g_out(x_in).sigmoid()
+
+ return x
diff --git a/src/boltz/model/loss/__init__.py b/src/boltz/model/loss/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/boltz/model/loss/confidence.py b/src/boltz/model/loss/confidence.py
new file mode 100644
index 0000000..7080c9d
--- /dev/null
+++ b/src/boltz/model/loss/confidence.py
@@ -0,0 +1,590 @@
+import torch
+from torch import nn
+
+from boltz.data import const
+
+
+def confidence_loss(
+ model_out,
+ feats,
+ true_coords,
+ true_coords_resolved_mask,
+ multiplicity=1,
+ alpha_pae=0.0,
+):
+ """Compute confidence loss.
+
+ Parameters
+ ----------
+ model_out: Dict[str, torch.Tensor]
+ Dictionary containing the model output
+ feats: Dict[str, torch.Tensor]
+ Dictionary containing the model input
+ true_coords: torch.Tensor
+ The atom coordinates after symmetry correction
+ true_coords_resolved_mask: torch.Tensor
+ The resolved mask after symmetry correction
+ multiplicity: int, optional
+ The diffusion batch size, by default 1
+ alpha_pae: float, optional
+ The weight of the pae loss, by default 0.0
+
+ Returns
+ -------
+ Dict[str, torch.Tensor]
+ Loss breakdown
+
+ """
+ # Compute losses
+ plddt = plddt_loss(
+ model_out["plddt_logits"],
+ model_out["sample_atom_coords"],
+ true_coords,
+ true_coords_resolved_mask,
+ feats,
+ multiplicity=multiplicity,
+ )
+ pde = pde_loss(
+ model_out["pde_logits"],
+ model_out["sample_atom_coords"],
+ true_coords,
+ true_coords_resolved_mask,
+ feats,
+ multiplicity,
+ )
+ resolved = resolved_loss(
+ model_out["resolved_logits"],
+ feats,
+ true_coords_resolved_mask,
+ multiplicity=multiplicity,
+ )
+
+ pae = 0.0
+ if alpha_pae > 0.0:
+ pae = pae_loss(
+ model_out["pae_logits"],
+ model_out["sample_atom_coords"],
+ true_coords,
+ true_coords_resolved_mask,
+ feats,
+ multiplicity,
+ )
+
+ loss = plddt + pde + resolved + alpha_pae * pae
+
+ dict_out = {
+ "loss": loss,
+ "loss_breakdown": {
+ "plddt_loss": plddt,
+ "pde_loss": pde,
+ "resolved_loss": resolved,
+ "pae_loss": pae,
+ },
+ }
+ return dict_out
+
+
+def resolved_loss(
+ pred_resolved,
+ feats,
+ true_coords_resolved_mask,
+ multiplicity=1,
+):
+ """Compute resolved loss.
+
+ Parameters
+ ----------
+ pred_resolved: torch.Tensor
+ The resolved logits
+ feats: Dict[str, torch.Tensor]
+ Dictionary containing the model input
+ true_coords_resolved_mask: torch.Tensor
+ The resolved mask after symmetry correction
+ multiplicity: int, optional
+ The diffusion batch size, by default 1
+
+ Returns
+ -------
+ torch.Tensor
+ Resolved loss
+
+ """
+
+ # extract necessary features
+ token_to_rep_atom = feats["token_to_rep_atom"]
+ token_to_rep_atom = token_to_rep_atom.repeat_interleave(multiplicity, 0).float()
+ ref_mask = torch.bmm(
+ token_to_rep_atom, true_coords_resolved_mask.unsqueeze(-1).float()
+ ).squeeze(-1)
+ pad_mask = feats["token_pad_mask"]
+ pad_mask = pad_mask.repeat_interleave(multiplicity, 0).float()
+
+ # compute loss
+ log_softmax_resolved = torch.nn.functional.log_softmax(pred_resolved, dim=-1)
+ errors = (
+ -ref_mask * log_softmax_resolved[:, :, 0]
+ - (1 - ref_mask) * log_softmax_resolved[:, :, 1]
+ )
+ loss = torch.sum(errors * pad_mask, dim=-1) / (1e-7 + torch.sum(pad_mask, dim=-1))
+
+ # Average over the batch dimension
+ loss = torch.mean(loss)
+
+ return loss
+
+
+def plddt_loss(
+ pred_lddt,
+ pred_atom_coords,
+ true_atom_coords,
+ true_coords_resolved_mask,
+ feats,
+ multiplicity=1,
+):
+ """Compute plddt loss.
+
+ Parameters
+ ----------
+ pred_lddt: torch.Tensor
+ The plddt logits
+ pred_atom_coords: torch.Tensor
+ The predicted atom coordinates
+ true_atom_coords: torch.Tensor
+ The atom coordinates after symmetry correction
+ true_coords_resolved_mask: torch.Tensor
+ The resolved mask after symmetry correction
+ feats: Dict[str, torch.Tensor]
+ Dictionary containing the model input
+ multiplicity: int, optional
+ The diffusion batch size, by default 1
+
+ Returns
+ -------
+ torch.Tensor
+ Plddt loss
+
+ """
+
+ # extract necessary features
+ atom_mask = true_coords_resolved_mask
+
+ R_set_to_rep_atom = feats["r_set_to_rep_atom"]
+ R_set_to_rep_atom = R_set_to_rep_atom.repeat_interleave(multiplicity, 0).float()
+
+ token_type = feats["mol_type"]
+ token_type = token_type.repeat_interleave(multiplicity, 0)
+ is_nucleotide_token = (token_type == const.chain_type_ids["DNA"]).float() + (
+ token_type == const.chain_type_ids["RNA"]
+ ).float()
+
+ B = true_atom_coords.shape[0]
+
+ atom_to_token = feats["atom_to_token"].float()
+ atom_to_token = atom_to_token.repeat_interleave(multiplicity, 0)
+
+ token_to_rep_atom = feats["token_to_rep_atom"].float()
+ token_to_rep_atom = token_to_rep_atom.repeat_interleave(multiplicity, 0)
+
+ true_token_coords = torch.bmm(token_to_rep_atom, true_atom_coords)
+ pred_token_coords = torch.bmm(token_to_rep_atom, pred_atom_coords)
+
+ # compute true lddt
+ true_d = torch.cdist(
+ true_token_coords,
+ torch.bmm(R_set_to_rep_atom, true_atom_coords),
+ )
+ pred_d = torch.cdist(
+ pred_token_coords,
+ torch.bmm(R_set_to_rep_atom, pred_atom_coords),
+ )
+
+ # compute mask
+ pair_mask = atom_mask.unsqueeze(-1) * atom_mask.unsqueeze(-2)
+ pair_mask = (
+ pair_mask
+ * (1 - torch.eye(pair_mask.shape[1], device=pair_mask.device))[None, :, :]
+ )
+ pair_mask = torch.einsum("bnm,bkm->bnk", pair_mask, R_set_to_rep_atom)
+ pair_mask = torch.bmm(token_to_rep_atom, pair_mask)
+ atom_mask = torch.bmm(token_to_rep_atom, atom_mask.unsqueeze(-1).float())
+ is_nucleotide_R_element = torch.bmm(
+ R_set_to_rep_atom, torch.bmm(atom_to_token, is_nucleotide_token.unsqueeze(-1))
+ ).squeeze(-1)
+ cutoff = 15 + 15 * is_nucleotide_R_element.reshape(B, 1, -1).repeat(
+ 1, true_d.shape[1], 1
+ )
+
+ # compute lddt
+ target_lddt, mask_no_match = lddt_dist(
+ pred_d, true_d, pair_mask, cutoff, per_atom=True
+ )
+
+ # compute loss
+ num_bins = pred_lddt.shape[-1]
+ bin_index = torch.floor(target_lddt * num_bins).long()
+ bin_index = torch.clamp(bin_index, max=(num_bins - 1))
+ lddt_one_hot = nn.functional.one_hot(bin_index, num_classes=num_bins)
+ errors = -1 * torch.sum(
+ lddt_one_hot * torch.nn.functional.log_softmax(pred_lddt, dim=-1),
+ dim=-1,
+ )
+ atom_mask = atom_mask.squeeze(-1)
+ loss = torch.sum(errors * atom_mask * mask_no_match, dim=-1) / (
+ 1e-7 + torch.sum(atom_mask * mask_no_match, dim=-1)
+ )
+
+ # Average over the batch dimension
+ loss = torch.mean(loss)
+
+ return loss
+
+
+def pde_loss(
+ pred_pde,
+ pred_atom_coords,
+ true_atom_coords,
+ true_coords_resolved_mask,
+ feats,
+ multiplicity=1,
+ max_dist=32.0,
+):
+ """Compute pde loss.
+
+ Parameters
+ ----------
+ pred_pde: torch.Tensor
+ The pde logits
+ pred_atom_coords: torch.Tensor
+ The predicted atom coordinates
+ true_atom_coords: torch.Tensor
+ The atom coordinates after symmetry correction
+ true_coords_resolved_mask: torch.Tensor
+ The resolved mask after symmetry correction
+ feats: Dict[str, torch.Tensor]
+ Dictionary containing the model input
+ multiplicity: int, optional
+ The diffusion batch size, by default 1
+
+ Returns
+ -------
+ torch.Tensor
+ Pde loss
+
+ """
+
+ # extract necessary features
+ token_to_rep_atom = feats["token_to_rep_atom"]
+ token_to_rep_atom = token_to_rep_atom.repeat_interleave(multiplicity, 0).float()
+ token_mask = torch.bmm(
+ token_to_rep_atom, true_coords_resolved_mask.unsqueeze(-1).float()
+ ).squeeze(-1)
+ mask = token_mask.unsqueeze(-1) * token_mask.unsqueeze(-2)
+
+ # compute true pde
+ true_token_coords = torch.bmm(token_to_rep_atom, true_atom_coords)
+ pred_token_coords = torch.bmm(token_to_rep_atom, pred_atom_coords)
+
+ true_d = torch.cdist(true_token_coords, true_token_coords)
+ pred_d = torch.cdist(pred_token_coords, pred_token_coords)
+ target_pde = torch.abs(true_d - pred_d)
+
+ # compute loss
+ num_bins = pred_pde.shape[-1]
+ bin_index = torch.floor(target_pde * num_bins / max_dist).long()
+ bin_index = torch.clamp(bin_index, max=(num_bins - 1))
+ pde_one_hot = nn.functional.one_hot(bin_index, num_classes=num_bins)
+ errors = -1 * torch.sum(
+ pde_one_hot * torch.nn.functional.log_softmax(pred_pde, dim=-1),
+ dim=-1,
+ )
+ loss = torch.sum(errors * mask, dim=(-2, -1)) / (
+ 1e-7 + torch.sum(mask, dim=(-2, -1))
+ )
+
+ # Average over the batch dimension
+ loss = torch.mean(loss)
+
+ return loss
+
+
+def pae_loss(
+ pred_pae,
+ pred_atom_coords,
+ true_atom_coords,
+ true_coords_resolved_mask,
+ feats,
+ multiplicity=1,
+ max_dist=32.0,
+):
+ """Compute pae loss.
+
+ Parameters
+ ----------
+ pred_pae: torch.Tensor
+ The pae logits
+ pred_atom_coords: torch.Tensor
+ The predicted atom coordinates
+ true_atom_coords: torch.Tensor
+ The atom coordinates after symmetry correction
+ true_coords_resolved_mask: torch.Tensor
+ The resolved mask after symmetry correction
+ feats: Dict[str, torch.Tensor]
+ Dictionary containing the model input
+ multiplicity: int, optional
+ The diffusion batch size, by default 1
+
+ Returns
+ -------
+ torch.Tensor
+ Pae loss
+
+ """
+ # Retrieve frames and resolved masks
+ frames_idx_original = feats["frames_idx"]
+ mask_frame_true = feats["frame_resolved_mask"]
+
+ # Adjust the frames for nonpolymers after symmetry correction!
+ # NOTE: frames of polymers do not change under symmetry!
+ frames_idx_true, mask_collinear_true = compute_frame_pred(
+ true_atom_coords,
+ frames_idx_original,
+ feats,
+ multiplicity,
+ resolved_mask=true_coords_resolved_mask,
+ )
+
+ frame_true_atom_a, frame_true_atom_b, frame_true_atom_c = (
+ frames_idx_true[:, :, :, 0],
+ frames_idx_true[:, :, :, 1],
+ frames_idx_true[:, :, :, 2],
+ )
+ # Compute token coords in true frames
+ B, N, _ = true_atom_coords.shape
+ true_atom_coords = true_atom_coords.reshape(B // multiplicity, multiplicity, -1, 3)
+ true_coords_transformed = express_coordinate_in_frame(
+ true_atom_coords, frame_true_atom_a, frame_true_atom_b, frame_true_atom_c
+ )
+
+ # Compute pred frames and mask
+ frames_idx_pred, mask_collinear_pred = compute_frame_pred(
+ pred_atom_coords, frames_idx_original, feats, multiplicity
+ )
+ frame_pred_atom_a, frame_pred_atom_b, frame_pred_atom_c = (
+ frames_idx_pred[:, :, :, 0],
+ frames_idx_pred[:, :, :, 1],
+ frames_idx_pred[:, :, :, 2],
+ )
+ # Compute token coords in pred frames
+ B, N, _ = pred_atom_coords.shape
+ pred_atom_coords = pred_atom_coords.reshape(B // multiplicity, multiplicity, -1, 3)
+ pred_coords_transformed = express_coordinate_in_frame(
+ pred_atom_coords, frame_pred_atom_a, frame_pred_atom_b, frame_pred_atom_c
+ )
+
+ target_pae = torch.sqrt(
+ ((true_coords_transformed - pred_coords_transformed) ** 2).sum(-1) + 1e-8
+ )
+
+ # Compute mask for the pae loss
+ b_true_resolved_mask = true_coords_resolved_mask[
+ torch.arange(B // multiplicity)[:, None, None].to(
+ pred_coords_transformed.device
+ ),
+ frame_true_atom_b,
+ ]
+
+ pair_mask = (
+ mask_frame_true[:, None, :, None] # if true frame is invalid
+ * mask_collinear_true[:, :, :, None] # if true frame is invalid
+ * mask_collinear_pred[:, :, :, None] # if pred frame is invalid
+ * b_true_resolved_mask[:, :, None, :] # If atom j is not resolved
+ * feats["token_pad_mask"][:, None, :, None]
+ * feats["token_pad_mask"][:, None, None, :]
+ )
+
+ # compute loss
+ num_bins = pred_pae.shape[-1]
+ bin_index = torch.floor(target_pae * num_bins / max_dist).long()
+ bin_index = torch.clamp(bin_index, max=(num_bins - 1))
+ pae_one_hot = nn.functional.one_hot(bin_index, num_classes=num_bins)
+ errors = -1 * torch.sum(
+ pae_one_hot
+ * torch.nn.functional.log_softmax(pred_pae.reshape(pae_one_hot.shape), dim=-1),
+ dim=-1,
+ )
+ loss = torch.sum(errors * pair_mask, dim=(-2, -1)) / (
+ 1e-7 + torch.sum(pair_mask, dim=(-2, -1))
+ )
+ # Average over the batch dimension
+ loss = torch.mean(loss)
+
+ return loss
+
+
+def lddt_dist(dmat_predicted, dmat_true, mask, cutoff=15.0, per_atom=False):
+ # NOTE: the mask is a pairwise mask which should have the identity elements already masked out
+ # Compute mask over distances
+ dists_to_score = (dmat_true < cutoff).float() * mask
+ dist_l1 = torch.abs(dmat_true - dmat_predicted)
+
+ score = 0.25 * (
+ (dist_l1 < 0.5).float()
+ + (dist_l1 < 1.0).float()
+ + (dist_l1 < 2.0).float()
+ + (dist_l1 < 4.0).float()
+ )
+
+ # Normalize over the appropriate axes.
+ if per_atom:
+ mask_no_match = torch.sum(dists_to_score, dim=-1) != 0
+ norm = 1.0 / (1e-10 + torch.sum(dists_to_score, dim=-1))
+ score = norm * (1e-10 + torch.sum(dists_to_score * score, dim=-1))
+ return score, mask_no_match.float()
+ else:
+ norm = 1.0 / (1e-10 + torch.sum(dists_to_score, dim=(-2, -1)))
+ score = norm * (1e-10 + torch.sum(dists_to_score * score, dim=(-2, -1)))
+ total = torch.sum(dists_to_score, dim=(-1, -2))
+ return score, total
+
+
+def express_coordinate_in_frame(atom_coords, frame_atom_a, frame_atom_b, frame_atom_c):
+ batch, multiplicity = atom_coords.shape[0], atom_coords.shape[1]
+ batch_indices0 = torch.arange(batch)[:, None, None].to(atom_coords.device)
+ batch_indices1 = torch.arange(multiplicity)[None, :, None].to(atom_coords.device)
+
+ # extract frame atoms
+ a, b, c = (
+ atom_coords[batch_indices0, batch_indices1, frame_atom_a],
+ atom_coords[batch_indices0, batch_indices1, frame_atom_b],
+ atom_coords[batch_indices0, batch_indices1, frame_atom_c],
+ )
+ w1 = (a - b) / (torch.norm(a - b, dim=-1, keepdim=True) + 1e-5)
+ w2 = (c - b) / (torch.norm(c - b, dim=-1, keepdim=True) + 1e-5)
+
+ # build orthogonal frame
+ e1 = (w1 + w2) / (torch.norm(w1 + w2, dim=-1, keepdim=True) + 1e-5)
+ e2 = (w2 - w1) / (torch.norm(w2 - w1, dim=-1, keepdim=True) + 1e-5)
+ e3 = torch.linalg.cross(e1, e2)
+
+ # project onto frame basis
+ d = b[:, :, None, :, :] - b[:, :, :, None, :]
+ x_transformed = torch.cat(
+ [
+ torch.sum(d * e1[:, :, :, None, :], dim=-1, keepdim=True),
+ torch.sum(d * e2[:, :, :, None, :], dim=-1, keepdim=True),
+ torch.sum(d * e3[:, :, :, None, :], dim=-1, keepdim=True),
+ ],
+ dim=-1,
+ )
+ return x_transformed
+
+
+def compute_collinear_mask(v1, v2):
+ # Compute the mask for collinear or overlapping atoms
+ norm1 = torch.norm(v1, dim=1, keepdim=True)
+ norm2 = torch.norm(v2, dim=1, keepdim=True)
+ v1 = v1 / (norm1 + 1e-6)
+ v2 = v2 / (norm2 + 1e-6)
+ mask_angle = torch.abs(torch.sum(v1 * v2, dim=1)) < 0.9063
+ mask_overlap1 = norm1.reshape(-1) > 1e-2
+ mask_overlap2 = norm2.reshape(-1) > 1e-2
+ return mask_angle & mask_overlap1 & mask_overlap2
+
+
+def compute_frame_pred(
+ pred_atom_coords,
+ frames_idx_true,
+ feats,
+ multiplicity,
+ resolved_mask=None,
+ inference=False,
+):
+ # extract necessary features
+ asym_id_token = feats["asym_id"]
+ asym_id_atom = torch.bmm(
+ feats["atom_to_token"].float(), asym_id_token.unsqueeze(-1).float()
+ ).squeeze(-1)
+ B, N, _ = pred_atom_coords.shape
+ pred_atom_coords = pred_atom_coords.reshape(B // multiplicity, multiplicity, -1, 3)
+ frames_idx_pred = (
+ frames_idx_true.clone()
+ .repeat_interleave(multiplicity, 0)
+ .reshape(B // multiplicity, multiplicity, -1, 3)
+ )
+
+ # Iterate through the batch and update the frames for nonpolymers
+ for i, pred_atom_coord in enumerate(pred_atom_coords):
+ token_idx = 0
+ atom_idx = 0
+ for id in torch.unique(asym_id_token[i]):
+ mask_chain_token = (asym_id_token[i] == id) * feats["token_pad_mask"][i]
+ mask_chain_atom = (asym_id_atom[i] == id) * feats["atom_pad_mask"][i]
+ num_tokens = int(mask_chain_token.sum().item())
+ num_atoms = int(mask_chain_atom.sum().item())
+ if (
+ feats["mol_type"][i, token_idx] != const.chain_type_ids["NONPOLYMER"]
+ or num_atoms < 3
+ ):
+ token_idx += num_tokens
+ atom_idx += num_atoms
+ continue
+ dist_mat = (
+ (
+ pred_atom_coord[:, mask_chain_atom.bool()][:, None, :, :]
+ - pred_atom_coord[:, mask_chain_atom.bool()][:, :, None, :]
+ )
+ ** 2
+ ).sum(-1) ** 0.5
+
+ # Sort the atoms by distance
+ if inference:
+ resolved_pair = 1 - (
+ feats["atom_pad_mask"][i][mask_chain_atom.bool()][None, :]
+ * feats["atom_pad_mask"][i][mask_chain_atom.bool()][:, None]
+ ).to(torch.float32)
+ resolved_pair[resolved_pair == 1] = torch.inf
+ indices = torch.sort(dist_mat + resolved_pair, axis=2).indices
+ else:
+ if resolved_mask is None:
+ resolved_mask = feats["atom_resolved_mask"]
+ resolved_pair = 1 - (
+ resolved_mask[i][mask_chain_atom.bool()][None, :]
+ * resolved_mask[i][mask_chain_atom.bool()][:, None]
+ ).to(torch.float32)
+ resolved_pair[resolved_pair == 1] = torch.inf
+ indices = torch.sort(dist_mat + resolved_pair, axis=2).indices
+
+ # Compute the frames
+ frames = (
+ torch.cat(
+ [
+ indices[:, :, 1:2],
+ indices[:, :, 0:1],
+ indices[:, :, 2:3],
+ ],
+ dim=2,
+ )
+ + atom_idx
+ )
+ frames_idx_pred[i, :, token_idx : token_idx + num_atoms, :] = frames
+ token_idx += num_tokens
+ atom_idx += num_atoms
+
+ # Expand the frames with the multiplicity
+ frames_expanded = pred_atom_coords[
+ torch.arange(0, B // multiplicity, 1)[:, None, None, None].to(
+ frames_idx_pred.device
+ ),
+ torch.arange(0, multiplicity, 1)[None, :, None, None].to(
+ frames_idx_pred.device
+ ),
+ frames_idx_pred,
+ ].reshape(-1, 3, 3)
+
+ # Compute masks for collinear or overlapping atoms in the frame
+ mask_collinear_pred = compute_collinear_mask(
+ frames_expanded[:, 1] - frames_expanded[:, 0],
+ frames_expanded[:, 1] - frames_expanded[:, 2],
+ ).reshape(B // multiplicity, multiplicity, -1)
+
+ return frames_idx_pred, mask_collinear_pred * feats["token_pad_mask"][:, None, :]
diff --git a/src/boltz/model/loss/diffusion.py b/src/boltz/model/loss/diffusion.py
new file mode 100644
index 0000000..3433e42
--- /dev/null
+++ b/src/boltz/model/loss/diffusion.py
@@ -0,0 +1,171 @@
+# started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
+
+from einops import einsum
+import torch
+import torch.nn.functional as F
+
+
+def weighted_rigid_align(
+ true_coords,
+ pred_coords,
+ weights,
+ mask,
+):
+ """Compute weighted alignment.
+
+ Parameters
+ ----------
+ true_coords: torch.Tensor
+ The ground truth atom coordinates
+ pred_coords: torch.Tensor
+ The predicted atom coordinates
+ weights: torch.Tensor
+ The weights for alignment
+ mask: torch.Tensor
+ The atoms mask
+
+ Returns
+ -------
+ torch.Tensor
+ Aligned coordinates
+
+ """
+
+ batch_size, num_points, dim = true_coords.shape
+ weights = (mask * weights).unsqueeze(-1)
+
+ # Compute weighted centroids
+ true_centroid = (true_coords * weights).sum(dim=1, keepdim=True) / weights.sum(
+ dim=1, keepdim=True
+ )
+ pred_centroid = (pred_coords * weights).sum(dim=1, keepdim=True) / weights.sum(
+ dim=1, keepdim=True
+ )
+
+ # Center the coordinates
+ true_coords_centered = true_coords - true_centroid
+ pred_coords_centered = pred_coords - pred_centroid
+
+ if num_points < (dim + 1):
+ print(
+ "Warning: The size of one of the point clouds is <= dim+1. "
+ + "`WeightedRigidAlign` cannot return a unique rotation."
+ )
+
+ # Compute the weighted covariance matrix
+ cov_matrix = einsum(
+ weights * pred_coords_centered, true_coords_centered, "b n i, b n j -> b i j"
+ )
+
+ # Compute the SVD of the covariance matrix, required float32 for svd and determinant
+ original_dtype = cov_matrix.dtype
+ cov_matrix_32 = cov_matrix.to(dtype=torch.float32)
+ U, S, V = torch.linalg.svd(
+ cov_matrix_32, driver="gesvd" if cov_matrix_32.is_cuda else None
+ )
+ V = V.mH
+
+ # Catch ambiguous rotation by checking the magnitude of singular values
+ if (S.abs() <= 1e-15).any() and not (num_points < (dim + 1)):
+ print(
+ "Warning: Excessively low rank of "
+ + "cross-correlation between aligned point clouds. "
+ + "`WeightedRigidAlign` cannot return a unique rotation."
+ )
+
+ # Compute the rotation matrix
+ rot_matrix = torch.einsum("b i j, b k j -> b i k", U, V).to(dtype=torch.float32)
+
+ # Ensure proper rotation matrix with determinant 1
+ F = torch.eye(dim, dtype=cov_matrix_32.dtype, device=cov_matrix.device)[
+ None
+ ].repeat(batch_size, 1, 1)
+ F[:, -1, -1] = torch.det(rot_matrix)
+ rot_matrix = einsum(U, F, V, "b i j, b j k, b l k -> b i l")
+ rot_matrix = rot_matrix.to(dtype=original_dtype)
+
+ # Apply the rotation and translation
+ aligned_coords = (
+ einsum(true_coords_centered, rot_matrix, "b n i, b j i -> b n j")
+ + pred_centroid
+ )
+ aligned_coords.detach_()
+
+ return aligned_coords
+
+
+def smooth_lddt_loss(
+ pred_coords,
+ true_coords,
+ is_nucleotide,
+ coords_mask,
+ nucleic_acid_cutoff: float = 30.0,
+ other_cutoff: float = 15.0,
+ multiplicity: int = 1,
+):
+ """Compute weighted alignment.
+
+ Parameters
+ ----------
+ pred_coords: torch.Tensor
+ The predicted atom coordinates
+ true_coords: torch.Tensor
+ The ground truth atom coordinates
+ is_nucleotide: torch.Tensor
+ The weights for alignment
+ coords_mask: torch.Tensor
+ The atoms mask
+ nucleic_acid_cutoff: float
+ The nucleic acid cutoff
+ other_cutoff: float
+ The non nucleic acid cutoff
+ multiplicity: int
+ The multiplicity
+ Returns
+ -------
+ torch.Tensor
+ Aligned coordinates
+
+ """
+ B, N, _ = true_coords.shape
+ true_dists = torch.cdist(true_coords, true_coords)
+ is_nucleotide = is_nucleotide.repeat_interleave(multiplicity, 0)
+
+ coords_mask = coords_mask.repeat_interleave(multiplicity, 0)
+ is_nucleotide_pair = is_nucleotide.unsqueeze(-1).expand(
+ -1, -1, is_nucleotide.shape[-1]
+ )
+
+ mask = (
+ is_nucleotide_pair * (true_dists < nucleic_acid_cutoff).float()
+ + (1 - is_nucleotide_pair) * (true_dists < other_cutoff).float()
+ )
+ mask = mask * (1 - torch.eye(pred_coords.shape[1], device=pred_coords.device))
+ mask = mask * (coords_mask.unsqueeze(-1) * coords_mask.unsqueeze(-2))
+
+ # Compute distances between all pairs of atoms
+ pred_dists = torch.cdist(pred_coords, pred_coords)
+ dist_diff = torch.abs(true_dists - pred_dists)
+
+ # Compute epsilon values
+ eps = (
+ (
+ (
+ F.sigmoid(0.5 - dist_diff)
+ + F.sigmoid(1.0 - dist_diff)
+ + F.sigmoid(2.0 - dist_diff)
+ + F.sigmoid(4.0 - dist_diff)
+ )
+ / 4.0
+ )
+ .view(multiplicity, B // multiplicity, N, N)
+ .mean(dim=0)
+ )
+
+ # Calculate masked averaging
+ eps = eps.repeat_interleave(multiplicity, 0)
+ num = (eps * mask).sum(dim=(-1, -2))
+ den = mask.sum(dim=(-1, -2)).clamp(min=1)
+ lddt = num / den
+
+ return 1.0 - lddt.mean()
diff --git a/src/boltz/model/loss/distogram.py b/src/boltz/model/loss/distogram.py
new file mode 100644
index 0000000..4d8818e
--- /dev/null
+++ b/src/boltz/model/loss/distogram.py
@@ -0,0 +1,50 @@
+from typing import Dict, Tuple
+
+import torch
+from torch import Tensor
+
+
+def distogram_loss(
+ output: Dict[str, Tensor],
+ feats: Dict[str, Tensor],
+) -> Tuple[Tensor, Tensor]:
+ """Compute the distogram loss.
+
+ Parameters
+ ----------
+ output : Dict[str, Tensor]
+ Output of the model
+ feats : Dict[str, Tensor]
+ Input features
+
+ Returns
+ -------
+ Tensor
+ The globally averaged loss.
+ Tensor
+ Per example loss.
+
+ """
+ # Get predicted distograms
+ pred = output["pdistogram"]
+
+ # Compute target distogram
+ target = feats["disto_target"]
+
+ # Combine target mask and padding mask
+ mask = feats["token_disto_mask"]
+ mask = mask[:, None, :] * mask[:, :, None]
+ mask = mask * (1 - torch.eye(mask.shape[1])[None]).to(pred)
+
+ # Compute the distogram loss
+ errors = -1 * torch.sum(
+ target * torch.nn.functional.log_softmax(pred, dim=-1),
+ dim=-1,
+ )
+ denom = 1e-5 + torch.sum(mask, dim=(-1, -2))
+ mean = errors * mask
+ mean = torch.sum(mean, dim=-1)
+ mean = mean / denom[..., None]
+ batch_loss = torch.sum(mean, dim=-1)
+ global_loss = torch.mean(batch_loss)
+ return global_loss, batch_loss
diff --git a/src/boltz/model/loss/validation.py b/src/boltz/model/loss/validation.py
new file mode 100644
index 0000000..00d1aa7
--- /dev/null
+++ b/src/boltz/model/loss/validation.py
@@ -0,0 +1,1025 @@
+import torch
+
+from boltz.data import const
+from boltz.model.loss.confidence import (
+ compute_frame_pred,
+ express_coordinate_in_frame,
+ lddt_dist,
+)
+from boltz.model.loss.diffusion import weighted_rigid_align
+
+
+def factored_lddt_loss(
+ true_atom_coords,
+ pred_atom_coords,
+ feats,
+ atom_mask,
+ multiplicity=1,
+ cardinality_weighted=False,
+):
+ """Compute the lddt factorized into the different modalities.
+
+ Parameters
+ ----------
+ true_atom_coords : torch.Tensor
+ Ground truth atom coordinates after symmetry correction
+ pred_atom_coords : torch.Tensor
+ Predicted atom coordinates
+ feats : Dict[str, torch.Tensor]
+ Input features
+ atom_mask : torch.Tensor
+ Atom mask
+ multiplicity : int
+ Diffusion batch size, by default 1
+
+ Returns
+ -------
+ Dict[str, torch.Tensor]
+ The lddt for each modality
+ Dict[str, torch.Tensor]
+ The total number of pairs for each modality
+
+ """
+ # extract necessary features
+ atom_type = (
+ torch.bmm(
+ feats["atom_to_token"].float(), feats["mol_type"].unsqueeze(-1).float()
+ )
+ .squeeze(-1)
+ .long()
+ )
+ atom_type = atom_type.repeat_interleave(multiplicity, 0)
+
+ ligand_mask = (atom_type == const.chain_type_ids["NONPOLYMER"]).float()
+ dna_mask = (atom_type == const.chain_type_ids["DNA"]).float()
+ rna_mask = (atom_type == const.chain_type_ids["RNA"]).float()
+ protein_mask = (atom_type == const.chain_type_ids["PROTEIN"]).float()
+
+ nucleotide_mask = dna_mask + rna_mask
+
+ true_d = torch.cdist(true_atom_coords, true_atom_coords)
+ pred_d = torch.cdist(pred_atom_coords, pred_atom_coords)
+
+ pair_mask = atom_mask[:, :, None] * atom_mask[:, None, :]
+ pair_mask = (
+ pair_mask
+ * (1 - torch.eye(pair_mask.shape[1], device=pair_mask.device))[None, :, :]
+ )
+
+ cutoff = 15 + 15 * (
+ 1 - (1 - nucleotide_mask[:, :, None]) * (1 - nucleotide_mask[:, None, :])
+ )
+
+ # compute different lddts
+ dna_protein_mask = pair_mask * (
+ dna_mask[:, :, None] * protein_mask[:, None, :]
+ + protein_mask[:, :, None] * dna_mask[:, None, :]
+ )
+ dna_protein_lddt, dna_protein_total = lddt_dist(
+ pred_d, true_d, dna_protein_mask, cutoff
+ )
+ del dna_protein_mask
+
+ rna_protein_mask = pair_mask * (
+ rna_mask[:, :, None] * protein_mask[:, None, :]
+ + protein_mask[:, :, None] * rna_mask[:, None, :]
+ )
+ rna_protein_lddt, rna_protein_total = lddt_dist(
+ pred_d, true_d, rna_protein_mask, cutoff
+ )
+ del rna_protein_mask
+
+ ligand_protein_mask = pair_mask * (
+ ligand_mask[:, :, None] * protein_mask[:, None, :]
+ + protein_mask[:, :, None] * ligand_mask[:, None, :]
+ )
+ ligand_protein_lddt, ligand_protein_total = lddt_dist(
+ pred_d, true_d, ligand_protein_mask, cutoff
+ )
+ del ligand_protein_mask
+
+ dna_ligand_mask = pair_mask * (
+ dna_mask[:, :, None] * ligand_mask[:, None, :]
+ + ligand_mask[:, :, None] * dna_mask[:, None, :]
+ )
+ dna_ligand_lddt, dna_ligand_total = lddt_dist(
+ pred_d, true_d, dna_ligand_mask, cutoff
+ )
+ del dna_ligand_mask
+
+ rna_ligand_mask = pair_mask * (
+ rna_mask[:, :, None] * ligand_mask[:, None, :]
+ + ligand_mask[:, :, None] * rna_mask[:, None, :]
+ )
+ rna_ligand_lddt, rna_ligand_total = lddt_dist(
+ pred_d, true_d, rna_ligand_mask, cutoff
+ )
+ del rna_ligand_mask
+
+ intra_dna_mask = pair_mask * (dna_mask[:, :, None] * dna_mask[:, None, :])
+ intra_dna_lddt, intra_dna_total = lddt_dist(pred_d, true_d, intra_dna_mask, cutoff)
+ del intra_dna_mask
+
+ intra_rna_mask = pair_mask * (rna_mask[:, :, None] * rna_mask[:, None, :])
+ intra_rna_lddt, intra_rna_total = lddt_dist(pred_d, true_d, intra_rna_mask, cutoff)
+ del intra_rna_mask
+
+ chain_id = feats["asym_id"]
+ atom_chain_id = (
+ torch.bmm(feats["atom_to_token"].float(), chain_id.unsqueeze(-1).float())
+ .squeeze(-1)
+ .long()
+ )
+ atom_chain_id = atom_chain_id.repeat_interleave(multiplicity, 0)
+ same_chain_mask = (atom_chain_id[:, :, None] == atom_chain_id[:, None, :]).float()
+
+ intra_ligand_mask = (
+ pair_mask
+ * same_chain_mask
+ * (ligand_mask[:, :, None] * ligand_mask[:, None, :])
+ )
+ intra_ligand_lddt, intra_ligand_total = lddt_dist(
+ pred_d, true_d, intra_ligand_mask, cutoff
+ )
+ del intra_ligand_mask
+
+ intra_protein_mask = (
+ pair_mask
+ * same_chain_mask
+ * (protein_mask[:, :, None] * protein_mask[:, None, :])
+ )
+ intra_protein_lddt, intra_protein_total = lddt_dist(
+ pred_d, true_d, intra_protein_mask, cutoff
+ )
+ del intra_protein_mask
+
+ protein_protein_mask = (
+ pair_mask
+ * (1 - same_chain_mask)
+ * (protein_mask[:, :, None] * protein_mask[:, None, :])
+ )
+ protein_protein_lddt, protein_protein_total = lddt_dist(
+ pred_d, true_d, protein_protein_mask, cutoff
+ )
+ del protein_protein_mask
+
+ lddt_dict = {
+ "dna_protein": dna_protein_lddt,
+ "rna_protein": rna_protein_lddt,
+ "ligand_protein": ligand_protein_lddt,
+ "dna_ligand": dna_ligand_lddt,
+ "rna_ligand": rna_ligand_lddt,
+ "intra_ligand": intra_ligand_lddt,
+ "intra_dna": intra_dna_lddt,
+ "intra_rna": intra_rna_lddt,
+ "intra_protein": intra_protein_lddt,
+ "protein_protein": protein_protein_lddt,
+ }
+
+ total_dict = {
+ "dna_protein": dna_protein_total,
+ "rna_protein": rna_protein_total,
+ "ligand_protein": ligand_protein_total,
+ "dna_ligand": dna_ligand_total,
+ "rna_ligand": rna_ligand_total,
+ "intra_ligand": intra_ligand_total,
+ "intra_dna": intra_dna_total,
+ "intra_rna": intra_rna_total,
+ "intra_protein": intra_protein_total,
+ "protein_protein": protein_protein_total,
+ }
+ if not cardinality_weighted:
+ for key in total_dict:
+ total_dict[key] = (total_dict[key] > 0.0).float()
+
+ return lddt_dict, total_dict
+
+
+def factored_token_lddt_dist_loss(true_d, pred_d, feats, cardinality_weighted=False):
+ """Compute the distogram lddt factorized into the different modalities.
+
+ Parameters
+ ----------
+ true_d : torch.Tensor
+ Ground truth atom distogram
+ pred_d : torch.Tensor
+ Predicted atom distogram
+ feats : Dict[str, torch.Tensor]
+ Input features
+
+ Returns
+ -------
+ Tensor
+ The lddt for each modality
+ Tensor
+ The total number of pairs for each modality
+
+ """
+ # extract necessary features
+ token_type = feats["mol_type"]
+
+ ligand_mask = (token_type == const.chain_type_ids["NONPOLYMER"]).float()
+ dna_mask = (token_type == const.chain_type_ids["DNA"]).float()
+ rna_mask = (token_type == const.chain_type_ids["RNA"]).float()
+ protein_mask = (token_type == const.chain_type_ids["PROTEIN"]).float()
+ nucleotide_mask = dna_mask + rna_mask
+
+ token_mask = feats["token_disto_mask"]
+ token_mask = token_mask[:, :, None] * token_mask[:, None, :]
+ token_mask = token_mask * (1 - torch.eye(token_mask.shape[1])[None]).to(token_mask)
+
+ cutoff = 15 + 15 * (
+ 1 - (1 - nucleotide_mask[:, :, None]) * (1 - nucleotide_mask[:, None, :])
+ )
+
+ # compute different lddts
+ dna_protein_mask = token_mask * (
+ dna_mask[:, :, None] * protein_mask[:, None, :]
+ + protein_mask[:, :, None] * dna_mask[:, None, :]
+ )
+ dna_protein_lddt, dna_protein_total = lddt_dist(
+ pred_d, true_d, dna_protein_mask, cutoff
+ )
+
+ rna_protein_mask = token_mask * (
+ rna_mask[:, :, None] * protein_mask[:, None, :]
+ + protein_mask[:, :, None] * rna_mask[:, None, :]
+ )
+ rna_protein_lddt, rna_protein_total = lddt_dist(
+ pred_d, true_d, rna_protein_mask, cutoff
+ )
+
+ ligand_protein_mask = token_mask * (
+ ligand_mask[:, :, None] * protein_mask[:, None, :]
+ + protein_mask[:, :, None] * ligand_mask[:, None, :]
+ )
+ ligand_protein_lddt, ligand_protein_total = lddt_dist(
+ pred_d, true_d, ligand_protein_mask, cutoff
+ )
+
+ dna_ligand_mask = token_mask * (
+ dna_mask[:, :, None] * ligand_mask[:, None, :]
+ + ligand_mask[:, :, None] * dna_mask[:, None, :]
+ )
+ dna_ligand_lddt, dna_ligand_total = lddt_dist(
+ pred_d, true_d, dna_ligand_mask, cutoff
+ )
+
+ rna_ligand_mask = token_mask * (
+ rna_mask[:, :, None] * ligand_mask[:, None, :]
+ + ligand_mask[:, :, None] * rna_mask[:, None, :]
+ )
+ rna_ligand_lddt, rna_ligand_total = lddt_dist(
+ pred_d, true_d, rna_ligand_mask, cutoff
+ )
+
+ chain_id = feats["asym_id"]
+ same_chain_mask = (chain_id[:, :, None] == chain_id[:, None, :]).float()
+ intra_ligand_mask = (
+ token_mask
+ * same_chain_mask
+ * (ligand_mask[:, :, None] * ligand_mask[:, None, :])
+ )
+ intra_ligand_lddt, intra_ligand_total = lddt_dist(
+ pred_d, true_d, intra_ligand_mask, cutoff
+ )
+
+ intra_dna_mask = token_mask * (dna_mask[:, :, None] * dna_mask[:, None, :])
+ intra_dna_lddt, intra_dna_total = lddt_dist(pred_d, true_d, intra_dna_mask, cutoff)
+
+ intra_rna_mask = token_mask * (rna_mask[:, :, None] * rna_mask[:, None, :])
+ intra_rna_lddt, intra_rna_total = lddt_dist(pred_d, true_d, intra_rna_mask, cutoff)
+
+ chain_id = feats["asym_id"]
+ same_chain_mask = (chain_id[:, :, None] == chain_id[:, None, :]).float()
+
+ intra_protein_mask = (
+ token_mask
+ * same_chain_mask
+ * (protein_mask[:, :, None] * protein_mask[:, None, :])
+ )
+ intra_protein_lddt, intra_protein_total = lddt_dist(
+ pred_d, true_d, intra_protein_mask, cutoff
+ )
+
+ protein_protein_mask = (
+ token_mask
+ * (1 - same_chain_mask)
+ * (protein_mask[:, :, None] * protein_mask[:, None, :])
+ )
+ protein_protein_lddt, protein_protein_total = lddt_dist(
+ pred_d, true_d, protein_protein_mask, cutoff
+ )
+
+ lddt_dict = {
+ "dna_protein": dna_protein_lddt,
+ "rna_protein": rna_protein_lddt,
+ "ligand_protein": ligand_protein_lddt,
+ "dna_ligand": dna_ligand_lddt,
+ "rna_ligand": rna_ligand_lddt,
+ "intra_ligand": intra_ligand_lddt,
+ "intra_dna": intra_dna_lddt,
+ "intra_rna": intra_rna_lddt,
+ "intra_protein": intra_protein_lddt,
+ "protein_protein": protein_protein_lddt,
+ }
+
+ total_dict = {
+ "dna_protein": dna_protein_total,
+ "rna_protein": rna_protein_total,
+ "ligand_protein": ligand_protein_total,
+ "dna_ligand": dna_ligand_total,
+ "rna_ligand": rna_ligand_total,
+ "intra_ligand": intra_ligand_total,
+ "intra_dna": intra_dna_total,
+ "intra_rna": intra_rna_total,
+ "intra_protein": intra_protein_total,
+ "protein_protein": protein_protein_total,
+ }
+
+ if not cardinality_weighted:
+ for key in total_dict:
+ total_dict[key] = (total_dict[key] > 0.0).float()
+
+ return lddt_dict, total_dict
+
+
+def compute_plddt_mae(
+ pred_atom_coords,
+ feats,
+ true_atom_coords,
+ pred_lddt,
+ true_coords_resolved_mask,
+ multiplicity=1,
+):
+ """Compute the plddt mean absolute error.
+
+ Parameters
+ ----------
+ pred_atom_coords : torch.Tensor
+ Predicted atom coordinates
+ feats : torch.Tensor
+ Input features
+ true_atom_coords : torch.Tensor
+ Ground truth atom coordinates
+ pred_lddt : torch.Tensor
+ Predicted lddt
+ true_coords_resolved_mask : torch.Tensor
+ Resolved atom mask
+ multiplicity : int
+ Diffusion batch size, by default 1
+
+ Returns
+ -------
+ Tensor
+ The mae for each modality
+ Tensor
+ The total number of pairs for each modality
+
+ """
+ # extract necessary features
+ atom_mask = true_coords_resolved_mask
+ R_set_to_rep_atom = feats["r_set_to_rep_atom"]
+ R_set_to_rep_atom = R_set_to_rep_atom.repeat_interleave(multiplicity, 0).float()
+
+ token_type = feats["mol_type"]
+ token_type = token_type.repeat_interleave(multiplicity, 0)
+ is_nucleotide_token = (token_type == const.chain_type_ids["DNA"]).float() + (
+ token_type == const.chain_type_ids["RNA"]
+ ).float()
+
+ B = true_atom_coords.shape[0]
+
+ atom_to_token = feats["atom_to_token"].float()
+ atom_to_token = atom_to_token.repeat_interleave(multiplicity, 0)
+
+ token_to_rep_atom = feats["token_to_rep_atom"].float()
+ token_to_rep_atom = token_to_rep_atom.repeat_interleave(multiplicity, 0)
+
+ true_token_coords = torch.bmm(token_to_rep_atom, true_atom_coords)
+ pred_token_coords = torch.bmm(token_to_rep_atom, pred_atom_coords)
+
+ # compute true lddt
+ true_d = torch.cdist(
+ true_token_coords,
+ torch.bmm(R_set_to_rep_atom, true_atom_coords),
+ )
+ pred_d = torch.cdist(
+ pred_token_coords,
+ torch.bmm(R_set_to_rep_atom, pred_atom_coords),
+ )
+
+ pair_mask = atom_mask.unsqueeze(-1) * atom_mask.unsqueeze(-2)
+ pair_mask = (
+ pair_mask
+ * (1 - torch.eye(pair_mask.shape[1], device=pair_mask.device))[None, :, :]
+ )
+ pair_mask = torch.einsum("bnm,bkm->bnk", pair_mask, R_set_to_rep_atom)
+
+ pair_mask = torch.bmm(token_to_rep_atom, pair_mask)
+ atom_mask = torch.bmm(token_to_rep_atom, atom_mask.unsqueeze(-1).float()).squeeze(
+ -1
+ )
+ is_nucleotide_R_element = torch.bmm(
+ R_set_to_rep_atom, torch.bmm(atom_to_token, is_nucleotide_token.unsqueeze(-1))
+ ).squeeze(-1)
+ cutoff = 15 + 15 * is_nucleotide_R_element.reshape(B, 1, -1).repeat(
+ 1, true_d.shape[1], 1
+ )
+
+ target_lddt, mask_no_match = lddt_dist(
+ pred_d, true_d, pair_mask, cutoff, per_atom=True
+ )
+
+ protein_mask = (
+ (token_type == const.chain_type_ids["PROTEIN"]).float()
+ * atom_mask
+ * mask_no_match
+ )
+ ligand_mask = (
+ (token_type == const.chain_type_ids["NONPOLYMER"]).float()
+ * atom_mask
+ * mask_no_match
+ )
+ dna_mask = (
+ (token_type == const.chain_type_ids["DNA"]).float() * atom_mask * mask_no_match
+ )
+ rna_mask = (
+ (token_type == const.chain_type_ids["RNA"]).float() * atom_mask * mask_no_match
+ )
+
+ protein_mae = torch.sum(torch.abs(target_lddt - pred_lddt) * protein_mask) / (
+ torch.sum(protein_mask) + 1e-5
+ )
+ protein_total = torch.sum(protein_mask)
+ ligand_mae = torch.sum(torch.abs(target_lddt - pred_lddt) * ligand_mask) / (
+ torch.sum(ligand_mask) + 1e-5
+ )
+ ligand_total = torch.sum(ligand_mask)
+ dna_mae = torch.sum(torch.abs(target_lddt - pred_lddt) * dna_mask) / (
+ torch.sum(dna_mask) + 1e-5
+ )
+ dna_total = torch.sum(dna_mask)
+ rna_mae = torch.sum(torch.abs(target_lddt - pred_lddt) * rna_mask) / (
+ torch.sum(rna_mask) + 1e-5
+ )
+ rna_total = torch.sum(rna_mask)
+
+ mae_plddt_dict = {
+ "protein": protein_mae,
+ "ligand": ligand_mae,
+ "dna": dna_mae,
+ "rna": rna_mae,
+ }
+ total_dict = {
+ "protein": protein_total,
+ "ligand": ligand_total,
+ "dna": dna_total,
+ "rna": rna_total,
+ }
+
+ return mae_plddt_dict, total_dict
+
+
+def compute_pde_mae(
+ pred_atom_coords,
+ feats,
+ true_atom_coords,
+ pred_pde,
+ true_coords_resolved_mask,
+ multiplicity=1,
+):
+ """Compute the plddt mean absolute error.
+
+ Parameters
+ ----------
+ pred_atom_coords : torch.Tensor
+ Predicted atom coordinates
+ feats : torch.Tensor
+ Input features
+ true_atom_coords : torch.Tensor
+ Ground truth atom coordinates
+ pred_pde : torch.Tensor
+ Predicted pde
+ true_coords_resolved_mask : torch.Tensor
+ Resolved atom mask
+ multiplicity : int
+ Diffusion batch size, by default 1
+
+ Returns
+ -------
+ Tensor
+ The mae for each modality
+ Tensor
+ The total number of pairs for each modality
+
+ """
+ # extract necessary features
+ token_to_rep_atom = feats["token_to_rep_atom"].float()
+ token_to_rep_atom = token_to_rep_atom.repeat_interleave(multiplicity, 0)
+
+ token_mask = torch.bmm(
+ token_to_rep_atom, true_coords_resolved_mask.unsqueeze(-1).float()
+ ).squeeze(-1)
+
+ token_type = feats["mol_type"]
+ token_type = token_type.repeat_interleave(multiplicity, 0)
+
+ true_token_coords = torch.bmm(token_to_rep_atom, true_atom_coords)
+ pred_token_coords = torch.bmm(token_to_rep_atom, pred_atom_coords)
+
+ # compute true pde
+ true_d = torch.cdist(true_token_coords, true_token_coords)
+ pred_d = torch.cdist(pred_token_coords, pred_token_coords)
+ target_pde = (
+ torch.clamp(
+ torch.floor(torch.abs(true_d - pred_d) * 64 / 32).long(), max=63
+ ).float()
+ * 0.5
+ + 0.25
+ )
+
+ pair_mask = token_mask.unsqueeze(-1) * token_mask.unsqueeze(-2)
+ pair_mask = (
+ pair_mask
+ * (1 - torch.eye(pair_mask.shape[1], device=pair_mask.device))[None, :, :]
+ )
+
+ protein_mask = (token_type == const.chain_type_ids["PROTEIN"]).float()
+ ligand_mask = (token_type == const.chain_type_ids["NONPOLYMER"]).float()
+ dna_mask = (token_type == const.chain_type_ids["DNA"]).float()
+ rna_mask = (token_type == const.chain_type_ids["RNA"]).float()
+
+ # compute different pdes
+ dna_protein_mask = pair_mask * (
+ dna_mask[:, :, None] * protein_mask[:, None, :]
+ + protein_mask[:, :, None] * dna_mask[:, None, :]
+ )
+ dna_protein_mae = torch.sum(torch.abs(target_pde - pred_pde) * dna_protein_mask) / (
+ torch.sum(dna_protein_mask) + 1e-5
+ )
+ dna_protein_total = torch.sum(dna_protein_mask)
+
+ rna_protein_mask = pair_mask * (
+ rna_mask[:, :, None] * protein_mask[:, None, :]
+ + protein_mask[:, :, None] * rna_mask[:, None, :]
+ )
+ rna_protein_mae = torch.sum(torch.abs(target_pde - pred_pde) * rna_protein_mask) / (
+ torch.sum(rna_protein_mask) + 1e-5
+ )
+ rna_protein_total = torch.sum(rna_protein_mask)
+
+ ligand_protein_mask = pair_mask * (
+ ligand_mask[:, :, None] * protein_mask[:, None, :]
+ + protein_mask[:, :, None] * ligand_mask[:, None, :]
+ )
+ ligand_protein_mae = torch.sum(
+ torch.abs(target_pde - pred_pde) * ligand_protein_mask
+ ) / (torch.sum(ligand_protein_mask) + 1e-5)
+ ligand_protein_total = torch.sum(ligand_protein_mask)
+
+ dna_ligand_mask = pair_mask * (
+ dna_mask[:, :, None] * ligand_mask[:, None, :]
+ + ligand_mask[:, :, None] * dna_mask[:, None, :]
+ )
+ dna_ligand_mae = torch.sum(torch.abs(target_pde - pred_pde) * dna_ligand_mask) / (
+ torch.sum(dna_ligand_mask) + 1e-5
+ )
+ dna_ligand_total = torch.sum(dna_ligand_mask)
+
+ rna_ligand_mask = pair_mask * (
+ rna_mask[:, :, None] * ligand_mask[:, None, :]
+ + ligand_mask[:, :, None] * rna_mask[:, None, :]
+ )
+ rna_ligand_mae = torch.sum(torch.abs(target_pde - pred_pde) * rna_ligand_mask) / (
+ torch.sum(rna_ligand_mask) + 1e-5
+ )
+ rna_ligand_total = torch.sum(rna_ligand_mask)
+
+ intra_ligand_mask = pair_mask * (ligand_mask[:, :, None] * ligand_mask[:, None, :])
+ intra_ligand_mae = torch.sum(
+ torch.abs(target_pde - pred_pde) * intra_ligand_mask
+ ) / (torch.sum(intra_ligand_mask) + 1e-5)
+ intra_ligand_total = torch.sum(intra_ligand_mask)
+
+ intra_dna_mask = pair_mask * (dna_mask[:, :, None] * dna_mask[:, None, :])
+ intra_dna_mae = torch.sum(torch.abs(target_pde - pred_pde) * intra_dna_mask) / (
+ torch.sum(intra_dna_mask) + 1e-5
+ )
+ intra_dna_total = torch.sum(intra_dna_mask)
+
+ intra_rna_mask = pair_mask * (rna_mask[:, :, None] * rna_mask[:, None, :])
+ intra_rna_mae = torch.sum(torch.abs(target_pde - pred_pde) * intra_rna_mask) / (
+ torch.sum(intra_rna_mask) + 1e-5
+ )
+ intra_rna_total = torch.sum(intra_rna_mask)
+
+ chain_id = feats["asym_id"].repeat_interleave(multiplicity, 0)
+ same_chain_mask = (chain_id[:, :, None] == chain_id[:, None, :]).float()
+
+ intra_protein_mask = (
+ pair_mask
+ * same_chain_mask
+ * (protein_mask[:, :, None] * protein_mask[:, None, :])
+ )
+ intra_protein_mae = torch.sum(
+ torch.abs(target_pde - pred_pde) * intra_protein_mask
+ ) / (torch.sum(intra_protein_mask) + 1e-5)
+ intra_protein_total = torch.sum(intra_protein_mask)
+
+ protein_protein_mask = (
+ pair_mask
+ * (1 - same_chain_mask)
+ * (protein_mask[:, :, None] * protein_mask[:, None, :])
+ )
+ protein_protein_mae = torch.sum(
+ torch.abs(target_pde - pred_pde) * protein_protein_mask
+ ) / (torch.sum(protein_protein_mask) + 1e-5)
+ protein_protein_total = torch.sum(protein_protein_mask)
+
+ mae_pde_dict = {
+ "dna_protein": dna_protein_mae,
+ "rna_protein": rna_protein_mae,
+ "ligand_protein": ligand_protein_mae,
+ "dna_ligand": dna_ligand_mae,
+ "rna_ligand": rna_ligand_mae,
+ "intra_ligand": intra_ligand_mae,
+ "intra_dna": intra_dna_mae,
+ "intra_rna": intra_rna_mae,
+ "intra_protein": intra_protein_mae,
+ "protein_protein": protein_protein_mae,
+ }
+ total_pde_dict = {
+ "dna_protein": dna_protein_total,
+ "rna_protein": rna_protein_total,
+ "ligand_protein": ligand_protein_total,
+ "dna_ligand": dna_ligand_total,
+ "rna_ligand": rna_ligand_total,
+ "intra_ligand": intra_ligand_total,
+ "intra_dna": intra_dna_total,
+ "intra_rna": intra_rna_total,
+ "intra_protein": intra_protein_total,
+ "protein_protein": protein_protein_total,
+ }
+
+ return mae_pde_dict, total_pde_dict
+
+
+def compute_pae_mae(
+ pred_atom_coords,
+ feats,
+ true_atom_coords,
+ pred_pae,
+ true_coords_resolved_mask,
+ multiplicity=1,
+):
+ """Compute the pae mean absolute error.
+
+ Parameters
+ ----------
+ pred_atom_coords : torch.Tensor
+ Predicted atom coordinates
+ feats : torch.Tensor
+ Input features
+ true_atom_coords : torch.Tensor
+ Ground truth atom coordinates
+ pred_pae : torch.Tensor
+ Predicted pae
+ true_coords_resolved_mask : torch.Tensor
+ Resolved atom mask
+ multiplicity : int
+ Diffusion batch size, by default 1
+
+ Returns
+ -------
+ Tensor
+ The mae for each modality
+ Tensor
+ The total number of pairs for each modality
+
+ """
+ # Retrieve frames and resolved masks
+ frames_idx_original = feats["frames_idx"]
+ mask_frame_true = feats["frame_resolved_mask"]
+
+ # Adjust the frames for nonpolymers after symmetry correction!
+ # NOTE: frames of polymers do not change under symmetry!
+ frames_idx_true, mask_collinear_true = compute_frame_pred(
+ true_atom_coords,
+ frames_idx_original,
+ feats,
+ multiplicity,
+ resolved_mask=true_coords_resolved_mask,
+ )
+
+ frame_true_atom_a, frame_true_atom_b, frame_true_atom_c = (
+ frames_idx_true[:, :, :, 0],
+ frames_idx_true[:, :, :, 1],
+ frames_idx_true[:, :, :, 2],
+ )
+ # Compute token coords in true frames
+ B, N, _ = true_atom_coords.shape
+ true_atom_coords = true_atom_coords.reshape(B // multiplicity, multiplicity, -1, 3)
+ true_coords_transformed = express_coordinate_in_frame(
+ true_atom_coords, frame_true_atom_a, frame_true_atom_b, frame_true_atom_c
+ )
+
+ # Compute pred frames and mask
+ frames_idx_pred, mask_collinear_pred = compute_frame_pred(
+ pred_atom_coords, frames_idx_original, feats, multiplicity
+ )
+ frame_pred_atom_a, frame_pred_atom_b, frame_pred_atom_c = (
+ frames_idx_pred[:, :, :, 0],
+ frames_idx_pred[:, :, :, 1],
+ frames_idx_pred[:, :, :, 2],
+ )
+ # Compute token coords in pred frames
+ B, N, _ = pred_atom_coords.shape
+ pred_atom_coords = pred_atom_coords.reshape(B // multiplicity, multiplicity, -1, 3)
+ pred_coords_transformed = express_coordinate_in_frame(
+ pred_atom_coords, frame_pred_atom_a, frame_pred_atom_b, frame_pred_atom_c
+ )
+
+ target_pae_continuous = torch.sqrt(
+ ((true_coords_transformed - pred_coords_transformed) ** 2).sum(-1) + 1e-8
+ )
+ target_pae = (
+ torch.clamp(torch.floor(target_pae_continuous * 64 / 32).long(), max=63).float()
+ * 0.5
+ + 0.25
+ )
+
+ # Compute mask for the pae loss
+ b_true_resolved_mask = true_coords_resolved_mask[
+ torch.arange(B // multiplicity)[:, None, None].to(
+ pred_coords_transformed.device
+ ),
+ frame_true_atom_b,
+ ]
+
+ pair_mask = (
+ mask_frame_true[:, None, :, None] # if true frame is invalid
+ * mask_collinear_true[:, :, :, None] # if true frame is invalid
+ * mask_collinear_pred[:, :, :, None] # if pred frame is invalid
+ * b_true_resolved_mask[:, :, None, :] # If atom j is not resolved
+ * feats["token_pad_mask"][:, None, :, None]
+ * feats["token_pad_mask"][:, None, None, :]
+ )
+
+ token_type = feats["mol_type"]
+ token_type = token_type.repeat_interleave(multiplicity, 0)
+
+ protein_mask = (token_type == const.chain_type_ids["PROTEIN"]).float()
+ ligand_mask = (token_type == const.chain_type_ids["NONPOLYMER"]).float()
+ dna_mask = (token_type == const.chain_type_ids["DNA"]).float()
+ rna_mask = (token_type == const.chain_type_ids["RNA"]).float()
+
+ # compute different paes
+ dna_protein_mask = pair_mask * (
+ dna_mask[:, :, None] * protein_mask[:, None, :]
+ + protein_mask[:, :, None] * dna_mask[:, None, :]
+ )
+ dna_protein_mae = torch.sum(torch.abs(target_pae - pred_pae) * dna_protein_mask) / (
+ torch.sum(dna_protein_mask) + 1e-5
+ )
+ dna_protein_total = torch.sum(dna_protein_mask)
+
+ rna_protein_mask = pair_mask * (
+ rna_mask[:, :, None] * protein_mask[:, None, :]
+ + protein_mask[:, :, None] * rna_mask[:, None, :]
+ )
+ rna_protein_mae = torch.sum(torch.abs(target_pae - pred_pae) * rna_protein_mask) / (
+ torch.sum(rna_protein_mask) + 1e-5
+ )
+ rna_protein_total = torch.sum(rna_protein_mask)
+
+ ligand_protein_mask = pair_mask * (
+ ligand_mask[:, :, None] * protein_mask[:, None, :]
+ + protein_mask[:, :, None] * ligand_mask[:, None, :]
+ )
+ ligand_protein_mae = torch.sum(
+ torch.abs(target_pae - pred_pae) * ligand_protein_mask
+ ) / (torch.sum(ligand_protein_mask) + 1e-5)
+ ligand_protein_total = torch.sum(ligand_protein_mask)
+
+ dna_ligand_mask = pair_mask * (
+ dna_mask[:, :, None] * ligand_mask[:, None, :]
+ + ligand_mask[:, :, None] * dna_mask[:, None, :]
+ )
+ dna_ligand_mae = torch.sum(torch.abs(target_pae - pred_pae) * dna_ligand_mask) / (
+ torch.sum(dna_ligand_mask) + 1e-5
+ )
+ dna_ligand_total = torch.sum(dna_ligand_mask)
+
+ rna_ligand_mask = pair_mask * (
+ rna_mask[:, :, None] * ligand_mask[:, None, :]
+ + ligand_mask[:, :, None] * rna_mask[:, None, :]
+ )
+ rna_ligand_mae = torch.sum(torch.abs(target_pae - pred_pae) * rna_ligand_mask) / (
+ torch.sum(rna_ligand_mask) + 1e-5
+ )
+ rna_ligand_total = torch.sum(rna_ligand_mask)
+
+ intra_ligand_mask = pair_mask * (ligand_mask[:, :, None] * ligand_mask[:, None, :])
+ intra_ligand_mae = torch.sum(
+ torch.abs(target_pae - pred_pae) * intra_ligand_mask
+ ) / (torch.sum(intra_ligand_mask) + 1e-5)
+ intra_ligand_total = torch.sum(intra_ligand_mask)
+
+ intra_dna_mask = pair_mask * (dna_mask[:, :, None] * dna_mask[:, None, :])
+ intra_dna_mae = torch.sum(torch.abs(target_pae - pred_pae) * intra_dna_mask) / (
+ torch.sum(intra_dna_mask) + 1e-5
+ )
+ intra_dna_total = torch.sum(intra_dna_mask)
+
+ intra_rna_mask = pair_mask * (rna_mask[:, :, None] * rna_mask[:, None, :])
+ intra_rna_mae = torch.sum(torch.abs(target_pae - pred_pae) * intra_rna_mask) / (
+ torch.sum(intra_rna_mask) + 1e-5
+ )
+ intra_rna_total = torch.sum(intra_rna_mask)
+
+ chain_id = feats["asym_id"].repeat_interleave(multiplicity, 0)
+ same_chain_mask = (chain_id[:, :, None] == chain_id[:, None, :]).float()
+
+ intra_protein_mask = (
+ pair_mask
+ * same_chain_mask
+ * (protein_mask[:, :, None] * protein_mask[:, None, :])
+ )
+ intra_protein_mae = torch.sum(
+ torch.abs(target_pae - pred_pae) * intra_protein_mask
+ ) / (torch.sum(intra_protein_mask) + 1e-5)
+ intra_protein_total = torch.sum(intra_protein_mask)
+
+ protein_protein_mask = (
+ pair_mask
+ * (1 - same_chain_mask)
+ * (protein_mask[:, :, None] * protein_mask[:, None, :])
+ )
+ protein_protein_mae = torch.sum(
+ torch.abs(target_pae - pred_pae) * protein_protein_mask
+ ) / (torch.sum(protein_protein_mask) + 1e-5)
+ protein_protein_total = torch.sum(protein_protein_mask)
+
+ mae_pae_dict = {
+ "dna_protein": dna_protein_mae,
+ "rna_protein": rna_protein_mae,
+ "ligand_protein": ligand_protein_mae,
+ "dna_ligand": dna_ligand_mae,
+ "rna_ligand": rna_ligand_mae,
+ "intra_ligand": intra_ligand_mae,
+ "intra_dna": intra_dna_mae,
+ "intra_rna": intra_rna_mae,
+ "intra_protein": intra_protein_mae,
+ "protein_protein": protein_protein_mae,
+ }
+ total_pae_dict = {
+ "dna_protein": dna_protein_total,
+ "rna_protein": rna_protein_total,
+ "ligand_protein": ligand_protein_total,
+ "dna_ligand": dna_ligand_total,
+ "rna_ligand": rna_ligand_total,
+ "intra_ligand": intra_ligand_total,
+ "intra_dna": intra_dna_total,
+ "intra_rna": intra_rna_total,
+ "intra_protein": intra_protein_total,
+ "protein_protein": protein_protein_total,
+ }
+
+ return mae_pae_dict, total_pae_dict
+
+
+def weighted_minimum_rmsd(
+ pred_atom_coords,
+ feats,
+ multiplicity=1,
+ nucleotide_weight=5.0,
+ ligand_weight=10.0,
+):
+ """Compute rmsd of the aligned atom coordinates.
+
+ Parameters
+ ----------
+ pred_atom_coords : torch.Tensor
+ Predicted atom coordinates
+ feats : torch.Tensor
+ Input features
+ multiplicity : int
+ Diffusion batch size, by default 1
+
+ Returns
+ -------
+ Tensor
+ The rmsds
+ Tensor
+ The best rmsd
+
+ """
+ atom_coords = feats["coords"]
+ atom_coords = atom_coords.repeat_interleave(multiplicity, 0)
+ atom_coords = atom_coords[:, 0]
+
+ atom_mask = feats["atom_resolved_mask"]
+ atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
+
+ align_weights = atom_coords.new_ones(atom_coords.shape[:2])
+ atom_type = (
+ torch.bmm(
+ feats["atom_to_token"].float(), feats["mol_type"].unsqueeze(-1).float()
+ )
+ .squeeze(-1)
+ .long()
+ )
+ atom_type = atom_type.repeat_interleave(multiplicity, 0)
+
+ align_weights = align_weights * (
+ 1
+ + nucleotide_weight
+ * (
+ torch.eq(atom_type, const.chain_type_ids["DNA"]).float()
+ + torch.eq(atom_type, const.chain_type_ids["RNA"]).float()
+ )
+ + ligand_weight
+ * torch.eq(atom_type, const.chain_type_ids["NONPOLYMER"]).float()
+ )
+
+ with torch.no_grad():
+ atom_coords_aligned_ground_truth = weighted_rigid_align(
+ atom_coords, pred_atom_coords, align_weights, mask=atom_mask
+ )
+
+ # weighted MSE loss of denoised atom positions
+ mse_loss = ((pred_atom_coords - atom_coords_aligned_ground_truth) ** 2).sum(dim=-1)
+ rmsd = torch.sqrt(
+ torch.sum(mse_loss * align_weights * atom_mask, dim=-1)
+ / torch.sum(align_weights * atom_mask, dim=-1)
+ )
+ best_rmsd = torch.min(rmsd.reshape(-1, multiplicity), dim=1).values
+
+ return rmsd, best_rmsd
+
+
+def weighted_minimum_rmsd_single(
+ pred_atom_coords,
+ atom_coords,
+ atom_mask,
+ atom_to_token,
+ mol_type,
+ nucleotide_weight=5.0,
+ ligand_weight=10.0,
+):
+ """Compute rmsd of the aligned atom coordinates.
+
+ Parameters
+ ----------
+ pred_atom_coords : torch.Tensor
+ Predicted atom coordinates
+ atom_coords: torch.Tensor
+ Ground truth atom coordinates
+ atom_mask : torch.Tensor
+ Resolved atom mask
+ atom_to_token : torch.Tensor
+ Atom to token mapping
+ mol_type : torch.Tensor
+ Atom type
+
+ Returns
+ -------
+ Tensor
+ The rmsd
+ Tensor
+ The aligned coordinates
+ Tensor
+ The aligned weights
+
+ """
+ align_weights = atom_coords.new_ones(atom_coords.shape[:2])
+ atom_type = (
+ torch.bmm(atom_to_token.float(), mol_type.unsqueeze(-1).float())
+ .squeeze(-1)
+ .long()
+ )
+
+ align_weights = align_weights * (
+ 1
+ + nucleotide_weight
+ * (
+ torch.eq(atom_type, const.chain_type_ids["DNA"]).float()
+ + torch.eq(atom_type, const.chain_type_ids["RNA"]).float()
+ )
+ + ligand_weight
+ * torch.eq(atom_type, const.chain_type_ids["NONPOLYMER"]).float()
+ )
+
+ with torch.no_grad():
+ atom_coords_aligned_ground_truth = weighted_rigid_align(
+ atom_coords, pred_atom_coords, align_weights, mask=atom_mask
+ )
+
+ # weighted MSE loss of denoised atom positions
+ mse_loss = ((pred_atom_coords - atom_coords_aligned_ground_truth) ** 2).sum(dim=-1)
+ rmsd = torch.sqrt(
+ torch.sum(mse_loss * align_weights * atom_mask, dim=-1)
+ / torch.sum(align_weights * atom_mask, dim=-1)
+ )
+ return rmsd, atom_coords_aligned_ground_truth, align_weights
diff --git a/src/boltz/model/model.py b/src/boltz/model/model.py
new file mode 100644
index 0000000..b140100
--- /dev/null
+++ b/src/boltz/model/model.py
@@ -0,0 +1,1221 @@
+import gc
+import random
+from typing import Any, Optional, Dict
+
+import torch
+import torch._dynamo
+from pytorch_lightning import LightningModule
+from torch import Tensor, nn
+from torchmetrics import MeanMetric
+
+import boltz.model.layers.initialize as init
+from boltz.data import const
+from boltz.data.feature.symmetry import (
+ minimum_lddt_symmetry_coords,
+ minimum_symmetry_coords,
+)
+from boltz.model.loss.confidence import confidence_loss
+from boltz.model.loss.distogram import distogram_loss
+from boltz.model.loss.validation import (
+ compute_pae_mae,
+ compute_pde_mae,
+ compute_plddt_mae,
+ factored_lddt_loss,
+ factored_token_lddt_dist_loss,
+ weighted_minimum_rmsd,
+)
+from boltz.model.modules.confidence import ConfidenceModule
+from boltz.model.modules.diffusion import AtomDiffusion
+from boltz.model.modules.encoders import RelativePositionEncoder
+from boltz.model.modules.trunk import (
+ DistogramModule,
+ InputEmbedder,
+ MSAModule,
+ PairformerModule,
+)
+from boltz.model.modules.utils import ExponentialMovingAverage
+from boltz.model.optim.scheduler import AlphaFoldLRScheduler
+
+
+class Boltz1(LightningModule):
+ def __init__( # noqa: PLR0915, C901, PLR0912
+ self,
+ atom_s: int,
+ atom_z: int,
+ token_s: int,
+ token_z: int,
+ num_bins: int,
+ training_args: dict[str, Any],
+ validation_args: dict[str, Any],
+ embedder_args: dict[str, Any],
+ msa_args: dict[str, Any],
+ pairformer_args: dict[str, Any],
+ score_model_args: dict[str, Any],
+ diffusion_process_args: dict[str, Any],
+ diffusion_loss_args: dict[str, Any],
+ confidence_model_args: dict[str, Any],
+ atom_feature_dim: int = 128,
+ confidence_prediction: bool = False,
+ confidence_imitate_trunk: bool = False,
+ alpha_pae: float = 0.0,
+ structure_prediction_training: bool = True,
+ atoms_per_window_queries: int = 32,
+ atoms_per_window_keys: int = 128,
+ compile_pairformer: bool = False,
+ compile_structure: bool = False,
+ compile_confidence: bool = False,
+ nucleotide_rmsd_weight: float = 5.0,
+ ligand_rmsd_weight: float = 10.0,
+ no_msa: bool = False,
+ no_atom_encoder: bool = False,
+ ema: bool = False,
+ ema_decay: float = 0.999,
+ min_dist: float = 2.0,
+ max_dist: float = 22.0,
+ predict_args: Optional[dict[str, Any]] = None,
+ ) -> None:
+ super().__init__()
+
+ self.save_hyperparameters()
+
+ self.lddt = nn.ModuleDict()
+ self.disto_lddt = nn.ModuleDict()
+ self.complex_lddt = nn.ModuleDict()
+ if confidence_prediction:
+ self.top1_lddt = nn.ModuleDict()
+ self.iplddt_top1_lddt = nn.ModuleDict()
+ self.ipde_top1_lddt = nn.ModuleDict()
+ self.pde_top1_lddt = nn.ModuleDict()
+ self.ptm_top1_lddt = nn.ModuleDict()
+ self.iptm_top1_lddt = nn.ModuleDict()
+ self.ligand_iptm_top1_lddt = nn.ModuleDict()
+ self.protein_iptm_top1_lddt = nn.ModuleDict()
+ self.avg_lddt = nn.ModuleDict()
+ self.plddt_mae = nn.ModuleDict()
+ self.pde_mae = nn.ModuleDict()
+ self.pae_mae = nn.ModuleDict()
+ for m in const.out_types + ["pocket_ligand_protein"]:
+ self.lddt[m] = MeanMetric()
+ self.disto_lddt[m] = MeanMetric()
+ self.complex_lddt[m] = MeanMetric()
+ if confidence_prediction:
+ self.top1_lddt[m] = MeanMetric()
+ self.iplddt_top1_lddt[m] = MeanMetric()
+ self.ipde_top1_lddt[m] = MeanMetric()
+ self.pde_top1_lddt[m] = MeanMetric()
+ self.ptm_top1_lddt[m] = MeanMetric()
+ self.iptm_top1_lddt[m] = MeanMetric()
+ self.ligand_iptm_top1_lddt[m] = MeanMetric()
+ self.protein_iptm_top1_lddt[m] = MeanMetric()
+ self.avg_lddt[m] = MeanMetric()
+ self.pde_mae[m] = MeanMetric()
+ self.pae_mae[m] = MeanMetric()
+ for m in const.out_single_types:
+ if confidence_prediction:
+ self.plddt_mae[m] = MeanMetric()
+ self.rmsd = MeanMetric()
+ self.best_rmsd = MeanMetric()
+
+ self.train_confidence_loss_logger = MeanMetric()
+ self.train_confidence_loss_dict_logger = nn.ModuleDict()
+ for m in [
+ "plddt_loss",
+ "resolved_loss",
+ "pde_loss",
+ "pae_loss",
+ "rel_plddt_loss",
+ "rel_pde_loss",
+ "rel_pae_loss",
+ ]:
+ self.train_confidence_loss_dict_logger[m] = MeanMetric()
+
+ self.ema = None
+ self.use_ema = ema
+ self.ema_decay = ema_decay
+
+ self.training_args = training_args
+ self.validation_args = validation_args
+ self.diffusion_loss_args = diffusion_loss_args
+ self.predict_args = predict_args
+
+ self.nucleotide_rmsd_weight = nucleotide_rmsd_weight
+ self.ligand_rmsd_weight = ligand_rmsd_weight
+
+ self.num_bins = num_bins
+ self.min_dist = min_dist
+ self.max_dist = max_dist
+ self.is_pairformer_compiled = False
+
+ # Input projections
+ s_input_dim = (
+ token_s + 2 * const.num_tokens + 1 + len(const.pocket_contact_info)
+ )
+ self.s_init = nn.Linear(s_input_dim, token_s, bias=False)
+ self.z_init_1 = nn.Linear(s_input_dim, token_z, bias=False)
+ self.z_init_2 = nn.Linear(s_input_dim, token_z, bias=False)
+
+ # Input embeddings
+ full_embedder_args = {
+ "atom_s": atom_s,
+ "atom_z": atom_z,
+ "token_s": token_s,
+ "token_z": token_z,
+ "atoms_per_window_queries": atoms_per_window_queries,
+ "atoms_per_window_keys": atoms_per_window_keys,
+ "atom_feature_dim": atom_feature_dim,
+ "no_atom_encoder": no_atom_encoder,
+ **embedder_args,
+ }
+ self.input_embedder = InputEmbedder(**full_embedder_args)
+ self.rel_pos = RelativePositionEncoder(token_z)
+ self.token_bonds = nn.Linear(1, token_z, bias=False)
+
+ # Normalization layers
+ self.s_norm = nn.LayerNorm(token_s)
+ self.z_norm = nn.LayerNorm(token_z)
+
+ # Recycling projections
+ self.s_recycle = nn.Linear(token_s, token_s, bias=False)
+ self.z_recycle = nn.Linear(token_z, token_z, bias=False)
+ init.gating_init_(self.s_recycle.weight)
+ init.gating_init_(self.z_recycle.weight)
+
+ # Pairwise stack
+ self.no_msa = no_msa
+ if not no_msa:
+ self.msa_module = MSAModule(
+ token_z=token_z,
+ s_input_dim=s_input_dim,
+ **msa_args,
+ )
+ self.pairformer_module = PairformerModule(token_s, token_z, **pairformer_args)
+ if compile_pairformer:
+ # Big models hit the default cache limit (8)
+ self.is_pairformer_compiled = True
+ torch._dynamo.config.cache_size_limit = 512
+ torch._dynamo.config.accumulated_cache_size_limit = 512
+ self.pairformer_module = torch.compile(
+ self.pairformer_module,
+ dynamic=False,
+ fullgraph=False,
+ )
+
+ # Output modules
+ self.structure_module = AtomDiffusion(
+ score_model_args={
+ "token_z": token_z,
+ "token_s": token_s,
+ "atom_z": atom_z,
+ "atom_s": atom_s,
+ "atoms_per_window_queries": atoms_per_window_queries,
+ "atoms_per_window_keys": atoms_per_window_keys,
+ "atom_feature_dim": atom_feature_dim,
+ **score_model_args,
+ },
+ compile_score=compile_structure,
+ accumulate_token_repr="use_s_diffusion" in confidence_model_args
+ and confidence_model_args["use_s_diffusion"],
+ **diffusion_process_args,
+ )
+ self.distogram_module = DistogramModule(token_z, num_bins)
+ self.confidence_prediction = confidence_prediction
+ self.alpha_pae = alpha_pae
+
+ self.structure_prediction_training = structure_prediction_training
+ self.confidence_imitate_trunk = confidence_imitate_trunk
+ if self.confidence_prediction:
+ if self.confidence_imitate_trunk:
+ self.confidence_module = ConfidenceModule(
+ token_s,
+ token_z,
+ confidence_prediction=confidence_prediction,
+ compute_pae=alpha_pae > 0,
+ imitate_trunk=True,
+ pairformer_args=pairformer_args,
+ full_embedder_args=full_embedder_args,
+ msa_args=msa_args,
+ **confidence_model_args,
+ )
+ else:
+ self.confidence_module = ConfidenceModule(
+ token_s,
+ token_z,
+ confidence_prediction=confidence_prediction,
+ compute_pae=alpha_pae > 0,
+ **confidence_model_args,
+ )
+ if compile_confidence:
+ self.confidence_module = torch.compile(
+ self.confidence_module, dynamic=False, fullgraph=False
+ )
+
+ # Remove grad from weights they are not trained for ddp
+ if not structure_prediction_training:
+ for name, param in self.named_parameters():
+ if name.split(".")[0] != "confidence_module":
+ param.requires_grad = False
+
+ def forward(
+ self,
+ feats: dict[str, Tensor],
+ recycling_steps: int = 0,
+ num_sampling_steps: Optional[int] = None,
+ multiplicity_diffusion_train: int = 1,
+ diffusion_samples: int = 1,
+ ) -> dict[str, Tensor]:
+ dict_out = {}
+
+ # Compute input embeddings
+ with torch.set_grad_enabled(
+ self.training and self.structure_prediction_training
+ ):
+ s_inputs = self.input_embedder(feats)
+
+ # Initialize the sequence and pairwise embeddings
+ s_init = self.s_init(s_inputs)
+ z_init = (
+ self.z_init_1(s_inputs)[:, :, None]
+ + self.z_init_2(s_inputs)[:, None, :]
+ )
+ relative_position_encoding = self.rel_pos(feats)
+ z_init = z_init + relative_position_encoding
+ z_init = z_init + self.token_bonds(feats["token_bonds"].float())
+
+ # Perform rounds of the pairwise stack
+ s = torch.zeros_like(s_init)
+ z = torch.zeros_like(z_init)
+
+ # Compute pairwise mask
+ mask = feats["token_pad_mask"].float()
+ pair_mask = mask[:, :, None] * mask[:, None, :]
+
+ for i in range(recycling_steps + 1):
+ with torch.set_grad_enabled(self.training and (i == recycling_steps)):
+ # Fixes an issue with unused parameters in autocast
+ if (
+ self.training
+ and (i == recycling_steps)
+ and torch.is_autocast_enabled()
+ ):
+ torch.clear_autocast_cache()
+
+ # Apply recycling
+ s = s_init + self.s_recycle(self.s_norm(s))
+ z = z_init + self.z_recycle(self.z_norm(z))
+
+ # Compute pairwise stack
+ if not self.no_msa:
+ z = z + self.msa_module(z, s_inputs, feats)
+
+ # Revert to uncompiled version for validation
+ if self.is_pairformer_compiled and not self.training:
+ pairformer_module = self.pairformer_module._orig_mod # noqa: SLF001
+ else:
+ pairformer_module = self.pairformer_module
+
+ s, z = pairformer_module(s, z, mask=mask, pair_mask=pair_mask)
+
+ pdistogram = self.distogram_module(z)
+ dict_out = {"pdistogram": pdistogram}
+
+ # Compute structure module
+ if self.training and self.structure_prediction_training:
+ dict_out.update(
+ self.structure_module(
+ s_trunk=s,
+ z_trunk=z,
+ s_inputs=s_inputs,
+ feats=feats,
+ relative_position_encoding=relative_position_encoding,
+ multiplicity=multiplicity_diffusion_train,
+ )
+ )
+
+ if (not self.training) or self.confidence_prediction:
+ dict_out.update(
+ self.structure_module.sample(
+ s_trunk=s,
+ z_trunk=z,
+ s_inputs=s_inputs,
+ feats=feats,
+ relative_position_encoding=relative_position_encoding,
+ num_sampling_steps=num_sampling_steps,
+ atom_mask=feats["atom_pad_mask"],
+ multiplicity=diffusion_samples,
+ train_accumulate_token_repr=self.training,
+ )
+ )
+
+ if self.confidence_prediction:
+ dict_out.update(
+ self.confidence_module(
+ s_inputs=s_inputs.detach(),
+ s=s.detach(),
+ z=z.detach(),
+ s_diffusion=(
+ dict_out["diff_token_repr"]
+ if self.confidence_module.use_s_diffusion
+ else None
+ ),
+ x_pred=dict_out["sample_atom_coords"].detach(),
+ feats=feats,
+ pred_distogram_logits=dict_out["pdistogram"].detach(),
+ multiplicity=diffusion_samples,
+ )
+ )
+ if self.confidence_prediction and self.confidence_module.use_s_diffusion:
+ dict_out.pop("diff_token_repr", None)
+ return dict_out
+
+ def get_true_coordinates(
+ self,
+ batch,
+ out,
+ diffusion_samples,
+ symmetry_correction,
+ lddt_minimization=True,
+ ):
+ if symmetry_correction:
+ min_coords_routine = (
+ minimum_lddt_symmetry_coords
+ if lddt_minimization
+ else minimum_symmetry_coords
+ )
+ true_coords = []
+ true_coords_resolved_mask = []
+ rmsds, best_rmsds = [], []
+ for idx in range(batch["token_index"].shape[0]):
+ best_rmsd = float("inf")
+ for rep in range(diffusion_samples):
+ i = idx * diffusion_samples + rep
+ best_true_coords, rmsd, best_true_coords_resolved_mask = (
+ min_coords_routine(
+ coords=out["sample_atom_coords"][i : i + 1],
+ feats=batch,
+ index_batch=idx,
+ nucleotide_weight=self.nucleotide_rmsd_weight,
+ ligand_weight=self.ligand_rmsd_weight,
+ )
+ )
+ rmsds.append(rmsd)
+ true_coords.append(best_true_coords)
+ true_coords_resolved_mask.append(best_true_coords_resolved_mask)
+ if rmsd < best_rmsd:
+ best_rmsd = rmsd
+ best_rmsds.append(best_rmsd)
+ true_coords = torch.cat(true_coords, dim=0)
+ true_coords_resolved_mask = torch.cat(true_coords_resolved_mask, dim=0)
+ else:
+ true_coords = (
+ batch["coords"].squeeze(1).repeat_interleave(diffusion_samples, 0)
+ )
+
+ true_coords_resolved_mask = batch["atom_resolved_mask"].repeat_interleave(
+ diffusion_samples, 0
+ )
+ rmsds, best_rmsds = weighted_minimum_rmsd(
+ out["sample_atom_coords"],
+ batch,
+ multiplicity=diffusion_samples,
+ nucleotide_weight=self.nucleotide_rmsd_weight,
+ ligand_weight=self.ligand_rmsd_weight,
+ )
+
+ return true_coords, rmsds, best_rmsds, true_coords_resolved_mask
+
+ def training_step(self, batch: dict[str, Tensor], batch_idx: int) -> Tensor:
+ # Sample recycling steps
+ recycling_steps = random.randint(0, self.training_args.recycling_steps)
+
+ # Compute the forward pass
+ out = self(
+ feats=batch,
+ recycling_steps=recycling_steps,
+ num_sampling_steps=self.training_args.sampling_steps,
+ multiplicity_diffusion_train=self.training_args.diffusion_multiplicity,
+ diffusion_samples=self.training_args.diffusion_samples,
+ )
+
+ # Compute losses
+ if self.structure_prediction_training:
+ disto_loss, _ = distogram_loss(
+ out,
+ batch,
+ )
+ diffusion_loss_dict = self.structure_module.compute_loss(
+ batch,
+ out,
+ multiplicity=self.training_args.diffusion_multiplicity,
+ **self.diffusion_loss_args,
+ )
+ else:
+ disto_loss = 0.0
+ diffusion_loss_dict = {"loss": 0.0, "loss_breakdown": {}}
+
+ if self.confidence_prediction:
+ # confidence model symmetry correction
+ true_coords, _, _, true_coords_resolved_mask = self.get_true_coordinates(
+ batch,
+ out,
+ diffusion_samples=self.training_args.diffusion_samples,
+ symmetry_correction=self.training_args.symmetry_correction,
+ )
+
+ confidence_loss_dict = confidence_loss(
+ out,
+ batch,
+ true_coords,
+ true_coords_resolved_mask,
+ alpha_pae=self.alpha_pae,
+ multiplicity=self.training_args.diffusion_samples,
+ )
+ else:
+ confidence_loss_dict = {
+ "loss": torch.tensor(0.0).to(batch["token_index"].device),
+ "loss_breakdown": {},
+ }
+
+ # Aggregate losses
+ loss = (
+ self.training_args.confidence_loss_weight * confidence_loss_dict["loss"]
+ + self.training_args.diffusion_loss_weight * diffusion_loss_dict["loss"]
+ + self.training_args.distogram_loss_weight * disto_loss
+ )
+ # Log losses
+ self.log("train/distogram_loss", disto_loss)
+ self.log("train/diffusion_loss", diffusion_loss_dict["loss"])
+ for k, v in diffusion_loss_dict["loss_breakdown"].items():
+ self.log(f"train/{k}", v)
+
+ if self.confidence_prediction:
+ self.train_confidence_loss_logger.update(
+ confidence_loss_dict["loss"].detach()
+ )
+
+ for k in self.train_confidence_loss_dict_logger.keys():
+ self.train_confidence_loss_dict_logger[k].update(
+ confidence_loss_dict["loss_breakdown"][k].detach()
+ if torch.is_tensor(confidence_loss_dict["loss_breakdown"][k])
+ else confidence_loss_dict["loss_breakdown"][k]
+ )
+ self.log("train/loss", loss)
+ self.training_log()
+ return loss
+
+ def training_log(self):
+ self.log("train/grad_norm", self.gradient_norm(self), prog_bar=False)
+ self.log("train/param_norm", self.parameter_norm(self), prog_bar=False)
+
+ lr = self.trainer.optimizers[0].param_groups[0]["lr"]
+ self.log("lr", lr, prog_bar=False)
+
+ self.log(
+ "train/grad_norm_msa_module",
+ self.gradient_norm(self.msa_module),
+ prog_bar=False,
+ )
+ self.log(
+ "train/param_norm_msa_module",
+ self.parameter_norm(self.msa_module),
+ prog_bar=False,
+ )
+
+ self.log(
+ "train/grad_norm_pairformer_module",
+ self.gradient_norm(self.pairformer_module),
+ prog_bar=False,
+ )
+ self.log(
+ "train/param_norm_pairformer_module",
+ self.parameter_norm(self.pairformer_module),
+ prog_bar=False,
+ )
+
+ self.log(
+ "train/grad_norm_structure_module",
+ self.gradient_norm(self.structure_module),
+ prog_bar=False,
+ )
+ self.log(
+ "train/param_norm_structure_module",
+ self.parameter_norm(self.structure_module),
+ prog_bar=False,
+ )
+
+ if self.confidence_prediction:
+ self.log(
+ "train/grad_norm_confidence_module",
+ self.gradient_norm(self.confidence_module),
+ prog_bar=False,
+ )
+ self.log(
+ "train/param_norm_confidence_module",
+ self.parameter_norm(self.confidence_module),
+ prog_bar=False,
+ )
+
+ def on_train_epoch_end(self):
+ self.log(
+ "train/confidence_loss",
+ self.train_confidence_loss_logger,
+ prog_bar=False,
+ on_step=False,
+ on_epoch=True,
+ )
+ for k, v in self.train_confidence_loss_dict_logger.items():
+ self.log(f"train/{k}", v, prog_bar=False, on_step=False, on_epoch=True)
+
+ def gradient_norm(self, module) -> float:
+ # Only compute over parameters that are being trained
+ parameters = filter(lambda p: p.requires_grad, module.parameters())
+ parameters = filter(lambda p: p.grad is not None, parameters)
+ norm = torch.tensor([p.grad.norm(p=2) ** 2 for p in parameters]).sum().sqrt()
+ return norm
+
+ def parameter_norm(self, module) -> float:
+ # Only compute over parameters that are being trained
+ parameters = filter(lambda p: p.requires_grad, module.parameters())
+ norm = torch.tensor([p.norm(p=2) ** 2 for p in parameters]).sum().sqrt()
+ return norm
+
+ def validation_step(self, batch: dict[str, Tensor], batch_idx: int):
+ # Compute the forward pass
+ n_samples = self.validation_args.diffusion_samples
+ try:
+ out = self(
+ batch,
+ recycling_steps=self.validation_args.recycling_steps,
+ num_sampling_steps=self.validation_args.sampling_steps,
+ diffusion_samples=n_samples,
+ )
+
+ except RuntimeError as e: # catch out of memory exceptions
+ if "out of memory" in str(e):
+ print("| WARNING: ran out of memory, skipping batch")
+ torch.cuda.empty_cache()
+ gc.collect()
+ return
+ else:
+ raise e
+
+ try:
+ # Compute distogram LDDT
+ boundaries = torch.linspace(2, 22.0, 63)
+ lower = torch.tensor([1.0])
+ upper = torch.tensor([22.0 + 5.0])
+ exp_boundaries = torch.cat((lower, boundaries, upper))
+ mid_points = ((exp_boundaries[:-1] + exp_boundaries[1:]) / 2).to(
+ out["pdistogram"]
+ )
+
+ # Compute predicted dists
+ preds = out["pdistogram"]
+ pred_softmax = torch.softmax(preds, dim=-1)
+ pred_softmax = pred_softmax.argmax(dim=-1)
+ pred_softmax = torch.nn.functional.one_hot(
+ pred_softmax, num_classes=preds.shape[-1]
+ )
+ pred_dist = (pred_softmax * mid_points).sum(dim=-1)
+ true_center = batch["disto_center"]
+ true_dists = torch.cdist(true_center, true_center)
+
+ # Compute lddt's
+ batch["token_disto_mask"] = batch["token_disto_mask"]
+ disto_lddt_dict, disto_total_dict = factored_token_lddt_dist_loss(
+ feats=batch,
+ true_d=true_dists,
+ pred_d=pred_dist,
+ )
+
+ true_coords, rmsds, best_rmsds, true_coords_resolved_mask = (
+ self.get_true_coordinates(
+ batch=batch,
+ out=out,
+ diffusion_samples=n_samples,
+ symmetry_correction=self.validation_args.symmetry_correction,
+ )
+ )
+
+ all_lddt_dict, all_total_dict = factored_lddt_loss(
+ feats=batch,
+ atom_mask=true_coords_resolved_mask,
+ true_atom_coords=true_coords,
+ pred_atom_coords=out["sample_atom_coords"],
+ multiplicity=n_samples,
+ )
+ except RuntimeError as e: # catch out of memory exceptions
+ if "out of memory" in str(e):
+ print("| WARNING: ran out of memory, skipping batch")
+ torch.cuda.empty_cache()
+ gc.collect()
+ return
+ else:
+ raise e
+ # if the multiplicity used is > 1 then we take the best lddt of the different samples
+ # AF3 combines this with the confidence based filtering
+ best_lddt_dict, best_total_dict = {}, {}
+ best_complex_lddt_dict, best_complex_total_dict = {}, {}
+ B = true_coords.shape[0] // n_samples
+ if n_samples > 1:
+ # NOTE: we can change the way we aggregate the lddt
+ complex_total = 0
+ complex_lddt = 0
+ for key in all_lddt_dict.keys():
+ complex_lddt += all_lddt_dict[key] * all_total_dict[key]
+ complex_total += all_total_dict[key]
+ complex_lddt /= complex_total + 1e-7
+ best_complex_idx = complex_lddt.reshape(-1, n_samples).argmax(dim=1)
+ for key in all_lddt_dict:
+ best_idx = all_lddt_dict[key].reshape(-1, n_samples).argmax(dim=1)
+ best_lddt_dict[key] = all_lddt_dict[key].reshape(-1, n_samples)[
+ torch.arange(B), best_idx
+ ]
+ best_total_dict[key] = all_total_dict[key].reshape(-1, n_samples)[
+ torch.arange(B), best_idx
+ ]
+ best_complex_lddt_dict[key] = all_lddt_dict[key].reshape(-1, n_samples)[
+ torch.arange(B), best_complex_idx
+ ]
+ best_complex_total_dict[key] = all_total_dict[key].reshape(
+ -1, n_samples
+ )[torch.arange(B), best_complex_idx]
+ else:
+ best_lddt_dict = all_lddt_dict
+ best_total_dict = all_total_dict
+ best_complex_lddt_dict = all_lddt_dict
+ best_complex_total_dict = all_total_dict
+
+ # Filtering based on confidence
+ if self.confidence_prediction and n_samples > 1:
+ # note: for now we don't have pae predictions so have to use pLDDT instead of pTM
+ # also, while AF3 differentiates the best prediction per confidence type we are currently not doing it
+ # consider this in the future as well as weighing the different pLLDT types before aggregation
+ mae_plddt_dict, total_mae_plddt_dict = compute_plddt_mae(
+ pred_atom_coords=out["sample_atom_coords"],
+ feats=batch,
+ true_atom_coords=true_coords,
+ pred_lddt=out["plddt"],
+ true_coords_resolved_mask=true_coords_resolved_mask,
+ multiplicity=n_samples,
+ )
+ mae_pde_dict, total_mae_pde_dict = compute_pde_mae(
+ pred_atom_coords=out["sample_atom_coords"],
+ feats=batch,
+ true_atom_coords=true_coords,
+ pred_pde=out["pde"],
+ true_coords_resolved_mask=true_coords_resolved_mask,
+ multiplicity=n_samples,
+ )
+ mae_pae_dict, total_mae_pae_dict = compute_pae_mae(
+ pred_atom_coords=out["sample_atom_coords"],
+ feats=batch,
+ true_atom_coords=true_coords,
+ pred_pae=out["pae"],
+ true_coords_resolved_mask=true_coords_resolved_mask,
+ multiplicity=n_samples,
+ )
+
+ plddt = out["complex_plddt"].reshape(-1, n_samples)
+ top1_idx = plddt.argmax(dim=1)
+ iplddt = out["complex_iplddt"].reshape(-1, n_samples)
+ iplddt_top1_idx = iplddt.argmax(dim=1)
+ pde = out["complex_pde"].reshape(-1, n_samples)
+ pde_top1_idx = pde.argmin(dim=1)
+ ipde = out["complex_ipde"].reshape(-1, n_samples)
+ ipde_top1_idx = ipde.argmin(dim=1)
+ ptm = out["ptm"].reshape(-1, n_samples)
+ ptm_top1_idx = ptm.argmax(dim=1)
+ iptm = out["iptm"].reshape(-1, n_samples)
+ iptm_top1_idx = iptm.argmax(dim=1)
+ ligand_iptm = out["ligand_iptm"].reshape(-1, n_samples)
+ ligand_iptm_top1_idx = ligand_iptm.argmax(dim=1)
+ protein_iptm = out["protein_iptm"].reshape(-1, n_samples)
+ protein_iptm_top1_idx = protein_iptm.argmax(dim=1)
+
+ for key in all_lddt_dict:
+ top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[
+ torch.arange(B), top1_idx
+ ]
+ top1_total = all_total_dict[key].reshape(-1, n_samples)[
+ torch.arange(B), top1_idx
+ ]
+ iplddt_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[
+ torch.arange(B), iplddt_top1_idx
+ ]
+ iplddt_top1_total = all_total_dict[key].reshape(-1, n_samples)[
+ torch.arange(B), iplddt_top1_idx
+ ]
+ pde_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[
+ torch.arange(B), pde_top1_idx
+ ]
+ pde_top1_total = all_total_dict[key].reshape(-1, n_samples)[
+ torch.arange(B), pde_top1_idx
+ ]
+ ipde_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[
+ torch.arange(B), ipde_top1_idx
+ ]
+ ipde_top1_total = all_total_dict[key].reshape(-1, n_samples)[
+ torch.arange(B), ipde_top1_idx
+ ]
+ ptm_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[
+ torch.arange(B), ptm_top1_idx
+ ]
+ ptm_top1_total = all_total_dict[key].reshape(-1, n_samples)[
+ torch.arange(B), ptm_top1_idx
+ ]
+ iptm_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[
+ torch.arange(B), iptm_top1_idx
+ ]
+ iptm_top1_total = all_total_dict[key].reshape(-1, n_samples)[
+ torch.arange(B), iptm_top1_idx
+ ]
+ ligand_iptm_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[
+ torch.arange(B), ligand_iptm_top1_idx
+ ]
+ ligand_iptm_top1_total = all_total_dict[key].reshape(-1, n_samples)[
+ torch.arange(B), ligand_iptm_top1_idx
+ ]
+ protein_iptm_top1_lddt = all_lddt_dict[key].reshape(-1, n_samples)[
+ torch.arange(B), protein_iptm_top1_idx
+ ]
+ protein_iptm_top1_total = all_total_dict[key].reshape(-1, n_samples)[
+ torch.arange(B), protein_iptm_top1_idx
+ ]
+
+ self.top1_lddt[key].update(top1_lddt, top1_total)
+ self.iplddt_top1_lddt[key].update(iplddt_top1_lddt, iplddt_top1_total)
+ self.pde_top1_lddt[key].update(pde_top1_lddt, pde_top1_total)
+ self.ipde_top1_lddt[key].update(ipde_top1_lddt, ipde_top1_total)
+ self.ptm_top1_lddt[key].update(ptm_top1_lddt, ptm_top1_total)
+ self.iptm_top1_lddt[key].update(iptm_top1_lddt, iptm_top1_total)
+ self.ligand_iptm_top1_lddt[key].update(
+ ligand_iptm_top1_lddt, ligand_iptm_top1_total
+ )
+ self.protein_iptm_top1_lddt[key].update(
+ protein_iptm_top1_lddt, protein_iptm_top1_total
+ )
+
+ self.avg_lddt[key].update(all_lddt_dict[key], all_total_dict[key])
+ self.pde_mae[key].update(mae_pde_dict[key], total_mae_pde_dict[key])
+ self.pae_mae[key].update(mae_pae_dict[key], total_mae_pae_dict[key])
+
+ for key in mae_plddt_dict:
+ self.plddt_mae[key].update(
+ mae_plddt_dict[key], total_mae_plddt_dict[key]
+ )
+
+ for m in const.out_types:
+ if m == "ligand_protein":
+ if torch.any(
+ batch["pocket_feature"][
+ :, :, const.pocket_contact_info["POCKET"]
+ ].bool()
+ ):
+ self.lddt["pocket_ligand_protein"].update(
+ best_lddt_dict[m], best_total_dict[m]
+ )
+ self.disto_lddt["pocket_ligand_protein"].update(
+ disto_lddt_dict[m], disto_total_dict[m]
+ )
+ self.complex_lddt["pocket_ligand_protein"].update(
+ best_complex_lddt_dict[m], best_complex_total_dict[m]
+ )
+ else:
+ self.lddt["ligand_protein"].update(
+ best_lddt_dict[m], best_total_dict[m]
+ )
+ self.disto_lddt["ligand_protein"].update(
+ disto_lddt_dict[m], disto_total_dict[m]
+ )
+ self.complex_lddt["ligand_protein"].update(
+ best_complex_lddt_dict[m], best_complex_total_dict[m]
+ )
+ else:
+ self.lddt[m].update(best_lddt_dict[m], best_total_dict[m])
+ self.disto_lddt[m].update(disto_lddt_dict[m], disto_total_dict[m])
+ self.complex_lddt[m].update(
+ best_complex_lddt_dict[m], best_complex_total_dict[m]
+ )
+ self.rmsd.update(rmsds)
+ self.best_rmsd.update(best_rmsds)
+
+ def on_validation_epoch_end(self):
+ avg_lddt = {}
+ avg_disto_lddt = {}
+ avg_complex_lddt = {}
+ if self.confidence_prediction:
+ avg_top1_lddt = {}
+ avg_iplddt_top1_lddt = {}
+ avg_pde_top1_lddt = {}
+ avg_ipde_top1_lddt = {}
+ avg_ptm_top1_lddt = {}
+ avg_iptm_top1_lddt = {}
+ avg_ligand_iptm_top1_lddt = {}
+ avg_protein_iptm_top1_lddt = {}
+
+ avg_avg_lddt = {}
+ avg_mae_plddt = {}
+ avg_mae_pde = {}
+ avg_mae_pae = {}
+
+ for m in const.out_types + ["pocket_ligand_protein"]:
+ avg_lddt[m] = self.lddt[m].compute()
+ avg_lddt[m] = 0.0 if torch.isnan(avg_lddt[m]) else avg_lddt[m].item()
+ self.lddt[m].reset()
+ self.log(f"val/lddt_{m}", avg_lddt[m], prog_bar=False, sync_dist=True)
+
+ avg_disto_lddt[m] = self.disto_lddt[m].compute()
+ avg_disto_lddt[m] = (
+ 0.0 if torch.isnan(avg_disto_lddt[m]) else avg_disto_lddt[m].item()
+ )
+ self.disto_lddt[m].reset()
+ self.log(
+ f"val/disto_lddt_{m}", avg_disto_lddt[m], prog_bar=False, sync_dist=True
+ )
+ avg_complex_lddt[m] = self.complex_lddt[m].compute()
+ avg_complex_lddt[m] = (
+ 0.0 if torch.isnan(avg_complex_lddt[m]) else avg_complex_lddt[m].item()
+ )
+ self.complex_lddt[m].reset()
+ self.log(
+ f"val/complex_lddt_{m}",
+ avg_complex_lddt[m],
+ prog_bar=False,
+ sync_dist=True,
+ )
+ if self.confidence_prediction:
+ avg_top1_lddt[m] = self.top1_lddt[m].compute()
+ avg_top1_lddt[m] = (
+ 0.0 if torch.isnan(avg_top1_lddt[m]) else avg_top1_lddt[m].item()
+ )
+ self.top1_lddt[m].reset()
+ self.log(
+ f"val/top1_lddt_{m}",
+ avg_top1_lddt[m],
+ prog_bar=False,
+ sync_dist=True,
+ )
+ avg_iplddt_top1_lddt[m] = self.iplddt_top1_lddt[m].compute()
+ avg_iplddt_top1_lddt[m] = (
+ 0.0
+ if torch.isnan(avg_iplddt_top1_lddt[m])
+ else avg_iplddt_top1_lddt[m].item()
+ )
+ self.iplddt_top1_lddt[m].reset()
+ self.log(
+ f"val/iplddt_top1_lddt_{m}",
+ avg_iplddt_top1_lddt[m],
+ prog_bar=False,
+ sync_dist=True,
+ )
+ avg_pde_top1_lddt[m] = self.pde_top1_lddt[m].compute()
+ avg_pde_top1_lddt[m] = (
+ 0.0
+ if torch.isnan(avg_pde_top1_lddt[m])
+ else avg_pde_top1_lddt[m].item()
+ )
+ self.pde_top1_lddt[m].reset()
+ self.log(
+ f"val/pde_top1_lddt_{m}",
+ avg_pde_top1_lddt[m],
+ prog_bar=False,
+ sync_dist=True,
+ )
+ avg_ipde_top1_lddt[m] = self.ipde_top1_lddt[m].compute()
+ avg_ipde_top1_lddt[m] = (
+ 0.0
+ if torch.isnan(avg_ipde_top1_lddt[m])
+ else avg_ipde_top1_lddt[m].item()
+ )
+ self.ipde_top1_lddt[m].reset()
+ self.log(
+ f"val/ipde_top1_lddt_{m}",
+ avg_ipde_top1_lddt[m],
+ prog_bar=False,
+ sync_dist=True,
+ )
+ avg_ptm_top1_lddt[m] = self.ptm_top1_lddt[m].compute()
+ avg_ptm_top1_lddt[m] = (
+ 0.0
+ if torch.isnan(avg_ptm_top1_lddt[m])
+ else avg_ptm_top1_lddt[m].item()
+ )
+ self.ptm_top1_lddt[m].reset()
+ self.log(
+ f"val/ptm_top1_lddt_{m}",
+ avg_ptm_top1_lddt[m],
+ prog_bar=False,
+ sync_dist=True,
+ )
+ avg_iptm_top1_lddt[m] = self.iptm_top1_lddt[m].compute()
+ avg_iptm_top1_lddt[m] = (
+ 0.0
+ if torch.isnan(avg_iptm_top1_lddt[m])
+ else avg_iptm_top1_lddt[m].item()
+ )
+ self.iptm_top1_lddt[m].reset()
+ self.log(
+ f"val/iptm_top1_lddt_{m}",
+ avg_iptm_top1_lddt[m],
+ prog_bar=False,
+ sync_dist=True,
+ )
+
+ avg_ligand_iptm_top1_lddt[m] = self.ligand_iptm_top1_lddt[m].compute()
+ avg_ligand_iptm_top1_lddt[m] = (
+ 0.0
+ if torch.isnan(avg_ligand_iptm_top1_lddt[m])
+ else avg_ligand_iptm_top1_lddt[m].item()
+ )
+ self.ligand_iptm_top1_lddt[m].reset()
+ self.log(
+ f"val/ligand_iptm_top1_lddt_{m}",
+ avg_ligand_iptm_top1_lddt[m],
+ prog_bar=False,
+ sync_dist=True,
+ )
+
+ avg_protein_iptm_top1_lddt[m] = self.protein_iptm_top1_lddt[m].compute()
+ avg_protein_iptm_top1_lddt[m] = (
+ 0.0
+ if torch.isnan(avg_protein_iptm_top1_lddt[m])
+ else avg_protein_iptm_top1_lddt[m].item()
+ )
+ self.protein_iptm_top1_lddt[m].reset()
+ self.log(
+ f"val/protein_iptm_top1_lddt_{m}",
+ avg_protein_iptm_top1_lddt[m],
+ prog_bar=False,
+ sync_dist=True,
+ )
+
+ avg_avg_lddt[m] = self.avg_lddt[m].compute()
+ avg_avg_lddt[m] = (
+ 0.0 if torch.isnan(avg_avg_lddt[m]) else avg_avg_lddt[m].item()
+ )
+ self.avg_lddt[m].reset()
+ self.log(
+ f"val/avg_lddt_{m}", avg_avg_lddt[m], prog_bar=False, sync_dist=True
+ )
+ avg_mae_pde[m] = self.pde_mae[m].compute().item()
+ self.pde_mae[m].reset()
+ self.log(
+ f"val/MAE_pde_{m}",
+ avg_mae_pde[m],
+ prog_bar=False,
+ sync_dist=True,
+ )
+ avg_mae_pae[m] = self.pae_mae[m].compute().item()
+ self.pae_mae[m].reset()
+ self.log(
+ f"val/MAE_pae_{m}",
+ avg_mae_pae[m],
+ prog_bar=False,
+ sync_dist=True,
+ )
+
+ for m in const.out_single_types:
+ if self.confidence_prediction:
+ avg_mae_plddt[m] = self.plddt_mae[m].compute().item()
+ self.plddt_mae[m].reset()
+ self.log(
+ f"val/MAE_plddt_{m}",
+ avg_mae_plddt[m],
+ prog_bar=False,
+ sync_dist=True,
+ )
+
+ overall_disto_lddt = sum(
+ avg_disto_lddt[m] * w for (m, w) in const.out_types_weights.items()
+ ) / sum(const.out_types_weights.values())
+ self.log("val/disto_lddt", overall_disto_lddt, prog_bar=True, sync_dist=True)
+
+ overall_lddt = sum(
+ avg_lddt[m] * w for (m, w) in const.out_types_weights.items()
+ ) / sum(const.out_types_weights.values())
+ self.log("val/lddt", overall_lddt, prog_bar=True, sync_dist=True)
+
+ overall_complex_lddt = sum(
+ avg_complex_lddt[m] * w for (m, w) in const.out_types_weights.items()
+ ) / sum(const.out_types_weights.values())
+ self.log(
+ "val/complex_lddt", overall_complex_lddt, prog_bar=True, sync_dist=True
+ )
+
+ if self.confidence_prediction:
+ overall_top1_lddt = sum(
+ avg_top1_lddt[m] * w for (m, w) in const.out_types_weights.items()
+ ) / sum(const.out_types_weights.values())
+ self.log("val/top1_lddt", overall_top1_lddt, prog_bar=True, sync_dist=True)
+
+ overall_iplddt_top1_lddt = sum(
+ avg_iplddt_top1_lddt[m] * w
+ for (m, w) in const.out_types_weights.items()
+ ) / sum(const.out_types_weights.values())
+ self.log(
+ "val/iplddt_top1_lddt",
+ overall_iplddt_top1_lddt,
+ prog_bar=True,
+ sync_dist=True,
+ )
+
+ overall_pde_top1_lddt = sum(
+ avg_pde_top1_lddt[m] * w for (m, w) in const.out_types_weights.items()
+ ) / sum(const.out_types_weights.values())
+ self.log(
+ "val/pde_top1_lddt",
+ overall_pde_top1_lddt,
+ prog_bar=True,
+ sync_dist=True,
+ )
+
+ overall_ipde_top1_lddt = sum(
+ avg_ipde_top1_lddt[m] * w for (m, w) in const.out_types_weights.items()
+ ) / sum(const.out_types_weights.values())
+ self.log(
+ "val/ipde_top1_lddt",
+ overall_ipde_top1_lddt,
+ prog_bar=True,
+ sync_dist=True,
+ )
+
+ overall_ptm_top1_lddt = sum(
+ avg_ptm_top1_lddt[m] * w for (m, w) in const.out_types_weights.items()
+ ) / sum(const.out_types_weights.values())
+ self.log(
+ "val/ptm_top1_lddt",
+ overall_ptm_top1_lddt,
+ prog_bar=True,
+ sync_dist=True,
+ )
+
+ overall_iptm_top1_lddt = sum(
+ avg_iptm_top1_lddt[m] * w for (m, w) in const.out_types_weights.items()
+ ) / sum(const.out_types_weights.values())
+ self.log(
+ "val/iptm_top1_lddt",
+ overall_iptm_top1_lddt,
+ prog_bar=True,
+ sync_dist=True,
+ )
+
+ overall_avg_lddt = sum(
+ avg_avg_lddt[m] * w for (m, w) in const.out_types_weights.items()
+ ) / sum(const.out_types_weights.values())
+ self.log("val/avg_lddt", overall_avg_lddt, prog_bar=True, sync_dist=True)
+
+ self.log("val/rmsd", self.rmsd.compute(), prog_bar=True, sync_dist=True)
+ self.rmsd.reset()
+
+ self.log(
+ "val/best_rmsd", self.best_rmsd.compute(), prog_bar=True, sync_dist=True
+ )
+ self.best_rmsd.reset()
+
+ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
+ try:
+ out = self(
+ batch,
+ recycling_steps=self.predict_args["recycling_steps"],
+ num_sampling_steps=self.predict_args["sampling_steps"],
+ diffusion_samples=self.predict_args["diffusion_samples"],
+ )
+ pred_dict = {"exception": False}
+ pred_dict["masks"] = batch["atom_pad_mask"]
+ pred_dict["coords"] = out["sample_atom_coords"]
+ if self.confidence_prediction:
+ pred_dict["confidence"] = out["iptm"]
+
+ return pred_dict
+
+ except RuntimeError as e: # catch out of memory exceptions
+ if "out of memory" in str(e):
+ print("| WARNING: ran out of memory, skipping batch")
+ torch.cuda.empty_cache()
+ gc.collect()
+ return {"exception": True}
+ else:
+ raise {"exception": True}
+
+ def configure_optimizers(self):
+ """Configure the optimizer."""
+
+ if self.structure_prediction_training:
+ parameters = [p for p in self.parameters() if p.requires_grad]
+ else:
+ parameters = [
+ p for p in self.confidence_module.parameters() if p.requires_grad
+ ]
+
+ optimizer = torch.optim.Adam(
+ parameters,
+ betas=(self.training_args.adam_beta_1, self.training_args.adam_beta_2),
+ eps=self.training_args.adam_eps,
+ lr=self.training_args.base_lr,
+ )
+ if self.training_args.lr_scheduler == "af3":
+ scheduler = AlphaFoldLRScheduler(
+ optimizer,
+ base_lr=self.training_args.base_lr,
+ max_lr=self.training_args.max_lr,
+ warmup_no_steps=self.training_args.lr_warmup_no_steps,
+ start_decay_after_n_steps=self.training_args.lr_start_decay_after_n_steps,
+ decay_every_n_steps=self.training_args.lr_decay_every_n_steps,
+ decay_factor=self.training_args.lr_decay_factor,
+ )
+ return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
+
+ return optimizer
+
+ def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
+ if self.use_ema:
+ checkpoint["ema"] = self.ema.state_dict()
+
+ def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
+ if self.use_ema and self.ema is None:
+ self.ema = ExponentialMovingAverage(
+ parameters=self.parameters(), decay=self.ema_decay
+ )
+ if self.use_ema:
+ if self.ema.compatible(checkpoint["ema"]["shadow_params"]):
+ self.ema.load_state_dict(checkpoint["ema"], device=torch.device("cpu"))
+ else:
+ print("EMA not compatible with checkpoint, skipping...")
+ elif "ema" in checkpoint:
+ self.load_state_dict(checkpoint["ema"]["shadow_params"], strict=False)
+
+ def on_train_start(self):
+ if self.use_ema and self.ema is None:
+ self.ema = ExponentialMovingAverage(
+ parameters=self.parameters(), decay=self.ema_decay
+ )
+ elif self.use_ema:
+ self.ema.to(self.device)
+
+ def on_train_epoch_start(self) -> None:
+ if self.use_ema:
+ self.ema.restore(self.parameters())
+
+ def on_train_batch_end(self, outputs, batch: Any, batch_idx: int) -> None:
+ # Updates EMA parameters after optimizer.step()
+ if self.use_ema:
+ self.ema.update(self.parameters())
+
+ def prepare_eval(self) -> None:
+ if self.use_ema and self.ema is None:
+ self.ema = ExponentialMovingAverage(
+ parameters=self.parameters(), decay=self.ema_decay
+ )
+
+ if self.use_ema:
+ self.ema.store(self.parameters())
+ self.ema.copy_to(self.parameters())
+
+ def on_validation_start(self):
+ self.prepare_eval()
+
+ def on_predict_start(self) -> None:
+ self.prepare_eval()
+
+ def on_test_start(self) -> None:
+ self.prepare_eval()
diff --git a/src/boltz/model/modules/__init__.py b/src/boltz/model/modules/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/boltz/model/modules/confidence.py b/src/boltz/model/modules/confidence.py
new file mode 100755
index 0000000..e2f7172
--- /dev/null
+++ b/src/boltz/model/modules/confidence.py
@@ -0,0 +1,442 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from boltz.data import const
+import boltz.model.layers.initialize as init
+from boltz.model.modules.confidence_utils import (
+ compute_aggregated_metric,
+ compute_ptms,
+)
+from boltz.model.modules.encoders import RelativePositionEncoder
+from boltz.model.modules.trunk import (
+ InputEmbedder,
+ MSAModule,
+ PairformerModule,
+)
+from boltz.model.modules.utils import LinearNoBias
+
+
+class ConfidenceModule(nn.Module):
+ """Confidence module."""
+
+ def __init__(
+ self,
+ token_s,
+ token_z,
+ pairformer_args: dict,
+ num_dist_bins=64,
+ max_dist=22,
+ add_s_to_z_prod=False,
+ add_s_input_to_s=False,
+ use_s_diffusion=False,
+ add_z_input_to_z=False,
+ confidence_args: dict = None,
+ compute_pae: bool = False,
+ imitate_trunk=False,
+ full_embedder_args: dict = None,
+ msa_args: dict = None,
+ compile_pairformer=False,
+ ):
+ """Initialize the confidence module.
+
+ Parameters
+ ----------
+ token_s : int
+ The single representation dimension.
+ token_z : int
+ The pair representation dimension.
+ pairformer_args : int
+ The pairformer arguments.
+ num_dist_bins : int, optional
+ The number of distance bins, by default 64.
+ max_dist : int, optional
+ The maximum distance, by default 22.
+ add_s_to_z_prod : bool, optional
+ Whether to add s to z product, by default False.
+ add_s_input_to_s : bool, optional
+ Whether to add s input to s, by default False.
+ use_s_diffusion : bool, optional
+ Whether to use s diffusion, by default False.
+ add_z_input_to_z : bool, optional
+ Whether to add z input to z, by default False.
+ confidence_args : dict, optional
+ The confidence arguments, by default None.
+ compute_pae : bool, optional
+ Whether to compute pae, by default False.
+ imitate_trunk : bool, optional
+ Whether to imitate trunk, by default False.
+ full_embedder_args : dict, optional
+ The full embedder arguments, by default None.
+ msa_args : dict, optional
+ The msa arguments, by default None.
+ compile_pairformer : bool, optional
+ Whether to compile pairformer, by default False.
+
+ """
+
+ super().__init__()
+ self.max_num_atoms_per_token = 23
+ self.no_update_s = pairformer_args.get("no_update_s", False)
+ boundaries = torch.linspace(2, max_dist, num_dist_bins - 1)
+ self.register_buffer("boundaries", boundaries)
+ self.dist_bin_pairwise_embed = nn.Embedding(num_dist_bins, token_z)
+ init.gating_init_(self.dist_bin_pairwise_embed.weight)
+ s_input_dim = (
+ token_s + 2 * const.num_tokens + 1 + len(const.pocket_contact_info)
+ )
+
+ self.use_s_diffusion = use_s_diffusion
+ if use_s_diffusion:
+ self.s_diffusion_norm = nn.LayerNorm(2 * token_s)
+ self.s_diffusion_to_s = LinearNoBias(2 * token_s, token_s)
+ init.gating_init_(self.s_diffusion_to_s.weight)
+
+ self.s_to_z = LinearNoBias(s_input_dim, token_z)
+ self.s_to_z_transpose = LinearNoBias(s_input_dim, token_z)
+ init.gating_init_(self.s_to_z.weight)
+ init.gating_init_(self.s_to_z_transpose.weight)
+
+ self.add_s_to_z_prod = add_s_to_z_prod
+ if add_s_to_z_prod:
+ self.s_to_z_prod_in1 = LinearNoBias(s_input_dim, token_z)
+ self.s_to_z_prod_in2 = LinearNoBias(s_input_dim, token_z)
+ self.s_to_z_prod_out = LinearNoBias(token_z, token_z)
+ init.gating_init_(self.s_to_z_prod_out.weight)
+
+ self.imitate_trunk = imitate_trunk
+ if self.imitate_trunk:
+ s_input_dim = (
+ token_s + 2 * const.num_tokens + 1 + len(const.pocket_contact_info)
+ )
+ self.s_init = nn.Linear(s_input_dim, token_s, bias=False)
+ self.z_init_1 = nn.Linear(s_input_dim, token_z, bias=False)
+ self.z_init_2 = nn.Linear(s_input_dim, token_z, bias=False)
+
+ # Input embeddings
+ self.input_embedder = InputEmbedder(**full_embedder_args)
+ self.rel_pos = RelativePositionEncoder(token_z)
+ self.token_bonds = nn.Linear(1, token_z, bias=False)
+
+ # Normalization layers
+ self.s_norm = nn.LayerNorm(token_s)
+ self.z_norm = nn.LayerNorm(token_z)
+
+ # Recycling projections
+ self.s_recycle = nn.Linear(token_s, token_s, bias=False)
+ self.z_recycle = nn.Linear(token_z, token_z, bias=False)
+ init.gating_init_(self.s_recycle.weight)
+ init.gating_init_(self.z_recycle.weight)
+
+ # Pairwise stack
+ self.msa_module = MSAModule(
+ token_z=token_z,
+ s_input_dim=s_input_dim,
+ **msa_args,
+ )
+ self.pairformer_module = PairformerModule(
+ token_s,
+ token_z,
+ **pairformer_args,
+ )
+ if compile_pairformer:
+ # Big models hit the default cache limit (8)
+ self.is_pairformer_compiled = True
+ torch._dynamo.config.cache_size_limit = 512
+ torch._dynamo.config.accumulated_cache_size_limit = 512
+ self.pairformer_module = torch.compile(
+ self.pairformer_module,
+ dynamic=False,
+ fullgraph=False,
+ )
+
+ self.final_s_norm = nn.LayerNorm(token_s)
+ self.final_z_norm = nn.LayerNorm(token_z)
+ else:
+ self.s_inputs_norm = nn.LayerNorm(s_input_dim)
+ if not self.no_update_s:
+ self.s_norm = nn.LayerNorm(token_s)
+ self.z_norm = nn.LayerNorm(token_z)
+
+ self.add_s_input_to_s = add_s_input_to_s
+ if add_s_input_to_s:
+ self.s_input_to_s = LinearNoBias(s_input_dim, token_s)
+ init.gating_init_(self.s_input_to_s.weight)
+
+ self.add_z_input_to_z = add_z_input_to_z
+ if add_z_input_to_z:
+ self.rel_pos = RelativePositionEncoder(token_z)
+ self.token_bonds = nn.Linear(1, token_z, bias=False)
+
+ self.pairformer_stack = PairformerModule(
+ token_s,
+ token_z,
+ **pairformer_args,
+ )
+
+ self.confidence_heads = ConfidenceHeads(
+ token_s,
+ token_z,
+ compute_pae=compute_pae,
+ **confidence_args,
+ )
+
+ def forward(
+ self,
+ s_inputs,
+ s,
+ z,
+ x_pred,
+ feats,
+ pred_distogram_logits,
+ multiplicity=1,
+ s_diffusion=None,
+ ):
+ if self.imitate_trunk:
+ s_inputs = self.input_embedder(feats)
+
+ # Initialize the sequence and pairwise embeddings
+ s_init = self.s_init(s_inputs)
+ z_init = (
+ self.z_init_1(s_inputs)[:, :, None]
+ + self.z_init_2(s_inputs)[:, None, :]
+ )
+ relative_position_encoding = self.rel_pos(feats)
+ z_init = z_init + relative_position_encoding
+ z_init = z_init + self.token_bonds(feats["token_bonds"].float())
+
+ # Apply recycling
+ s = s_init + self.s_recycle(self.s_norm(s))
+ z = z_init + self.z_recycle(self.z_norm(z))
+
+ else:
+ s_inputs = self.s_inputs_norm(s_inputs).repeat_interleave(multiplicity, 0)
+ if not self.no_update_s:
+ s = self.s_norm(s)
+
+ if self.add_s_input_to_s:
+ s = s + self.s_input_to_s(s_inputs)
+
+ z = self.z_norm(z)
+
+ if self.add_z_input_to_z:
+ relative_position_encoding = self.rel_pos(feats)
+ z = z + relative_position_encoding
+ z = z + self.token_bonds(feats["token_bonds"].float())
+
+ s = s.repeat_interleave(multiplicity, 0)
+
+ if self.use_s_diffusion:
+ assert s_diffusion is not None
+ s_diffusion = self.s_diffusion_norm(s_diffusion)
+ s = s + self.s_diffusion_to_s(s_diffusion)
+
+ z = z.repeat_interleave(multiplicity, 0)
+ z = (
+ z
+ + self.s_to_z(s_inputs)[:, :, None, :]
+ + self.s_to_z_transpose(s_inputs)[:, None, :, :]
+ )
+
+ if self.add_s_to_z_prod:
+ z = z + self.s_to_z_prod_out(
+ self.s_to_z_prod_in1(s_inputs)[:, :, None, :]
+ * self.s_to_z_prod_in2(s_inputs)[:, None, :, :]
+ )
+
+ token_to_rep_atom = feats["token_to_rep_atom"]
+ token_to_rep_atom = token_to_rep_atom.repeat_interleave(multiplicity, 0)
+ if len(x_pred.shape) == 4:
+ B, mult, N, _ = x_pred.shape
+ x_pred = x_pred.reshape(B * mult, N, -1)
+ x_pred_repr = torch.bmm(token_to_rep_atom.float(), x_pred)
+ d = torch.cdist(x_pred_repr, x_pred_repr)
+
+ distogram = (d.unsqueeze(-1) > self.boundaries).sum(dim=-1).long()
+ distogram = self.dist_bin_pairwise_embed(distogram)
+
+ z = z + distogram
+
+ mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
+ pair_mask = mask[:, :, None] * mask[:, None, :]
+
+ if self.imitate_trunk:
+ z = z + self.msa_module(z, s_inputs, feats)
+
+ s, z = self.pairformer_module(s, z, mask=mask, pair_mask=pair_mask)
+
+ s, z = self.final_s_norm(s), self.final_z_norm(z)
+
+ else:
+ s_t, z_t = self.pairformer_stack(s, z, mask=mask, pair_mask=pair_mask)
+
+ # AF3 has residual connections, we remove them
+ s = s_t
+ z = z_t
+
+ out_dict = {}
+
+ # confidence heads
+ out_dict.update(
+ self.confidence_heads(
+ s=s,
+ z=z,
+ x_pred=x_pred,
+ d=d,
+ feats=feats,
+ multiplicity=multiplicity,
+ pred_distogram_logits=pred_distogram_logits,
+ )
+ )
+
+ return out_dict
+
+
+class ConfidenceHeads(nn.Module):
+ """Confidence heads."""
+
+ def __init__(
+ self,
+ token_s,
+ token_z,
+ num_plddt_bins=50,
+ num_pde_bins=64,
+ num_pae_bins=64,
+ compute_pae: bool = False,
+ ):
+ """Initialize the confidence head.
+
+ Parameters
+ ----------
+ token_s : int
+ The single representation dimension.
+ token_z : int
+ The pair representation dimension.
+ num_plddt_bins : int
+ The number of plddt bins, by default 50.
+ num_pde_bins : int
+ The number of pde bins, by default 64.
+ num_pae_bins : int
+ The number of pae bins, by default 64.
+ compute_pae : bool
+ Whether to compute pae, by default False
+ """
+
+ super().__init__()
+ self.max_num_atoms_per_token = 23
+ self.to_pde_logits = LinearNoBias(token_z, num_pde_bins)
+ self.to_plddt_logits = LinearNoBias(token_s, num_plddt_bins)
+ self.to_resolved_logits = LinearNoBias(token_s, 2)
+ self.compute_pae = compute_pae
+ if self.compute_pae:
+ self.to_pae_logits = LinearNoBias(token_z, num_pae_bins)
+
+ def forward(
+ self,
+ s,
+ z,
+ x_pred,
+ d,
+ feats,
+ pred_distogram_logits,
+ multiplicity=1,
+ ):
+ # Compute the pLDDT, PDE, PAE, and resolved logits
+ plddt_logits = self.to_plddt_logits(s)
+ pde_logits = self.to_pde_logits(z + z.transpose(1, 2))
+ resolved_logits = self.to_resolved_logits(s)
+ if self.compute_pae:
+ pae_logits = self.to_pae_logits(z)
+
+ # Weights used to compute the interface pLDDT and PDE
+ ligand_weight = 20
+ non_interface_weight = 1
+ interface_weight = 10
+
+ # Retrieve relevant features
+ token_type = feats["mol_type"]
+ token_type = token_type.repeat_interleave(multiplicity, 0)
+ is_ligand_token = (token_type == const.chain_type_ids["NONPOLYMER"]).float()
+
+ # Compute the aggregated pLDDT and iPLDDT
+ plddt = compute_aggregated_metric(plddt_logits)
+ token_pad_mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
+ complex_plddt = (plddt * token_pad_mask).sum(dim=-1) / token_pad_mask.sum(
+ dim=-1
+ )
+
+ is_contact = (d < 8).float()
+ is_different_chain = (
+ feats["asym_id"].unsqueeze(-1) != feats["asym_id"].unsqueeze(-2)
+ ).float()
+ is_different_chain = is_different_chain.repeat_interleave(multiplicity, 0)
+ token_interface_mask = torch.max(
+ is_contact * is_different_chain * (1 - is_ligand_token).unsqueeze(-1),
+ dim=-1,
+ ).values
+ token_non_interface_mask = (1 - token_interface_mask) * (1 - is_ligand_token)
+ iplddt_weight = (
+ is_ligand_token * ligand_weight
+ + token_interface_mask * interface_weight
+ + token_non_interface_mask * non_interface_weight
+ )
+ complex_iplddt = (plddt * token_pad_mask * iplddt_weight).sum(
+ dim=-1
+ ) / torch.sum(token_pad_mask * iplddt_weight, dim=-1)
+
+ # Compute the aggregated PDE and iPDE
+ pde = compute_aggregated_metric(pde_logits, end=32)
+ pred_distogram_prob = nn.functional.softmax(
+ pred_distogram_logits, dim=-1
+ ).repeat_interleave(multiplicity, 0)
+ contacts = torch.zeros((1, 1, 1, 64), dtype=pred_distogram_prob.dtype).to(
+ pred_distogram_prob.device
+ )
+ contacts[:, :, :, :20] = 1.0
+ prob_contact = (pred_distogram_prob * contacts).sum(-1)
+ token_pad_mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
+ token_pad_pair_mask = (
+ token_pad_mask.unsqueeze(-1)
+ * token_pad_mask.unsqueeze(-2)
+ * (
+ 1
+ - torch.eye(
+ token_pad_mask.shape[1], device=token_pad_mask.device
+ ).unsqueeze(0)
+ )
+ )
+ token_pair_mask = token_pad_pair_mask * prob_contact
+ complex_pde = (pde * token_pair_mask).sum(dim=(1, 2)) / token_pair_mask.sum(
+ dim=(1, 2)
+ )
+ asym_id = feats["asym_id"].repeat_interleave(multiplicity, 0)
+ token_interface_pair_mask = token_pair_mask * (
+ asym_id.unsqueeze(-1) != asym_id.unsqueeze(-2)
+ )
+ complex_ipde = (pde * token_interface_pair_mask).sum(dim=(1, 2)) / (
+ token_interface_pair_mask.sum(dim=(1, 2)) + 1e-5
+ )
+
+ out_dict = dict(
+ pde_logits=pde_logits,
+ plddt_logits=plddt_logits,
+ resolved_logits=resolved_logits,
+ pde=pde,
+ plddt=plddt,
+ complex_plddt=complex_plddt,
+ complex_iplddt=complex_iplddt,
+ complex_pde=complex_pde,
+ complex_ipde=complex_ipde,
+ )
+ if self.compute_pae:
+ out_dict["pae_logits"] = pae_logits
+ out_dict["pae"] = compute_aggregated_metric(pae_logits, end=32)
+ ptm, iptm, ligand_iptm, protein_iptm = compute_ptms(
+ pae_logits, x_pred, feats, multiplicity
+ )
+ out_dict["ptm"] = ptm
+ out_dict["iptm"] = iptm
+ out_dict["ligand_iptm"] = ligand_iptm
+ out_dict["protein_iptm"] = protein_iptm
+
+ return out_dict
diff --git a/src/boltz/model/modules/confidence_utils.py b/src/boltz/model/modules/confidence_utils.py
new file mode 100644
index 0000000..7b77be2
--- /dev/null
+++ b/src/boltz/model/modules/confidence_utils.py
@@ -0,0 +1,160 @@
+import torch
+from torch import nn
+
+from boltz.data import const
+from boltz.model.loss.confidence import compute_frame_pred
+
+
+def compute_aggregated_metric(logits, end=1.0):
+ """Compute the metric from the logits.
+
+ Parameters
+ ----------
+ logits : torch.Tensor
+ The logits of the metric
+ end : float
+ Max value of the metric, by default 1.0
+
+ Returns
+ -------
+ Tensor
+ The metric value
+
+ """
+ num_bins = logits.shape[-1]
+ bin_width = end / num_bins
+ bounds = torch.arange(
+ start=0.5 * bin_width, end=end, step=bin_width, device=logits.device
+ )
+ probs = nn.functional.softmax(logits, dim=-1)
+ plddt = torch.sum(
+ probs * bounds.view(*((1,) * len(probs.shape[:-1])), *bounds.shape),
+ dim=-1,
+ )
+ return plddt
+
+
+def tm_function(d, Nres):
+ """Compute the rescaling function for pTM.
+
+ Parameters
+ ----------
+ d : torch.Tensor
+ The input
+ Nres : torch.Tensor
+ The number of residues
+
+ Returns
+ -------
+ Tensor
+ Output of the function
+
+ """
+ d0 = 1.24 * (torch.clip(Nres, min=19) - 15) ** (1 / 3) - 1.8
+ return 1 / (1 + (d / d0) ** 2)
+
+
+def compute_ptms(logits, x_preds, feats, multiplicity):
+ """Compute pTM and ipTM scores.
+
+ Parameters
+ ----------
+ logits : torch.Tensor
+ pae logits
+ x_preds : torch.Tensor
+ The predicted coordinates
+ feats : Dict[str, torch.Tensor]
+ The input features
+ multiplicity : int
+ The batch size of the diffusion roll-out
+
+ Returns
+ -------
+ Tensor
+ pTM score
+ Tensor
+ ipTM score
+ Tensor
+ ligand ipTM score
+ Tensor
+ protein ipTM score
+
+ """
+ # Compute mask for collinear and overlapping tokens
+ _, mask_collinear_pred = compute_frame_pred(
+ x_preds, feats["frames_idx"], feats, multiplicity, inference=True
+ )
+ mask_pad = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
+ maski = mask_collinear_pred.reshape(-1, mask_collinear_pred.shape[-1])
+ pair_mask_ptm = maski[:, :, None] * mask_pad[:, None, :] * mask_pad[:, :, None]
+ asym_id = feats["asym_id"].repeat_interleave(multiplicity, 0)
+ pair_mask_iptm = (
+ maski[:, :, None]
+ * (asym_id[:, None, :] != asym_id[:, :, None])
+ * mask_pad[:, None, :]
+ * mask_pad[:, :, None]
+ )
+
+ # Extract pae values
+ num_bins = logits.shape[-1]
+ bin_width = 32.0 / num_bins
+ end = 32.0
+ pae_value = torch.arange(
+ start=0.5 * bin_width, end=end, step=bin_width, device=logits.device
+ ).unsqueeze(0)
+ N_res = mask_pad.sum(dim=-1, keepdim=True)
+
+ # compute pTM and ipTM
+ tm_value = tm_function(pae_value, N_res).unsqueeze(1).unsqueeze(2)
+ probs = nn.functional.softmax(logits, dim=-1)
+ tm_expected_value = torch.sum(
+ probs * tm_value,
+ dim=-1,
+ ) # shape (B, N, N)
+ ptm = torch.max(
+ torch.sum(tm_expected_value * pair_mask_ptm, dim=-1)
+ / (torch.sum(pair_mask_ptm, dim=-1) + 1e-5),
+ dim=1,
+ ).values
+ iptm = torch.max(
+ torch.sum(tm_expected_value * pair_mask_iptm, dim=-1)
+ / (torch.sum(pair_mask_iptm, dim=-1) + 1e-5),
+ dim=1,
+ ).values
+
+ # compute ligand and protein ipTM
+ token_type = feats["mol_type"]
+ token_type = token_type.repeat_interleave(multiplicity, 0)
+ is_ligand_token = (token_type == const.chain_type_ids["NONPOLYMER"]).float()
+ is_protein_token = (token_type == const.chain_type_ids["PROTEIN"]).float()
+
+ ligand_iptm_mask = (
+ maski[:, :, None]
+ * (asym_id[:, None, :] != asym_id[:, :, None])
+ * mask_pad[:, None, :]
+ * mask_pad[:, :, None]
+ * (
+ (is_ligand_token[:, :, None] * is_protein_token[:, None, :])
+ + (is_protein_token[:, :, None] * is_ligand_token[:, None, :])
+ )
+ )
+ protein_ipmt_mask = (
+ maski[:, :, None]
+ * (asym_id[:, None, :] != asym_id[:, :, None])
+ * mask_pad[:, None, :]
+ * mask_pad[:, :, None]
+ * (is_protein_token[:, :, None] * is_protein_token[:, None, :])
+ )
+
+ ligand_iptm = torch.max(
+ torch.sum(tm_expected_value * ligand_iptm_mask, dim=-1)
+ / (torch.sum(ligand_iptm_mask, dim=-1) + 1e-5),
+ dim=1,
+ ).values
+ protein_iptm = torch.max(
+ torch.sum(tm_expected_value * protein_ipmt_mask, dim=-1)
+ / (torch.sum(protein_ipmt_mask, dim=-1) + 1e-5),
+ dim=1,
+ ).values
+
+ return ptm, iptm, ligand_iptm, protein_iptm
diff --git a/src/boltz/model/modules/diffusion.py b/src/boltz/model/modules/diffusion.py
new file mode 100644
index 0000000..55b7b3e
--- /dev/null
+++ b/src/boltz/model/modules/diffusion.py
@@ -0,0 +1,695 @@
+# started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
+
+from __future__ import annotations
+
+from math import sqrt
+import random
+
+from einops import rearrange
+import torch
+from torch import nn
+from torch.nn import Module
+import torch.nn.functional as F
+
+from boltz.data import const
+import boltz.model.layers.initialize as init
+from boltz.model.loss.diffusion import (
+ smooth_lddt_loss,
+ weighted_rigid_align,
+)
+from boltz.model.modules.encoders import (
+ AtomAttentionDecoder,
+ AtomAttentionEncoder,
+ FourierEmbedding,
+ PairwiseConditioning,
+ SingleConditioning,
+)
+from boltz.model.modules.transformers import (
+ ConditionedTransitionBlock,
+ DiffusionTransformer,
+)
+from boltz.model.modules.utils import (
+ LinearNoBias,
+ center_random_augmentation,
+ default,
+ log,
+)
+
+
+class DiffusionModule(Module):
+ """Diffusion module"""
+
+ def __init__(
+ self,
+ token_s: int,
+ token_z: int,
+ atom_s: int,
+ atom_z: int,
+ atoms_per_window_queries: int = 32,
+ atoms_per_window_keys: int = 128,
+ sigma_data: int = 16,
+ dim_fourier: int = 256,
+ atom_encoder_depth: int = 3,
+ atom_encoder_heads: int = 4,
+ token_transformer_depth: int = 24,
+ token_transformer_heads: int = 8,
+ atom_decoder_depth: int = 3,
+ atom_decoder_heads: int = 4,
+ atom_feature_dim: int = 128,
+ conditioning_transition_layers: int = 2,
+ activation_checkpointing: bool = False,
+ offload_to_cpu: bool = False,
+ **kwargs,
+ ) -> None:
+ """Initialize the diffusion module.
+
+ Parameters
+ ----------
+ token_s : int
+ The single representation dimension.
+ token_z : int
+ The pair representation dimension.
+ atom_s : int
+ The atom single representation dimension.
+ atom_z : int
+ The atom pair representation dimension.
+ atoms_per_window_queries : int, optional
+ The number of atoms per window for queries, by default 32.
+ atoms_per_window_keys : int, optional
+ The number of atoms per window for keys, by default 128.
+ sigma_data : int, optional
+ The standard deviation of the data distribution, by default 16.
+ dim_fourier : int, optional
+ The dimension of the fourier embedding, by default 256.
+ atom_encoder_depth : int, optional
+ The depth of the atom encoder, by default 3.
+ atom_encoder_heads : int, optional
+ The number of heads in the atom encoder, by default 4.
+ token_transformer_depth : int, optional
+ The depth of the token transformer, by default 24.
+ token_transformer_heads : int, optional
+ The number of heads in the token transformer, by default 8.
+ atom_decoder_depth : int, optional
+ The depth of the atom decoder, by default 3.
+ atom_decoder_heads : int, optional
+ The number of heads in the atom decoder, by default 4.
+ atom_feature_dim : int, optional
+ The atom feature dimension, by default 128.
+ conditioning_transition_layers : int, optional
+ The number of transition layers for conditioning, by default 2.
+ activation_checkpointing : bool, optional
+ Whether to use activation checkpointing, by default False.
+ offload_to_cpu : bool, optional
+ Whether to offload the activations to CPU, by default False.
+
+ """
+
+ super().__init__()
+
+ self.atoms_per_window_queries = atoms_per_window_queries
+ self.atoms_per_window_keys = atoms_per_window_keys
+ self.sigma_data = sigma_data
+
+ self.single_conditioner = SingleConditioning(
+ sigma_data=sigma_data,
+ token_s=token_s,
+ dim_fourier=dim_fourier,
+ num_transitions=conditioning_transition_layers,
+ )
+ self.pairwise_conditioner = PairwiseConditioning(
+ token_z=token_z,
+ dim_token_rel_pos_feats=token_z,
+ num_transitions=conditioning_transition_layers,
+ )
+
+ self.atom_attention_encoder = AtomAttentionEncoder(
+ atom_s=atom_s,
+ atom_z=atom_z,
+ token_s=token_s,
+ token_z=token_z,
+ atoms_per_window_queries=atoms_per_window_queries,
+ atoms_per_window_keys=atoms_per_window_keys,
+ atom_feature_dim=atom_feature_dim,
+ atom_encoder_depth=atom_encoder_depth,
+ atom_encoder_heads=atom_encoder_heads,
+ structure_prediction=True,
+ activation_checkpointing=activation_checkpointing,
+ )
+
+ self.s_to_a_linear = nn.Sequential(
+ nn.LayerNorm(2 * token_s), LinearNoBias(2 * token_s, 2 * token_s)
+ )
+ init.final_init_(self.s_to_a_linear[1].weight)
+
+ self.token_transformer = DiffusionTransformer(
+ dim=2 * token_s,
+ dim_single_cond=2 * token_s,
+ dim_pairwise=token_z,
+ depth=token_transformer_depth,
+ heads=token_transformer_heads,
+ activation_checkpointing=activation_checkpointing,
+ offload_to_cpu=offload_to_cpu,
+ )
+
+ self.a_norm = nn.LayerNorm(2 * token_s)
+
+ self.atom_attention_decoder = AtomAttentionDecoder(
+ atom_s=atom_s,
+ atom_z=atom_z,
+ token_s=token_s,
+ attn_window_queries=atoms_per_window_queries,
+ attn_window_keys=atoms_per_window_keys,
+ atom_decoder_depth=atom_decoder_depth,
+ atom_decoder_heads=atom_decoder_heads,
+ activation_checkpointing=activation_checkpointing,
+ )
+
+ def forward(
+ self,
+ s_inputs,
+ s_trunk,
+ z_trunk,
+ r_noisy,
+ times,
+ relative_position_encoding,
+ feats,
+ multiplicity=1,
+ model_cache=None,
+ ):
+ s, normed_fourier = self.single_conditioner(
+ times=times,
+ s_trunk=s_trunk.repeat_interleave(multiplicity, 0),
+ s_inputs=s_inputs.repeat_interleave(multiplicity, 0),
+ )
+
+ if model_cache is None or len(model_cache) == 0:
+ z = self.pairwise_conditioner(
+ z_trunk=z_trunk, token_rel_pos_feats=relative_position_encoding
+ )
+ else:
+ z = None
+
+ # Compute Atom Attention Encoder and aggregation to coarse-grained tokens
+ a, q_skip, c_skip, p_skip, to_keys = self.atom_attention_encoder(
+ feats=feats,
+ s_trunk=s_trunk,
+ z=z,
+ r=r_noisy,
+ multiplicity=multiplicity,
+ model_cache=model_cache,
+ )
+
+ # Full self-attention on token level
+ a = a + self.s_to_a_linear(s)
+
+ mask = feats["token_pad_mask"].repeat_interleave(multiplicity, 0)
+ a = self.token_transformer(
+ a,
+ mask=mask.float(),
+ s=s,
+ z=z, # note z is not expanded with multiplicity until after bias is computed
+ multiplicity=multiplicity,
+ model_cache=model_cache,
+ )
+ a = self.a_norm(a)
+
+ # Broadcast token activations to atoms and run Sequence-local Atom Attention
+ r_update = self.atom_attention_decoder(
+ a=a,
+ q=q_skip,
+ c=c_skip,
+ p=p_skip,
+ feats=feats,
+ multiplicity=multiplicity,
+ to_keys=to_keys,
+ model_cache=model_cache,
+ )
+
+ return {"r_update": r_update, "token_a": a}
+
+
+class OutTokenFeatUpdate(Module):
+ """Output token feature update"""
+
+ def __init__(
+ self,
+ sigma_data: float,
+ token_s=384,
+ dim_fourier=256,
+ ):
+ """Initialize the Output token feature update for confidence model.
+
+ Parameters
+ ----------
+ sigma_data : float
+ The standard deviation of the data distribution.
+ token_s : int, optional
+ The token dimension, by default 384.
+ dim_fourier : int, optional
+ The dimension of the fourier embedding, by default 256.
+
+ """
+
+ super().__init__()
+ self.sigma_data = sigma_data
+
+ self.norm_next = nn.LayerNorm(2 * token_s)
+ self.fourier_embed = FourierEmbedding(dim_fourier)
+ self.norm_fourier = nn.LayerNorm(dim_fourier)
+ self.transition_block = ConditionedTransitionBlock(
+ 2 * token_s, 2 * token_s + dim_fourier
+ )
+
+ def forward(
+ self,
+ times,
+ acc_a,
+ next_a,
+ ):
+ next_a = self.norm_next(next_a)
+ fourier_embed = self.fourier_embed(times)
+ normed_fourier = (
+ self.norm_fourier(fourier_embed)
+ .unsqueeze(1)
+ .expand(-1, next_a.shape[1], -1)
+ )
+ cond_a = torch.cat((acc_a, normed_fourier), dim=-1)
+
+ acc_a = acc_a + self.transition_block(next_a, cond_a)
+
+ return acc_a
+
+
+class AtomDiffusion(Module):
+ """Atom diffusion module"""
+
+ def __init__(
+ self,
+ score_model_args,
+ num_sampling_steps=5,
+ sigma_min=0.0004,
+ sigma_max=160.0,
+ sigma_data=16.0,
+ rho=7,
+ P_mean=-1.2,
+ P_std=1.5,
+ gamma_0=0.8,
+ gamma_min=1.0,
+ noise_scale=1.003,
+ step_scale=1.5,
+ coordinate_augmentation=True,
+ compile_score=False,
+ alignment_reverse_diff=False,
+ synchronize_sigmas=False,
+ use_inference_model_cache=False,
+ accumulate_token_repr=False,
+ **kwargs,
+ ):
+ """Initialize the atom diffusion module.
+
+ Parameters
+ ----------
+ score_model_args : dict
+ The arguments for the score model.
+ num_sampling_steps : int, optional
+ The number of sampling steps, by default 5.
+ sigma_min : float, optional
+ The minimum sigma value, by default 0.0004.
+ sigma_max : float, optional
+ The maximum sigma value, by default 160.0.
+ sigma_data : float, optional
+ The standard deviation of the data distribution, by default 16.0.
+ rho : int, optional
+ The rho value, by default 7.
+ P_mean : float, optional
+ The mean value of P, by default -1.2.
+ P_std : float, optional
+ The standard deviation of P, by default 1.5.
+ gamma_0 : float, optional
+ The gamma value, by default 0.8.
+ gamma_min : float, optional
+ The minimum gamma value, by default 1.0.
+ noise_scale : float, optional
+ The noise scale, by default 1.003.
+ step_scale : float, optional
+ The step scale, by default 1.5.
+ coordinate_augmentation : bool, optional
+ Whether to use coordinate augmentation, by default True.
+ compile_score : bool, optional
+ Whether to compile the score model, by default False.
+ alignment_reverse_diff : bool, optional
+ Whether to use alignment reverse diff, by default False.
+ synchronize_sigmas : bool, optional
+ Whether to synchronize the sigmas, by default False.
+ use_inference_model_cache : bool, optional
+ Whether to use the inference model cache, by default False.
+ accumulate_token_repr : bool, optional
+ Whether to accumulate the token representation, by default False.
+
+ """
+ super().__init__()
+ self.score_model = DiffusionModule(
+ **score_model_args,
+ )
+ if compile_score:
+ self.score_model = torch.compile(
+ self.score_model, dynamic=False, fullgraph=False
+ )
+
+ # parameters
+ self.sigma_min = sigma_min
+ self.sigma_max = sigma_max
+ self.sigma_data = sigma_data
+ self.rho = rho
+ self.P_mean = P_mean
+ self.P_std = P_std
+ self.num_sampling_steps = num_sampling_steps
+ self.gamma_0 = gamma_0
+ self.gamma_min = gamma_min
+ self.noise_scale = noise_scale
+ self.step_scale = step_scale
+ self.coordinate_augmentation = coordinate_augmentation
+ self.alignment_reverse_diff = alignment_reverse_diff
+ self.synchronize_sigmas = synchronize_sigmas
+ self.use_inference_model_cache = use_inference_model_cache
+
+ self.accumulate_token_repr = accumulate_token_repr
+ self.token_s = score_model_args["token_s"]
+ if self.accumulate_token_repr:
+ self.out_token_feat_update = OutTokenFeatUpdate(
+ sigma_data=sigma_data,
+ token_s=score_model_args["token_s"],
+ dim_fourier=score_model_args["dim_fourier"],
+ )
+
+ self.register_buffer("zero", torch.tensor(0.0), persistent=False)
+
+ @property
+ def device(self):
+ return next(self.score_model.parameters()).device
+
+ def c_skip(self, sigma):
+ return (self.sigma_data**2) / (sigma**2 + self.sigma_data**2)
+
+ def c_out(self, sigma):
+ return sigma * self.sigma_data / torch.sqrt(self.sigma_data**2 + sigma**2)
+
+ def c_in(self, sigma):
+ return 1 / torch.sqrt(sigma**2 + self.sigma_data**2)
+
+ def c_noise(self, sigma):
+ return log(sigma / self.sigma_data) * 0.25
+
+ def preconditioned_network_forward(
+ self,
+ noised_atom_coords,
+ sigma,
+ network_condition_kwargs: dict,
+ training: bool = True,
+ ):
+ batch, device = noised_atom_coords.shape[0], noised_atom_coords.device
+
+ if isinstance(sigma, float):
+ sigma = torch.full((batch,), sigma, device=device)
+
+ padded_sigma = rearrange(sigma, "b -> b 1 1")
+
+ net_out = self.score_model(
+ r_noisy=self.c_in(padded_sigma) * noised_atom_coords,
+ times=self.c_noise(sigma),
+ **network_condition_kwargs,
+ )
+
+ denoised_coords = (
+ self.c_skip(padded_sigma) * noised_atom_coords
+ + self.c_out(padded_sigma) * net_out["r_update"]
+ )
+ return denoised_coords, net_out["token_a"]
+
+ def sample_schedule(self, num_sampling_steps=None):
+ num_sampling_steps = default(num_sampling_steps, self.num_sampling_steps)
+ inv_rho = 1 / self.rho
+
+ steps = torch.arange(
+ num_sampling_steps, device=self.device, dtype=torch.float32
+ )
+ sigmas = (
+ self.sigma_max**inv_rho
+ + steps
+ / (num_sampling_steps - 1)
+ * (self.sigma_min**inv_rho - self.sigma_max**inv_rho)
+ ) ** self.rho
+
+ sigmas = sigmas * self.sigma_data
+
+ sigmas = F.pad(sigmas, (0, 1), value=0.0) # last step is sigma value of 0.
+ return sigmas
+
+ def sample(
+ self,
+ atom_mask,
+ num_sampling_steps=None,
+ multiplicity=1,
+ train_accumulate_token_repr=False,
+ **network_condition_kwargs,
+ ):
+ num_sampling_steps = default(num_sampling_steps, self.num_sampling_steps)
+ atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
+
+ shape = (*atom_mask.shape, 3)
+
+ # get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma
+ sigmas = self.sample_schedule(num_sampling_steps)
+ gammas = torch.where(sigmas > self.gamma_min, self.gamma_0, 0.0)
+ sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[1:]))
+
+ # atom position is noise at the beginning
+ init_sigma = sigmas[0]
+ atom_coords = init_sigma * torch.randn(shape, device=self.device)
+ atom_coords_denoised = None
+ model_cache = {} if self.use_inference_model_cache else None
+
+ token_repr = None
+ token_a = None
+
+ # gradually denoise
+ for sigma_tm, sigma_t, gamma in sigmas_and_gammas:
+ atom_coords, atom_coords_denoised = center_random_augmentation(
+ atom_coords,
+ atom_mask,
+ augmentation=True,
+ return_second_coords=True,
+ second_coords=atom_coords_denoised,
+ )
+
+ sigma_tm, sigma_t, gamma = sigma_tm.item(), sigma_t.item(), gamma.item()
+
+ t_hat = sigma_tm * (1 + gamma)
+ eps = (
+ self.noise_scale
+ * sqrt(t_hat**2 - sigma_tm**2)
+ * torch.randn(shape, device=self.device)
+ )
+ atom_coords_noisy = atom_coords + eps
+
+ with torch.no_grad():
+ atom_coords_denoised, token_a = self.preconditioned_network_forward(
+ atom_coords_noisy,
+ t_hat,
+ training=False,
+ network_condition_kwargs=dict(
+ multiplicity=multiplicity,
+ model_cache=model_cache,
+ **network_condition_kwargs,
+ ),
+ )
+
+ if self.accumulate_token_repr:
+ if token_repr is None:
+ token_repr = torch.zeros_like(token_a)
+
+ with torch.set_grad_enabled(train_accumulate_token_repr):
+ sigma = torch.full(
+ (atom_coords_denoised.shape[0],),
+ t_hat,
+ device=atom_coords_denoised.device,
+ )
+ token_repr = self.out_token_feat_update(
+ times=self.c_noise(sigma), acc_a=token_repr, next_a=token_a
+ )
+
+ if self.alignment_reverse_diff:
+ with torch.autocast("cuda", enabled=False):
+ atom_coords_noisy = weighted_rigid_align(
+ atom_coords_noisy.float(),
+ atom_coords_denoised.float(),
+ atom_mask.float(),
+ atom_mask.float(),
+ )
+
+ atom_coords_noisy = atom_coords_noisy.to(atom_coords_denoised)
+
+ denoised_over_sigma = (atom_coords_noisy - atom_coords_denoised) / t_hat
+ atom_coords_next = (
+ atom_coords_noisy
+ + self.step_scale * (sigma_t - t_hat) * denoised_over_sigma
+ )
+
+ atom_coords = atom_coords_next
+
+ return dict(sample_atom_coords=atom_coords, diff_token_repr=token_repr)
+
+ def loss_weight(self, sigma):
+ return (sigma**2 + self.sigma_data**2) / ((sigma * self.sigma_data) ** 2)
+
+ def noise_distribution(self, batch_size):
+ return (
+ self.sigma_data
+ * (
+ self.P_mean
+ + self.P_std * torch.randn((batch_size,), device=self.device)
+ ).exp()
+ )
+
+ def forward(
+ self,
+ s_inputs,
+ s_trunk,
+ z_trunk,
+ relative_position_encoding,
+ feats,
+ multiplicity=1,
+ ):
+ # training diffusion step
+ batch_size = feats["coords"].shape[0]
+
+ if self.synchronize_sigmas:
+ sigmas = self.noise_distribution(batch_size).repeat_interleave(
+ multiplicity, 0
+ )
+ else:
+ sigmas = self.noise_distribution(batch_size * multiplicity)
+ padded_sigmas = rearrange(sigmas, "b -> b 1 1")
+
+ atom_coords = feats["coords"]
+ B, N, L = atom_coords.shape[0:3]
+ atom_coords = atom_coords.reshape(B * N, L, 3)
+ atom_coords = atom_coords.repeat_interleave(multiplicity // N, 0)
+ feats["coords"] = atom_coords
+
+ atom_mask = feats["atom_pad_mask"]
+ atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
+
+ atom_coords = center_random_augmentation(
+ atom_coords, atom_mask, augmentation=self.coordinate_augmentation
+ )
+
+ noise = torch.randn_like(atom_coords)
+ noised_atom_coords = atom_coords + padded_sigmas * noise
+
+ denoised_atom_coords, _ = self.preconditioned_network_forward(
+ noised_atom_coords,
+ sigmas,
+ training=True,
+ network_condition_kwargs=dict(
+ s_inputs=s_inputs,
+ s_trunk=s_trunk,
+ z_trunk=z_trunk,
+ relative_position_encoding=relative_position_encoding,
+ feats=feats,
+ multiplicity=multiplicity,
+ ),
+ )
+
+ return dict(
+ noised_atom_coords=noised_atom_coords,
+ denoised_atom_coords=denoised_atom_coords,
+ sigmas=sigmas,
+ aligned_true_atom_coords=atom_coords,
+ )
+
+ def compute_loss(
+ self,
+ feats,
+ out_dict,
+ add_smooth_lddt_loss=True,
+ nucleotide_loss_weight=5.0,
+ ligand_loss_weight=10.0,
+ multiplicity=1,
+ ):
+ denoised_atom_coords = out_dict["denoised_atom_coords"]
+ noised_atom_coords = out_dict["noised_atom_coords"]
+ sigmas = out_dict["sigmas"]
+
+ resolved_atom_mask = feats["atom_resolved_mask"]
+ resolved_atom_mask = resolved_atom_mask.repeat_interleave(multiplicity, 0)
+
+ align_weights = noised_atom_coords.new_ones(noised_atom_coords.shape[:2])
+ atom_type = (
+ torch.bmm(
+ feats["atom_to_token"].float(), feats["mol_type"].unsqueeze(-1).float()
+ )
+ .squeeze(-1)
+ .long()
+ )
+ atom_type_mult = atom_type.repeat_interleave(multiplicity, 0)
+
+ align_weights = align_weights * (
+ 1
+ + nucleotide_loss_weight
+ * (
+ torch.eq(atom_type_mult, const.chain_type_ids["DNA"]).float()
+ + torch.eq(atom_type_mult, const.chain_type_ids["RNA"]).float()
+ )
+ + ligand_loss_weight
+ * torch.eq(atom_type_mult, const.chain_type_ids["NONPOLYMER"]).float()
+ )
+
+ with torch.no_grad(), torch.autocast("cuda", enabled=False):
+ atom_coords = out_dict["aligned_true_atom_coords"]
+ atom_coords_aligned_ground_truth = weighted_rigid_align(
+ atom_coords.detach().float(),
+ denoised_atom_coords.detach().float(),
+ align_weights.detach().float(),
+ mask=resolved_atom_mask.detach().float(),
+ )
+
+ # Cast back
+ atom_coords_aligned_ground_truth = atom_coords_aligned_ground_truth.to(
+ denoised_atom_coords
+ )
+
+ # weighted MSE loss of denoised atom positions
+ mse_loss = ((denoised_atom_coords - atom_coords_aligned_ground_truth) ** 2).sum(
+ dim=-1
+ )
+ mse_loss = torch.sum(
+ mse_loss * align_weights * resolved_atom_mask, dim=-1
+ ) / torch.sum(3 * align_weights * resolved_atom_mask, dim=-1)
+
+ # weight by sigma factor
+ loss_weights = self.loss_weight(sigmas)
+ mse_loss = (mse_loss * loss_weights).mean()
+
+ total_loss = mse_loss
+
+ # proposed auxiliary smooth lddt loss
+ lddt_loss = self.zero
+ if add_smooth_lddt_loss:
+ lddt_loss = smooth_lddt_loss(
+ denoised_atom_coords,
+ feats["coords"],
+ torch.eq(atom_type, const.chain_type_ids["DNA"]).float()
+ + torch.eq(atom_type, const.chain_type_ids["RNA"]).float(),
+ coords_mask=feats["atom_resolved_mask"],
+ multiplicity=multiplicity,
+ )
+
+ total_loss = total_loss + lddt_loss
+
+ loss_breakdown = dict(
+ mse_loss=mse_loss,
+ smooth_lddt_loss=lddt_loss,
+ )
+
+ return dict(loss=total_loss, loss_breakdown=loss_breakdown)
diff --git a/src/boltz/model/modules/encoders.py b/src/boltz/model/modules/encoders.py
new file mode 100644
index 0000000..2452fc9
--- /dev/null
+++ b/src/boltz/model/modules/encoders.py
@@ -0,0 +1,631 @@
+# started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
+from functools import partial
+from math import pi
+
+from einops import rearrange
+import torch
+from torch import nn
+from torch.nn import Module, ModuleList
+from torch.nn.functional import one_hot
+
+from boltz.data import const
+import boltz.model.layers.initialize as init
+from boltz.model.layers.transition import Transition
+from boltz.model.modules.transformers import AtomTransformer
+from boltz.model.modules.utils import LinearNoBias
+
+
+class FourierEmbedding(Module):
+ """Fourier embedding layer."""
+
+ def __init__(self, dim):
+ """Initialize the Fourier Embeddings.
+
+ Parameters
+ ----------
+ dim : int
+ The dimension of the embeddings.
+
+ """
+
+ super().__init__()
+ self.proj = nn.Linear(1, dim)
+ torch.nn.init.normal_(self.proj.weight, mean=0, std=1)
+ torch.nn.init.normal_(self.proj.bias, mean=0, std=1)
+ self.proj.requires_grad_(False)
+
+ def forward(
+ self,
+ times,
+ ):
+ times = rearrange(times, "b -> b 1")
+ rand_proj = self.proj(times)
+ return torch.cos(2 * pi * rand_proj)
+
+
+class RelativePositionEncoder(Module):
+ """Relative position encoder."""
+
+ def __init__(self, token_z, r_max=32, s_max=2):
+ """Initialize the relative position encoder.
+
+ Parameters
+ ----------
+ token_z : int
+ The pair representation dimension.
+ r_max : int, optional
+ The maximum index distance, by default 32.
+ s_max : int, optional
+ The maximum chain distance, by default 2.
+
+ """
+ super().__init__()
+ self.r_max = r_max
+ self.s_max = s_max
+ self.linear_layer = LinearNoBias(4 * (r_max + 1) + 2 * (s_max + 1) + 1, token_z)
+
+ def forward(self, feats):
+ b_same_chain = torch.eq(
+ feats["asym_id"][:, :, None], feats["asym_id"][:, None, :]
+ )
+ b_same_residue = torch.eq(
+ feats["residue_index"][:, :, None], feats["residue_index"][:, None, :]
+ )
+ b_same_entity = torch.eq(
+ feats["entity_id"][:, :, None], feats["entity_id"][:, None, :]
+ )
+
+ d_residue = torch.clip(
+ feats["residue_index"][:, :, None]
+ - feats["residue_index"][:, None, :]
+ + self.r_max,
+ 0,
+ 2 * self.r_max,
+ )
+ d_residue = torch.where(
+ b_same_chain, d_residue, torch.zeros_like(d_residue) + 2 * self.r_max + 1
+ )
+ a_rel_pos = one_hot(d_residue, 2 * self.r_max + 2)
+
+ d_token = torch.clip(
+ feats["token_index"][:, :, None]
+ - feats["token_index"][:, None, :]
+ + self.r_max,
+ 0,
+ 2 * self.r_max,
+ )
+ d_token = torch.where(
+ b_same_chain & b_same_residue,
+ d_token,
+ torch.zeros_like(d_token) + 2 * self.r_max + 1,
+ )
+ a_rel_token = one_hot(d_token, 2 * self.r_max + 2)
+
+ d_chain = torch.clip(
+ feats["sym_id"][:, :, None] - feats["sym_id"][:, None, :] + self.s_max,
+ 0,
+ 2 * self.s_max,
+ )
+ d_chain = torch.where(
+ b_same_chain, torch.zeros_like(d_chain) + 2 * self.s_max + 1, d_chain
+ )
+ a_rel_chain = one_hot(d_chain, 2 * self.s_max + 2)
+
+ p = self.linear_layer(
+ torch.cat(
+ [
+ a_rel_pos.float(),
+ a_rel_token.float(),
+ b_same_entity.unsqueeze(-1).float(),
+ a_rel_chain.float(),
+ ],
+ dim=-1,
+ )
+ )
+ return p
+
+
+class SingleConditioning(Module):
+ """Single conditioning layer."""
+
+ def __init__(
+ self,
+ sigma_data: float,
+ token_s=384,
+ dim_fourier=256,
+ num_transitions=2,
+ transition_expansion_factor=2,
+ eps=1e-20,
+ ):
+ """Initialize the single conditioning layer.
+
+ Parameters
+ ----------
+ sigma_data : float
+ The data sigma.
+ token_s : int, optional
+ The single representation dimension, by default 384.
+ dim_fourier : int, optional
+ The fourier embeddings dimension, by default 256.
+ num_transitions : int, optional
+ The number of transitions layers, by default 2.
+ transition_expansion_factor : int, optional
+ The transition expansion factor, by default 2.
+ eps : float, optional
+ The epsilon value, by default 1e-20.
+
+ """
+ super().__init__()
+ self.eps = eps
+ self.sigma_data = sigma_data
+
+ input_dim = (
+ 2 * token_s + 2 * const.num_tokens + 1 + len(const.pocket_contact_info)
+ )
+ self.norm_single = nn.LayerNorm(input_dim)
+ self.single_embed = nn.Linear(input_dim, 2 * token_s)
+ self.fourier_embed = FourierEmbedding(dim_fourier)
+ self.norm_fourier = nn.LayerNorm(dim_fourier)
+ self.fourier_to_single = LinearNoBias(dim_fourier, 2 * token_s)
+
+ transitions = ModuleList([])
+ for _ in range(num_transitions):
+ transition = Transition(
+ dim=2 * token_s, hidden=transition_expansion_factor * 2 * token_s
+ )
+ transitions.append(transition)
+
+ self.transitions = transitions
+
+ def forward(
+ self,
+ *,
+ times,
+ s_trunk,
+ s_inputs,
+ ):
+ s = torch.cat((s_trunk, s_inputs), dim=-1)
+ s = self.single_embed(self.norm_single(s))
+ fourier_embed = self.fourier_embed(times)
+ normed_fourier = self.norm_fourier(fourier_embed)
+ fourier_to_single = self.fourier_to_single(normed_fourier)
+
+ s = rearrange(fourier_to_single, "b d -> b 1 d") + s
+
+ for transition in self.transitions:
+ s = transition(s) + s
+
+ return s, normed_fourier
+
+
+class PairwiseConditioning(Module):
+ """Pairwise conditioning layer."""
+
+ def __init__(
+ self,
+ token_z,
+ dim_token_rel_pos_feats,
+ num_transitions=2,
+ transition_expansion_factor=2,
+ ):
+ """Initialize the pairwise conditioning layer.
+
+ Parameters
+ ----------
+ token_z : int
+ The pair representation dimension.
+ dim_token_rel_pos_feats : int
+ The token relative position features dimension.
+ num_transitions : int, optional
+ The number of transitions layers, by default 2.
+ transition_expansion_factor : int, optional
+ The transition expansion factor, by default 2.
+
+ """
+ super().__init__()
+
+ self.dim_pairwise_init_proj = nn.Sequential(
+ nn.LayerNorm(token_z + dim_token_rel_pos_feats),
+ LinearNoBias(token_z + dim_token_rel_pos_feats, token_z),
+ )
+
+ transitions = ModuleList([])
+ for _ in range(num_transitions):
+ transition = Transition(
+ dim=token_z, hidden=transition_expansion_factor * token_z
+ )
+ transitions.append(transition)
+
+ self.transitions = transitions
+
+ def forward(
+ self,
+ z_trunk,
+ token_rel_pos_feats,
+ ):
+ z = torch.cat((z_trunk, token_rel_pos_feats), dim=-1)
+ z = self.dim_pairwise_init_proj(z)
+
+ for transition in self.transitions:
+ z = transition(z) + z
+
+ return z
+
+
+def get_indexing_matrix(K, W, H, device):
+ assert W % 2 == 0
+ assert H % (W // 2) == 0
+
+ h = H // (W // 2)
+ assert h % 2 == 0
+
+ arange = torch.arange(2 * K, device=device)
+ index = ((arange.unsqueeze(0) - arange.unsqueeze(1)) + h // 2).clamp(
+ min=0, max=h + 1
+ )
+ index = index.view(K, 2, 2 * K)[:, 0, :]
+ onehot = one_hot(index, num_classes=h + 2)[..., 1:-1].transpose(1, 0)
+ return onehot.reshape(2 * K, h * K).float()
+
+
+def single_to_keys(single, indexing_matrix, W, H):
+ B, N, D = single.shape
+ K = N // W
+ single = single.view(B, 2 * K, W // 2, D)
+ return torch.einsum("b j i d, j k -> b k i d", single, indexing_matrix).reshape(
+ B, K, H, D
+ )
+
+
+class AtomAttentionEncoder(Module):
+ """Atom attention encoder."""
+
+ def __init__(
+ self,
+ atom_s,
+ atom_z,
+ token_s,
+ token_z,
+ atoms_per_window_queries,
+ atoms_per_window_keys,
+ atom_feature_dim,
+ atom_encoder_depth=3,
+ atom_encoder_heads=4,
+ structure_prediction=True,
+ activation_checkpointing=False,
+ ):
+ """Initialize the atom attention encoder.
+
+ Parameters
+ ----------
+ atom_s : int
+ The atom single representation dimension.
+ atom_z : int
+ The atom pair representation dimension.
+ token_s : int
+ The single representation dimension.
+ token_z : int
+ The pair representation dimension.
+ atoms_per_window_queries : int
+ The number of atoms per window for queries.
+ atoms_per_window_keys : int
+ The number of atoms per window for keys.
+ atom_feature_dim : int
+ The atom feature dimension.
+ atom_encoder_depth : int, optional
+ The number of transformer layers, by default 3.
+ atom_encoder_heads : int, optional
+ The number of transformer heads, by default 4.
+ structure_prediction : bool, optional
+ Whether it is used in the diffusion module, by default True.
+ activation_checkpointing : bool, optional
+ Whether to use activation checkpointing, by default False.
+
+ """
+ super().__init__()
+
+ self.embed_atom_features = LinearNoBias(atom_feature_dim, atom_s)
+ self.embed_atompair_ref_pos = LinearNoBias(3, atom_z)
+ self.embed_atompair_ref_dist = LinearNoBias(1, atom_z)
+ self.embed_atompair_mask = LinearNoBias(1, atom_z)
+ self.atoms_per_window_queries = atoms_per_window_queries
+ self.atoms_per_window_keys = atoms_per_window_keys
+
+ self.structure_prediction = structure_prediction
+ if structure_prediction:
+ self.s_to_c_trans = nn.Sequential(
+ nn.LayerNorm(token_s), LinearNoBias(token_s, atom_s)
+ )
+ init.final_init_(self.s_to_c_trans[1].weight)
+
+ self.z_to_p_trans = nn.Sequential(
+ nn.LayerNorm(token_z), LinearNoBias(token_z, atom_z)
+ )
+ init.final_init_(self.z_to_p_trans[1].weight)
+
+ self.r_to_q_trans = LinearNoBias(10, atom_s)
+ init.final_init_(self.r_to_q_trans.weight)
+
+ self.c_to_p_trans_k = nn.Sequential(
+ nn.ReLU(),
+ LinearNoBias(atom_s, atom_z),
+ )
+ init.final_init_(self.c_to_p_trans_k[1].weight)
+
+ self.c_to_p_trans_q = nn.Sequential(
+ nn.ReLU(),
+ LinearNoBias(atom_s, atom_z),
+ )
+ init.final_init_(self.c_to_p_trans_q[1].weight)
+
+ self.p_mlp = nn.Sequential(
+ nn.ReLU(),
+ LinearNoBias(atom_z, atom_z),
+ nn.ReLU(),
+ LinearNoBias(atom_z, atom_z),
+ nn.ReLU(),
+ LinearNoBias(atom_z, atom_z),
+ )
+ init.final_init_(self.p_mlp[5].weight)
+
+ self.atom_encoder = AtomTransformer(
+ dim=atom_s,
+ dim_single_cond=atom_s,
+ dim_pairwise=atom_z,
+ attn_window_queries=atoms_per_window_queries,
+ attn_window_keys=atoms_per_window_keys,
+ depth=atom_encoder_depth,
+ heads=atom_encoder_heads,
+ activation_checkpointing=activation_checkpointing,
+ )
+
+ self.atom_to_token_trans = nn.Sequential(
+ LinearNoBias(atom_s, 2 * token_s if structure_prediction else token_s),
+ nn.ReLU(),
+ )
+
+ def forward(
+ self,
+ feats,
+ s_trunk=None,
+ z=None,
+ r=None,
+ multiplicity=1,
+ model_cache=None,
+ ):
+ B, N, _ = feats["ref_pos"].shape
+ atom_mask = feats["atom_pad_mask"].bool()
+
+ layer_cache = None
+ if model_cache is not None:
+ cache_prefix = "atomencoder"
+ if cache_prefix not in model_cache:
+ model_cache[cache_prefix] = {}
+ layer_cache = model_cache[cache_prefix]
+
+ if model_cache is None or len(layer_cache) == 0:
+ # either model is not using the cache or it is the first time running it
+
+ atom_ref_pos = feats["ref_pos"]
+ atom_uid = feats["ref_space_uid"]
+ atom_feats = torch.cat(
+ [
+ atom_ref_pos,
+ feats["ref_charge"].unsqueeze(-1),
+ feats["atom_pad_mask"].unsqueeze(-1),
+ feats["ref_element"],
+ feats["ref_atom_name_chars"].reshape(B, N, 4 * 64),
+ ],
+ dim=-1,
+ )
+
+ c = self.embed_atom_features(atom_feats)
+
+ # NOTE: we are already creating the windows to make it more efficient
+ W, H = self.atoms_per_window_queries, self.atoms_per_window_keys
+ B, N = c.shape[:2]
+ K = N // W
+ keys_indexing_matrix = get_indexing_matrix(K, W, H, c.device)
+ to_keys = partial(
+ single_to_keys, indexing_matrix=keys_indexing_matrix, W=W, H=H
+ )
+
+ atom_ref_pos_queries = atom_ref_pos.view(B, K, W, 1, 3)
+ atom_ref_pos_keys = to_keys(atom_ref_pos).view(B, K, 1, H, 3)
+
+ d = atom_ref_pos_keys - atom_ref_pos_queries
+ d_norm = torch.sum(d * d, dim=-1, keepdim=True)
+ d_norm = 1 / (1 + d_norm)
+
+ atom_mask_queries = atom_mask.view(B, K, W, 1)
+ atom_mask_keys = (
+ to_keys(atom_mask.unsqueeze(-1).float()).view(B, K, 1, H).bool()
+ )
+ atom_uid_queries = atom_uid.view(B, K, W, 1)
+ atom_uid_keys = (
+ to_keys(atom_uid.unsqueeze(-1).float()).view(B, K, 1, H).long()
+ )
+ v = (
+ (
+ atom_mask_queries
+ & atom_mask_keys
+ & (atom_uid_queries == atom_uid_keys)
+ )
+ .float()
+ .unsqueeze(-1)
+ )
+
+ p = self.embed_atompair_ref_pos(d) * v
+ p = p + self.embed_atompair_ref_dist(d_norm) * v
+ p = p + self.embed_atompair_mask(v) * v
+
+ q = c
+
+ if self.structure_prediction:
+ # run only in structure model not in initial encoding
+ atom_to_token = feats["atom_to_token"].float()
+
+ s_to_c = self.s_to_c_trans(s_trunk)
+ s_to_c = torch.bmm(atom_to_token, s_to_c)
+ c = c + s_to_c
+
+ atom_to_token_queries = atom_to_token.view(
+ B, K, W, atom_to_token.shape[-1]
+ )
+ atom_to_token_keys = to_keys(atom_to_token)
+ z_to_p = self.z_to_p_trans(z)
+ z_to_p = torch.einsum(
+ "bijd,bwki,bwlj->bwkld",
+ z_to_p,
+ atom_to_token_queries,
+ atom_to_token_keys,
+ )
+ p = p + z_to_p
+
+ p = p + self.c_to_p_trans_q(c.view(B, K, W, 1, c.shape[-1]))
+ p = p + self.c_to_p_trans_k(to_keys(c).view(B, K, 1, H, c.shape[-1]))
+ p = p + self.p_mlp(p)
+
+ if model_cache is not None:
+ layer_cache["q"] = q
+ layer_cache["c"] = c
+ layer_cache["p"] = p
+ layer_cache["to_keys"] = to_keys
+
+ else:
+ q = layer_cache["q"]
+ c = layer_cache["c"]
+ p = layer_cache["p"]
+ to_keys = layer_cache["to_keys"]
+
+ if self.structure_prediction:
+ # only here the multiplicity kicks in because we use the different positions r
+ q = q.repeat_interleave(multiplicity, 0)
+ r_input = torch.cat(
+ [r, torch.zeros((B * multiplicity, N, 7)).to(r)],
+ dim=-1,
+ )
+ r_to_q = self.r_to_q_trans(r_input)
+ q = q + r_to_q
+
+ c = c.repeat_interleave(multiplicity, 0)
+ atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
+
+ q = self.atom_encoder(
+ q=q,
+ mask=atom_mask,
+ c=c,
+ p=p,
+ multiplicity=multiplicity,
+ to_keys=to_keys,
+ model_cache=layer_cache,
+ )
+
+ q_to_a = self.atom_to_token_trans(q)
+ atom_to_token = feats["atom_to_token"].float()
+ atom_to_token = atom_to_token.repeat_interleave(multiplicity, 0)
+ atom_to_token_mean = atom_to_token / (
+ atom_to_token.sum(dim=1, keepdim=True) + 1e-6
+ )
+ a = torch.bmm(atom_to_token_mean.transpose(1, 2), q_to_a)
+
+ return a, q, c, p, to_keys
+
+
+class AtomAttentionDecoder(Module):
+ """Atom attention decoder."""
+
+ def __init__(
+ self,
+ atom_s,
+ atom_z,
+ token_s,
+ attn_window_queries,
+ attn_window_keys,
+ atom_decoder_depth=3,
+ atom_decoder_heads=4,
+ activation_checkpointing=False,
+ ):
+ """Initialize the atom attention decoder.
+
+ Parameters
+ ----------
+ atom_s : int
+ The atom single representation dimension.
+ atom_z : int
+ The atom pair representation dimension.
+ token_s : int
+ The single representation dimension.
+ attn_window_queries : int
+ The number of atoms per window for queries.
+ attn_window_keys : int
+ The number of atoms per window for keys.
+ atom_decoder_depth : int, optional
+ The number of transformer layers, by default 3.
+ atom_decoder_heads : int, optional
+ The number of transformer heads, by default 4.
+ activation_checkpointing : bool, optional
+ Whether to use activation checkpointing, by default False.
+
+ """
+ super().__init__()
+
+ self.a_to_q_trans = LinearNoBias(2 * token_s, atom_s)
+ init.final_init_(self.a_to_q_trans.weight)
+
+ self.atom_decoder = AtomTransformer(
+ dim=atom_s,
+ dim_single_cond=atom_s,
+ dim_pairwise=atom_z,
+ attn_window_queries=attn_window_queries,
+ attn_window_keys=attn_window_keys,
+ depth=atom_decoder_depth,
+ heads=atom_decoder_heads,
+ activation_checkpointing=activation_checkpointing,
+ )
+
+ self.atom_feat_to_atom_pos_update = nn.Sequential(
+ nn.LayerNorm(atom_s), LinearNoBias(atom_s, 3)
+ )
+ init.final_init_(self.atom_feat_to_atom_pos_update[1].weight)
+
+ def forward(
+ self,
+ a,
+ q,
+ c,
+ p,
+ feats,
+ to_keys,
+ multiplicity=1,
+ model_cache=None,
+ ):
+ atom_mask = feats["atom_pad_mask"]
+ atom_mask = atom_mask.repeat_interleave(multiplicity, 0)
+
+ atom_to_token = feats["atom_to_token"].float()
+ atom_to_token = atom_to_token.repeat_interleave(multiplicity, 0)
+
+ a_to_q = self.a_to_q_trans(a)
+ a_to_q = torch.bmm(atom_to_token, a_to_q)
+ q = q + a_to_q
+
+ layer_cache = None
+ if model_cache is not None:
+ cache_prefix = "atomdecoder"
+ if cache_prefix not in model_cache:
+ model_cache[cache_prefix] = {}
+ layer_cache = model_cache[cache_prefix]
+
+ q = self.atom_decoder(
+ q=q,
+ mask=atom_mask,
+ c=c,
+ p=p,
+ multiplicity=multiplicity,
+ to_keys=to_keys,
+ model_cache=layer_cache,
+ )
+
+ r_update = self.atom_feat_to_atom_pos_update(q)
+ return r_update
diff --git a/src/boltz/model/modules/transformers.py b/src/boltz/model/modules/transformers.py
new file mode 100644
index 0000000..b1e1fba
--- /dev/null
+++ b/src/boltz/model/modules/transformers.py
@@ -0,0 +1,322 @@
+# started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
+
+from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
+from torch import nn, sigmoid
+from torch.nn import (
+ LayerNorm,
+ Linear,
+ Module,
+ ModuleList,
+ Sequential,
+)
+
+from boltz.model.layers.attention import AttentionPairBias
+from boltz.model.modules.utils import LinearNoBias, SwiGLU, default
+
+
+class AdaLN(Module):
+ """Adaptive Layer Normalization"""
+
+ def __init__(self, dim, dim_single_cond):
+ """Initialize the adaptive layer normalization.
+
+ Parameters
+ ----------
+ dim : int
+ The input dimension.
+ dim_single_cond : int
+ The single condition dimension.
+
+ """
+ super().__init__()
+ self.a_norm = LayerNorm(dim, elementwise_affine=False, bias=False)
+ self.s_norm = LayerNorm(dim_single_cond, bias=False)
+ self.s_scale = Linear(dim_single_cond, dim)
+ self.s_bias = LinearNoBias(dim_single_cond, dim)
+
+ def forward(self, a, s):
+ a = self.a_norm(a)
+ s = self.s_norm(s)
+ a = sigmoid(self.s_scale(s)) * a + self.s_bias(s)
+ return a
+
+
+class ConditionedTransitionBlock(Module):
+ """Conditioned Transition Block"""
+
+ def __init__(self, dim_single, dim_single_cond, expansion_factor=2):
+ """Initialize the conditioned transition block.
+
+ Parameters
+ ----------
+ dim_single : int
+ The single dimension.
+ dim_single_cond : int
+ The single condition dimension.
+ expansion_factor : int, optional
+ The expansion factor, by default 2
+
+ """
+ super().__init__()
+
+ self.adaln = AdaLN(dim_single, dim_single_cond)
+
+ dim_inner = int(dim_single * expansion_factor)
+ self.swish_gate = Sequential(
+ LinearNoBias(dim_single, dim_inner * 2),
+ SwiGLU(),
+ )
+ self.a_to_b = LinearNoBias(dim_single, dim_inner)
+ self.b_to_a = LinearNoBias(dim_inner, dim_single)
+
+ output_projection_linear = Linear(dim_single_cond, dim_single)
+ nn.init.zeros_(output_projection_linear.weight)
+ nn.init.constant_(output_projection_linear.bias, -2.0)
+
+ self.output_projection = nn.Sequential(output_projection_linear, nn.Sigmoid())
+
+ def forward(
+ self,
+ a,
+ s,
+ ):
+ a = self.adaln(a, s)
+ b = self.swish_gate(a) * self.a_to_b(a)
+ a = self.output_projection(s) * self.b_to_a(b)
+
+ return a
+
+
+class DiffusionTransformer(Module):
+ """Diffusion Transformer"""
+
+ def __init__(
+ self,
+ depth,
+ heads,
+ dim=384,
+ dim_single_cond=None,
+ dim_pairwise=128,
+ activation_checkpointing=False,
+ offload_to_cpu=False,
+ ):
+ """Initialize the diffusion transformer.
+
+ Parameters
+ ----------
+ depth : int
+ The depth.
+ heads : int
+ The number of heads.
+ dim : int, optional
+ The dimension, by default 384
+ dim_single_cond : int, optional
+ The single condition dimension, by default None
+ dim_pairwise : int, optional
+ The pairwise dimension, by default 128
+ activation_checkpointing : bool, optional
+ Whether to use activation checkpointing, by default False
+ offload_to_cpu : bool, optional
+ Whether to offload to CPU, by default False
+
+ """
+ super().__init__()
+ self.activation_checkpointing = activation_checkpointing
+ dim_single_cond = default(dim_single_cond, dim)
+
+ self.layers = ModuleList()
+ for _ in range(depth):
+ if activation_checkpointing:
+ self.layers.append(
+ checkpoint_wrapper(
+ DiffusionTransformerLayer(
+ heads,
+ dim,
+ dim_single_cond,
+ dim_pairwise,
+ ),
+ offload_to_cpu=offload_to_cpu,
+ )
+ )
+ else:
+ self.layers.append(
+ DiffusionTransformerLayer(
+ heads,
+ dim,
+ dim_single_cond,
+ dim_pairwise,
+ )
+ )
+
+ def forward(
+ self,
+ a,
+ s,
+ z,
+ mask=None,
+ to_keys=None,
+ multiplicity=1,
+ model_cache=None,
+ ):
+ for i, layer in enumerate(self.layers):
+ layer_cache = None
+ if model_cache is not None:
+ prefix_cache = "layer_" + str(i)
+ if prefix_cache not in model_cache:
+ model_cache[prefix_cache] = {}
+ layer_cache = model_cache[prefix_cache]
+ a = layer(
+ a,
+ s,
+ z,
+ mask=mask,
+ to_keys=to_keys,
+ multiplicity=multiplicity,
+ layer_cache=layer_cache,
+ )
+ return a
+
+
+class DiffusionTransformerLayer(Module):
+ """Diffusion Transformer Layer"""
+
+ def __init__(
+ self,
+ heads,
+ dim=384,
+ dim_single_cond=None,
+ dim_pairwise=128,
+ ):
+ """Initialize the diffusion transformer layer.
+
+ Parameters
+ ----------
+ heads : int
+ The number of heads.
+ dim : int, optional
+ The dimension, by default 384
+ dim_single_cond : int, optional
+ The single condition dimension, by default None
+ dim_pairwise : int, optional
+ The pairwise dimension, by default 128
+
+ """
+ super().__init__()
+
+ dim_single_cond = default(dim_single_cond, dim)
+
+ self.adaln = AdaLN(dim, dim_single_cond)
+
+ self.pair_bias_attn = AttentionPairBias(
+ c_s=dim, c_z=dim_pairwise, num_heads=heads, initial_norm=False
+ )
+
+ self.output_projection_linear = Linear(dim_single_cond, dim)
+ nn.init.zeros_(self.output_projection_linear.weight)
+ nn.init.constant_(self.output_projection_linear.bias, -2.0)
+
+ self.output_projection = nn.Sequential(
+ self.output_projection_linear, nn.Sigmoid()
+ )
+ self.transition = ConditionedTransitionBlock(
+ dim_single=dim, dim_single_cond=dim_single_cond
+ )
+
+ def forward(
+ self,
+ a,
+ s,
+ z,
+ mask=None,
+ to_keys=None,
+ multiplicity=1,
+ layer_cache=None,
+ ):
+ b = self.adaln(a, s)
+ b = self.pair_bias_attn(
+ s=b,
+ z=z,
+ mask=mask,
+ multiplicity=multiplicity,
+ to_keys=to_keys,
+ model_cache=layer_cache,
+ )
+ b = self.output_projection(s) * b
+
+ # NOTE: Added residual connection!
+ a = a + b
+ a = a + self.transition(a, s)
+ return a
+
+
+class AtomTransformer(Module):
+ """Atom Transformer"""
+
+ def __init__(
+ self,
+ attn_window_queries=None,
+ attn_window_keys=None,
+ **diffusion_transformer_kwargs,
+ ):
+ """Initialize the atom transformer.
+
+ Parameters
+ ----------
+ attn_window_queries : int, optional
+ The attention window queries, by default None
+ attn_window_keys : int, optional
+ The attention window keys, by default None
+ diffusion_transformer_kwargs : dict
+ The diffusion transformer keyword arguments
+
+ """
+ super().__init__()
+ self.attn_window_queries = attn_window_queries
+ self.attn_window_keys = attn_window_keys
+ self.diffusion_transformer = DiffusionTransformer(
+ **diffusion_transformer_kwargs
+ )
+
+ def forward(
+ self,
+ q,
+ c,
+ p,
+ to_keys=None,
+ mask=None,
+ multiplicity=1,
+ model_cache=None,
+ ):
+ W = self.attn_window_queries
+ H = self.attn_window_keys
+
+ if W is not None:
+ B, N, D = q.shape
+ NW = N // W
+
+ # reshape tokens
+ q = q.view((B * NW, W, -1))
+ c = c.view((B * NW, W, -1))
+ if mask is not None:
+ mask = mask.view(B * NW, W)
+ p = p.view((p.shape[0] * NW, W, H, -1))
+
+ to_keys_new = lambda x: to_keys(x.view(B, NW * W, -1)).view(B * NW, H, -1)
+ else:
+ to_keys_new = None
+
+ # main transformer
+ q = self.diffusion_transformer(
+ a=q,
+ s=c,
+ z=p,
+ mask=mask.float(),
+ multiplicity=multiplicity,
+ to_keys=to_keys_new,
+ model_cache=model_cache,
+ )
+
+ if W is not None:
+ q = q.view((B, NW * W, D))
+
+ return q
diff --git a/src/boltz/model/modules/trunk.py b/src/boltz/model/modules/trunk.py
new file mode 100644
index 0000000..9c5e786
--- /dev/null
+++ b/src/boltz/model/modules/trunk.py
@@ -0,0 +1,661 @@
+from typing import Dict, Tuple
+
+from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper
+import torch
+from torch import Tensor, nn
+
+from boltz.data import const
+from boltz.model.layers.attention import AttentionPairBias
+from boltz.model.layers.dropout import get_dropout_mask
+from boltz.model.layers.outer_product_mean import OuterProductMean
+from boltz.model.layers.pair_averaging import PairWeightedAveraging
+from boltz.model.layers.transition import Transition
+from boltz.model.layers.triangular_attention.attention import (
+ TriangleAttentionEndingNode,
+ TriangleAttentionStartingNode,
+)
+from boltz.model.layers.triangular_mult import (
+ TriangleMultiplicationIncoming,
+ TriangleMultiplicationOutgoing,
+)
+from boltz.model.modules.encoders import AtomAttentionEncoder
+
+
+class InputEmbedder(nn.Module):
+ """Input embedder."""
+
+ def __init__(
+ self,
+ atom_s: int,
+ atom_z: int,
+ token_s: int,
+ token_z: int,
+ atoms_per_window_queries: int,
+ atoms_per_window_keys: int,
+ atom_feature_dim: int,
+ atom_encoder_depth: int,
+ atom_encoder_heads: int,
+ no_atom_encoder: bool = False,
+ ) -> None:
+ """Initialize the input embedder.
+
+ Parameters
+ ----------
+ atom_s : int
+ The atom single representation dimension.
+ atom_z : int
+ The atom pair representation dimension.
+ token_s : int
+ The single token representation dimension.
+ token_z : int
+ The pair token representation dimension.
+ atoms_per_window_queries : int
+ The number of atoms per window for queries.
+ atoms_per_window_keys : int
+ The number of atoms per window for keys.
+ atom_feature_dim : int
+ The atom feature dimension.
+ atom_encoder_depth : int
+ The atom encoder depth.
+ atom_encoder_heads : int
+ The atom encoder heads.
+ no_atom_encoder : bool, optional
+ Whether to use the atom encoder, by default False
+
+ """
+ super().__init__()
+ self.token_s = token_s
+ self.no_atom_encoder = no_atom_encoder
+
+ if not no_atom_encoder:
+ self.atom_attention_encoder = AtomAttentionEncoder(
+ atom_s=atom_s,
+ atom_z=atom_z,
+ token_s=token_s,
+ token_z=token_z,
+ atoms_per_window_queries=atoms_per_window_queries,
+ atoms_per_window_keys=atoms_per_window_keys,
+ atom_feature_dim=atom_feature_dim,
+ atom_encoder_depth=atom_encoder_depth,
+ atom_encoder_heads=atom_encoder_heads,
+ structure_prediction=False,
+ )
+
+ def forward(self, feats: Dict[str, Tensor]) -> Tensor:
+ """Perform the forward pass.
+
+ Parameters
+ ----------
+ feats : Dict[str, Tensor]
+ Input features
+
+ Returns
+ -------
+ Tensor
+ The embedded tokens.
+
+ """
+ # Load relevant features
+ res_type = feats["res_type"]
+ profile = feats["profile"]
+ deletion_mean = feats["deletion_mean"].unsqueeze(-1)
+ pocket_feature = feats["pocket_feature"]
+
+ # Compute input embedding
+ if self.no_atom_encoder:
+ a = torch.zeros(
+ (res_type.shape[0], res_type.shape[1], self.token_s),
+ device=res_type.device,
+ )
+ else:
+ a, _, _, _, _ = self.atom_attention_encoder(feats)
+ s = torch.cat([a, res_type, profile, deletion_mean, pocket_feature], dim=-1)
+ return s
+
+
+class MSAModule(nn.Module):
+ """MSA module."""
+
+ def __init__(
+ self,
+ msa_s: int,
+ token_z: int,
+ s_input_dim: int,
+ msa_blocks: int,
+ msa_dropout: float,
+ z_dropout: float,
+ pairwise_head_width: int = 32,
+ pairwise_num_heads: int = 4,
+ activation_checkpointing: bool = False,
+ use_paired_feature: bool = False,
+ offload_to_cpu: bool = False,
+ chunk_heads_pwa: bool = False,
+ chunk_size_transition_z: int = None,
+ chunk_size_transition_msa: int = None,
+ chunk_size_outer_product: int = None,
+ chunk_size_tri_attn: int = None,
+ **kwargs,
+ ) -> None:
+ """Initialize the MSA module.
+
+ Parameters
+ ----------
+ msa_s : int
+ The MSA embedding size.
+ token_z : int
+ The token pairwise embedding size.
+ s_input_dim : int
+ The input sequence dimension.
+ msa_blocks : int
+ The number of MSA blocks.
+ msa_dropout : float
+ The MSA dropout.
+ z_dropout : float
+ The pairwise dropout.
+ pairwise_head_width : int, optional
+ The pairwise head width, by default 32
+ pairwise_num_heads : int, optional
+ The number of pairwise heads, by default 4
+ activation_checkpointing : bool, optional
+ Whether to use activation checkpointing, by default False
+ use_paired_feature : bool, optional
+ Whether to use the paired feature, by default False
+ offload_to_cpu : bool, optional
+ Whether to offload to CPU, by default False
+ chunk_heads_pwa : bool, optional
+ Chunk heads for PWA, by default False
+ chunk_size_transition_z : int, optional
+ Chunk size for transition Z, by default None
+ chunk_size_transition_msa : int, optional
+ Chunk size for transition MSA, by default None
+ chunk_size_outer_product : int, optional
+ Chunk size for outer product, by default None
+ chunk_size_tri_attn : int, optional
+ Chunk size for triangle attention, by default None
+
+ """
+ super().__init__()
+ self.msa_blocks = msa_blocks
+ self.msa_dropout = msa_dropout
+ self.z_dropout = z_dropout
+ self.use_paired_feature = use_paired_feature
+
+ self.s_proj = nn.Linear(s_input_dim, msa_s, bias=False)
+ self.msa_proj = nn.Linear(
+ const.num_tokens + 2 + int(use_paired_feature),
+ msa_s,
+ bias=False,
+ )
+ self.layers = nn.ModuleList()
+ for i in range(msa_blocks):
+ if activation_checkpointing:
+ self.layers.append(
+ checkpoint_wrapper(
+ MSALayer(
+ msa_s,
+ token_z,
+ msa_dropout,
+ z_dropout,
+ pairwise_head_width,
+ pairwise_num_heads,
+ chunk_heads_pwa=chunk_heads_pwa,
+ chunk_size_transition_z=chunk_size_transition_z,
+ chunk_size_transition_msa=chunk_size_transition_msa,
+ chunk_size_outer_product=chunk_size_outer_product,
+ chunk_size_tri_attn=chunk_size_tri_attn,
+ ),
+ offload_to_cpu=offload_to_cpu,
+ )
+ )
+ else:
+ self.layers.append(
+ MSALayer(
+ msa_s,
+ token_z,
+ msa_dropout,
+ z_dropout,
+ pairwise_head_width,
+ pairwise_num_heads,
+ chunk_heads_pwa=chunk_heads_pwa,
+ chunk_size_transition_z=chunk_size_transition_z,
+ chunk_size_transition_msa=chunk_size_transition_msa,
+ chunk_size_outer_product=chunk_size_outer_product,
+ chunk_size_tri_attn=chunk_size_tri_attn,
+ )
+ )
+
+ def forward(self, z: Tensor, emb: Tensor, feats: Dict[str, Tensor]) -> Tensor:
+ """Perform the forward pass.
+
+ Parameters
+ ----------
+ z : Tensor
+ The pairwise embeddings
+ emb : Tensor
+ The input embeddings
+ feats : Dict[str, Tensor]
+ Input features
+
+ Returns
+ -------
+ Tensor
+ The output pairwise embeddings.
+
+ """
+ # Load relevant features
+ msa = feats["msa"]
+ has_deletion = feats["has_deletion"].unsqueeze(-1)
+ deletion_value = feats["deletion_value"].unsqueeze(-1)
+ is_paired = feats["msa_paired"].unsqueeze(-1)
+ msa_mask = feats["msa_mask"]
+ token_mask = feats["token_pad_mask"].float()
+ token_mask = token_mask[:, :, None] * token_mask[:, None, :]
+
+ # Compute MSA embeddings
+ if self.use_paired_feature:
+ m = torch.cat([msa, has_deletion, deletion_value, is_paired], dim=-1)
+ else:
+ m = torch.cat([msa, has_deletion, deletion_value], dim=-1)
+
+ # Compute input projections
+ m = self.msa_proj(m)
+ m = m + self.s_proj(emb).unsqueeze(1)
+
+ # Perform MSA blocks
+ for i in range(self.msa_blocks):
+ z, m = self.layers[i](z, m, token_mask, msa_mask)
+ return z
+
+
+class MSALayer(nn.Module):
+ """MSA module."""
+
+ def __init__(
+ self,
+ msa_s: int,
+ token_z: int,
+ msa_dropout: float,
+ z_dropout: float,
+ pairwise_head_width: int = 32,
+ pairwise_num_heads: int = 4,
+ chunk_heads_pwa: bool = False,
+ chunk_size_transition_z: int = None,
+ chunk_size_transition_msa: int = None,
+ chunk_size_outer_product: int = None,
+ chunk_size_tri_attn: int = None,
+ ) -> None:
+ """Initialize the MSA module.
+
+ Parameters
+ ----------
+
+ msa_s : int
+ The MSA embedding size.
+ token_z : int
+ The pair representation dimention.
+ msa_dropout : float
+ The MSA dropout.
+ z_dropout : float
+ The pair dropout.
+ pairwise_head_width : int, optional
+ The pairwise head width, by default 32
+ pairwise_num_heads : int, optional
+ The number of pairwise heads, by default 4
+ chunk_heads_pwa : bool, optional
+ Chunk heads for PWA, by default False
+ chunk_size_transition_z : int, optional
+ Chunk size for transition Z, by default None
+ chunk_size_transition_msa : int, optional
+ Chunk size for transition MSA, by default None
+ chunk_size_outer_product : int, optional
+ Chunk size for outer product, by default None
+ chunk_size_tri_attn : int, optional
+ Chunk size for triangle attention, by default None
+
+ """
+ super().__init__()
+ self.msa_dropout = msa_dropout
+ self.z_dropout = z_dropout
+ self.chunk_size_tri_attn = chunk_size_tri_attn
+ self.msa_transition = Transition(
+ dim=msa_s, hidden=msa_s * 4, chunk_size=chunk_size_transition_msa
+ )
+ self.pair_weighted_averaging = PairWeightedAveraging(
+ c_m=msa_s,
+ c_z=token_z,
+ c_h=32,
+ num_heads=8,
+ chunk_heads=chunk_heads_pwa,
+ )
+
+ self.tri_mul_out = TriangleMultiplicationOutgoing(token_z)
+ self.tri_mul_in = TriangleMultiplicationIncoming(token_z)
+ self.tri_att_start = TriangleAttentionStartingNode(
+ token_z, pairwise_head_width, pairwise_num_heads, inf=1e9
+ )
+ self.tri_att_end = TriangleAttentionEndingNode(
+ token_z, pairwise_head_width, pairwise_num_heads, inf=1e9
+ )
+ self.z_transition = Transition(
+ dim=token_z,
+ hidden=token_z * 4,
+ chunk_size=chunk_size_transition_z,
+ )
+ self.outer_product_mean = OuterProductMean(
+ c_in=msa_s,
+ c_hidden=32,
+ c_out=token_z,
+ chunk_size=chunk_size_outer_product,
+ )
+
+ def forward(
+ self, z: Tensor, m: Tensor, token_mask: Tensor, msa_mask: Tensor
+ ) -> Tuple[Tensor, Tensor]:
+ """Perform the forward pass.
+
+ Parameters
+ ----------
+ z : Tensor
+ The pair representation
+ m : Tensor
+ The msa representation
+ token_mask : Tensor
+ The token mask
+ msa_mask : Dict[str, Tensor]
+ The MSA mask
+
+ Returns
+ -------
+ Tensor
+ The output pairwise embeddings.
+ Tensor
+ The output MSA embeddings.
+
+ """
+ # Communication to MSA stack
+ msa_dropout = get_dropout_mask(self.msa_dropout, m, self.training)
+ m = m + msa_dropout * self.pair_weighted_averaging(m, z, token_mask)
+ m = m + self.msa_transition(m)
+
+ # Communication to pairwise stack
+ z = z + self.outer_product_mean(m, msa_mask)
+
+ # Compute pairwise stack
+ dropout = get_dropout_mask(self.z_dropout, z, self.training)
+ z = z + dropout * self.tri_mul_out(z, mask=token_mask)
+
+ dropout = get_dropout_mask(self.z_dropout, z, self.training)
+ z = z + dropout * self.tri_mul_in(z, mask=token_mask)
+
+ dropout = get_dropout_mask(self.z_dropout, z, self.training)
+ z = z + dropout * self.tri_att_start(
+ z,
+ mask=token_mask,
+ chunk_size=self.chunk_size_tri_attn if not self.training else None,
+ )
+
+ dropout = get_dropout_mask(self.z_dropout, z, self.training, columnwise=True)
+ z = z + dropout * self.tri_att_end(
+ z,
+ mask=token_mask,
+ chunk_size=self.chunk_size_tri_attn if not self.training else None,
+ )
+
+ z = z + self.z_transition(z)
+
+ return z, m
+
+
+class PairformerModule(nn.Module):
+ """Pairformer module."""
+
+ def __init__(
+ self,
+ token_s: int,
+ token_z: int,
+ num_blocks: int,
+ num_heads: int = 16,
+ dropout: float = 0.25,
+ pairwise_head_width: int = 32,
+ pairwise_num_heads: int = 4,
+ activation_checkpointing: bool = False,
+ no_update_s: bool = False,
+ no_update_z: bool = False,
+ offload_to_cpu: bool = False,
+ chunk_size_tri_attn: int = None,
+ **kwargs,
+ ) -> None:
+ """Initialize the Pairformer module.
+
+ Parameters
+ ----------
+ token_s : int
+ The token single embedding size.
+ token_z : int
+ The token pairwise embedding size.
+ num_blocks : int
+ The number of blocks.
+ num_heads : int, optional
+ The number of heads, by default 16
+ dropout : float, optional
+ The dropout rate, by default 0.25
+ pairwise_head_width : int, optional
+ The pairwise head width, by default 32
+ pairwise_num_heads : int, optional
+ The number of pairwise heads, by default 4
+ activation_checkpointing : bool, optional
+ Whether to use activation checkpointing, by default False
+ no_update_s : bool, optional
+ Whether to update the single embeddings, by default False
+ no_update_z : bool, optional
+ Whether to update the pairwise embeddings, by default False
+ offload_to_cpu : bool, optional
+ Whether to offload to CPU, by default False
+ chunk_size_tri_attn : int, optional
+ The chunk size for triangle attention, by default None
+
+ """
+ super().__init__()
+ self.token_z = token_z
+ self.num_blocks = num_blocks
+ self.dropout = dropout
+ self.num_heads = num_heads
+
+ self.layers = nn.ModuleList()
+ for i in range(num_blocks):
+ if activation_checkpointing:
+ self.layers.append(
+ checkpoint_wrapper(
+ PairformerLayer(
+ token_s,
+ token_z,
+ num_heads,
+ dropout,
+ pairwise_head_width,
+ pairwise_num_heads,
+ no_update_s,
+ False if i < num_blocks - 1 else no_update_z,
+ chunk_size_tri_attn,
+ ),
+ offload_to_cpu=offload_to_cpu,
+ )
+ )
+ else:
+ self.layers.append(
+ PairformerLayer(
+ token_s,
+ token_z,
+ num_heads,
+ dropout,
+ pairwise_head_width,
+ pairwise_num_heads,
+ no_update_s,
+ False if i < num_blocks - 1 else no_update_z,
+ chunk_size_tri_attn,
+ )
+ )
+
+ def forward(
+ self,
+ s: Tensor,
+ z: Tensor,
+ mask: Tensor,
+ pair_mask: Tensor,
+ ) -> Tuple[Tensor, Tensor]:
+ """Perform the forward pass.
+
+ Parameters
+ ----------
+ s : Tensor
+ The sequence embeddings
+ z : Tensor
+ The pairwise embeddings
+ mask : Tensor
+ The token mask
+ pair_mask : Tensor
+ The pairwise mask
+ Returns
+ -------
+ Tensor
+ The updated sequence embeddings.
+ Tensor
+ The updated pairwise embeddings.
+
+ """
+ for layer in self.layers:
+ s, z = layer(s, z, mask, pair_mask)
+ return s, z
+
+
+class PairformerLayer(nn.Module):
+ """Pairformer module."""
+
+ def __init__(
+ self,
+ token_s: int,
+ token_z: int,
+ num_heads: int = 16,
+ dropout: float = 0.25,
+ pairwise_head_width: int = 32,
+ pairwise_num_heads: int = 4,
+ no_update_s: bool = False,
+ no_update_z: bool = False,
+ chunk_size_tri_attn: int = None,
+ ) -> None:
+ """Initialize the Pairformer module.
+
+ Parameters
+ ----------
+ token_s : int
+ The token single embedding size.
+ token_z : int
+ The token pairwise embedding size.
+ num_heads : int, optional
+ The number of heads, by default 16
+ dropout : float, optiona
+ The dropout rate, by default 0.25
+ pairwise_head_width : int, optional
+ The pairwise head width, by default 32
+ pairwise_num_heads : int, optional
+ The number of pairwise heads, by default 4
+ no_update_s : bool, optional
+ Whether to update the single embeddings, by default False
+ no_update_z : bool, optional
+ Whether to update the pairwise embeddings, by default False
+ chunk_size_tri_attn : int, optional
+ The chunk size for triangle attention, by default None
+
+ """
+ super().__init__()
+ self.token_z = token_z
+ self.dropout = dropout
+ self.num_heads = num_heads
+ self.no_update_s = no_update_s
+ self.no_update_z = no_update_z
+ self.chunk_size_tri_attn = chunk_size_tri_attn
+ if not self.no_update_s:
+ self.attention = AttentionPairBias(token_s, token_z, num_heads)
+ self.tri_mul_out = TriangleMultiplicationOutgoing(token_z)
+ self.tri_mul_in = TriangleMultiplicationIncoming(token_z)
+ self.tri_att_start = TriangleAttentionStartingNode(
+ token_z, pairwise_head_width, pairwise_num_heads, inf=1e9
+ )
+ self.tri_att_end = TriangleAttentionEndingNode(
+ token_z, pairwise_head_width, pairwise_num_heads, inf=1e9
+ )
+ if not self.no_update_s:
+ self.transition_s = Transition(token_s, token_s * 4)
+ self.transition_z = Transition(token_z, token_z * 4)
+
+ def forward(
+ self,
+ s: Tensor,
+ z: Tensor,
+ mask: Tensor,
+ pair_mask: Tensor,
+ ) -> Tuple[Tensor, Tensor]:
+ """Perform the forward pass."""
+ # Compute pairwise stack
+ dropout = get_dropout_mask(self.dropout, z, self.training)
+ z = z + dropout * self.tri_mul_out(z, mask=pair_mask)
+
+ dropout = get_dropout_mask(self.dropout, z, self.training)
+ z = z + dropout * self.tri_mul_in(z, mask=pair_mask)
+
+ dropout = get_dropout_mask(self.dropout, z, self.training)
+ z = z + dropout * self.tri_att_start(
+ z,
+ mask=pair_mask,
+ chunk_size=self.chunk_size_tri_attn if not self.training else None,
+ )
+
+ dropout = get_dropout_mask(self.dropout, z, self.training, columnwise=True)
+ z = z + dropout * self.tri_att_end(
+ z,
+ mask=pair_mask,
+ chunk_size=self.chunk_size_tri_attn if not self.training else None,
+ )
+
+ z = z + self.transition_z(z)
+
+ # Compute sequence stack
+ if not self.no_update_s:
+ s = s + self.attention(s, z, mask)
+ s = s + self.transition_s(s)
+
+ return s, z
+
+
+class DistogramModule(nn.Module):
+ """Distogram Module."""
+
+ def __init__(self, token_z: int, num_bins: int) -> None:
+ """Initialize the distogram module.
+
+ Parameters
+ ----------
+ token_z : int
+ The token pairwise embedding size.
+ num_bins : int
+ The number of bins.
+
+ """
+ super().__init__()
+ self.distogram = nn.Linear(token_z, num_bins)
+
+ def forward(self, z: Tensor) -> Tensor:
+ """Perform the forward pass.
+
+ Parameters
+ ----------
+ z : Tensor
+ The pairwise embeddings
+
+ Returns
+ -------
+ Tensor
+ The predicted distogram.
+
+ """
+ z = z + z.transpose(1, 2)
+ return self.distogram(z)
diff --git a/src/boltz/model/modules/utils.py b/src/boltz/model/modules/utils.py
new file mode 100644
index 0000000..69b9ad3
--- /dev/null
+++ b/src/boltz/model/modules/utils.py
@@ -0,0 +1,307 @@
+# started from code from https://github.com/lucidrains/alphafold3-pytorch, MIT License, Copyright (c) 2024 Phil Wang
+
+from functools import partial
+from typing import Optional
+
+import torch
+from torch.nn import (
+ Module,
+ Linear,
+)
+import torch.nn.functional as F
+from torch.types import Device
+
+LinearNoBias = partial(Linear, bias=False)
+
+
+def exists(v):
+ return v is not None
+
+
+def default(v, d):
+ return v if exists(v) else d
+
+
+def log(t, eps=1e-20):
+ return torch.log(t.clamp(min=eps))
+
+
+class SwiGLU(Module):
+ def forward(
+ self,
+ x,
+ ):
+ x, gates = x.chunk(2, dim=-1)
+ return F.silu(gates) * x
+
+
+def randomly_rotate(coords, return_second_coords=False, second_coords=None):
+ R = random_rotations(len(coords), coords.dtype, coords.device)
+
+ if return_second_coords:
+ return torch.einsum("bmd,bds->bms", coords, R), (
+ torch.einsum("bmd,bds->bms", second_coords, R)
+ if second_coords is not None
+ else None
+ )
+
+ return torch.einsum("bmd,bds->bms", coords, R)
+
+
+def center_random_augmentation(
+ atom_coords,
+ atom_mask,
+ s_trans=1.0,
+ augmentation=True,
+ centering=True,
+ return_second_coords=False,
+ second_coords=None,
+):
+ """Center and randomly augment the input coordinates.
+
+ Parameters
+ ----------
+ atom_coords : Tensor
+ The atom coordinates.
+ atom_mask : Tensor
+ The atom mask.
+ s_trans : float, optional
+ The translation factor, by default 1.0
+ augmentation : bool, optional
+ Whether to add rotational and translational augmentation the input, by default True
+ centering : bool, optional
+ Whether to center the input, by default True
+
+ Returns
+ -------
+ Tensor
+ The augmented atom coordinates.
+
+ """
+ if centering:
+ atom_mean = torch.sum(
+ atom_coords * atom_mask[:, :, None], dim=1, keepdim=True
+ ) / torch.sum(atom_mask[:, :, None], dim=1, keepdim=True)
+ atom_coords = atom_coords - atom_mean
+
+ if second_coords is not None:
+ # apply same transformation also to this input
+ second_coords = second_coords - atom_mean
+
+ if augmentation:
+ atom_coords, second_coords = randomly_rotate(
+ atom_coords, return_second_coords=True, second_coords=second_coords
+ )
+ random_trans = torch.randn_like(atom_coords[:, 0:1, :]) * s_trans
+ atom_coords = atom_coords + random_trans
+
+ if second_coords is not None:
+ second_coords = second_coords + random_trans
+
+ if return_second_coords:
+ return atom_coords, second_coords
+
+ return atom_coords
+
+
+class ExponentialMovingAverage:
+ """from https://github.com/yang-song/score_sde_pytorch/blob/main/models/ema.py, Apache-2.0 license
+ Maintains (exponential) moving average of a set of parameters."""
+
+ def __init__(self, parameters, decay, use_num_updates=True):
+ """
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; usually the result of
+ `model.parameters()`.
+ decay: The exponential decay.
+ use_num_updates: Whether to use number of updates when computing
+ averages.
+ """
+ if decay < 0.0 or decay > 1.0:
+ raise ValueError("Decay must be between 0 and 1")
+ self.decay = decay
+ self.num_updates = 0 if use_num_updates else None
+ self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad]
+ self.collected_params = []
+
+ def update(self, parameters):
+ """
+ Update currently maintained parameters.
+ Call this every time the parameters are updated, such as the result of
+ the `optimizer.step()` call.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; usually the same set of
+ parameters used to initialize this object.
+ """
+ decay = self.decay
+ if self.num_updates is not None:
+ self.num_updates += 1
+ decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates))
+ one_minus_decay = 1.0 - decay
+ with torch.no_grad():
+ parameters = [p for p in parameters if p.requires_grad]
+ for s_param, param in zip(self.shadow_params, parameters):
+ s_param.sub_(one_minus_decay * (s_param - param))
+
+ def compatible(self, parameters):
+ if len(self.shadow_params) != len(parameters):
+ print(
+ f"Model has {len(self.shadow_params)} parameter tensors, the incoming ema {len(parameters)}"
+ )
+ return False
+
+ for s_param, param in zip(self.shadow_params, parameters):
+ if param.data.shape != s_param.data.shape:
+ print(
+ f"Model has parameter tensor of shape {s_param.data.shape} , the incoming ema {param.data.shape}"
+ )
+ return False
+ return True
+
+ def copy_to(self, parameters):
+ """
+ Copy current parameters into given collection of parameters.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored moving averages.
+ """
+ parameters = [p for p in parameters if p.requires_grad]
+ for s_param, param in zip(self.shadow_params, parameters):
+ if param.requires_grad:
+ param.data.copy_(s_param.data)
+
+ def store(self, parameters):
+ """
+ Save the current parameters for restoring later.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored.
+ """
+ self.collected_params = [param.clone() for param in parameters]
+
+ def restore(self, parameters):
+ """
+ Restore the parameters stored with the `store` method.
+ Useful to validate the model with EMA parameters without affecting the
+ original optimization process. Store the parameters before the
+ `copy_to` method. After validation (or model saving), use this to
+ restore the former parameters.
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters.
+ """
+ for c_param, param in zip(self.collected_params, parameters):
+ param.data.copy_(c_param.data)
+
+ def state_dict(self):
+ return dict(
+ decay=self.decay,
+ num_updates=self.num_updates,
+ shadow_params=self.shadow_params,
+ )
+
+ def load_state_dict(self, state_dict, device):
+ self.decay = state_dict["decay"]
+ self.num_updates = state_dict["num_updates"]
+ self.shadow_params = [
+ tensor.to(device) for tensor in state_dict["shadow_params"]
+ ]
+
+ def to(self, device):
+ self.shadow_params = [tensor.to(device) for tensor in self.shadow_params]
+
+
+# the following is copied from Torch3D, BSD License, Copyright (c) Meta Platforms, Inc. and affiliates.
+
+
+def _copysign(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
+ """
+ Return a tensor where each element has the absolute value taken from the,
+ corresponding element of a, with sign taken from the corresponding
+ element of b. This is like the standard copysign floating-point operation,
+ but is not careful about negative 0 and NaN.
+
+ Args:
+ a: source tensor.
+ b: tensor whose signs will be used, of the same shape as a.
+
+ Returns:
+ Tensor of the same shape as a with the signs of b.
+ """
+ signs_differ = (a < 0) != (b < 0)
+ return torch.where(signs_differ, -a, a)
+
+
+def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as quaternions to rotation matrices.
+
+ Args:
+ quaternions: quaternions with real part first,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ r, i, j, k = torch.unbind(quaternions, -1)
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
+
+ o = torch.stack(
+ (
+ 1 - two_s * (j * j + k * k),
+ two_s * (i * j - k * r),
+ two_s * (i * k + j * r),
+ two_s * (i * j + k * r),
+ 1 - two_s * (i * i + k * k),
+ two_s * (j * k - i * r),
+ two_s * (i * k - j * r),
+ two_s * (j * k + i * r),
+ 1 - two_s * (i * i + j * j),
+ ),
+ -1,
+ )
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
+
+
+def random_quaternions(
+ n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
+) -> torch.Tensor:
+ """
+ Generate random quaternions representing rotations,
+ i.e. versors with nonnegative real part.
+
+ Args:
+ n: Number of quaternions in a batch to return.
+ dtype: Type to return.
+ device: Desired device of returned tensor. Default:
+ uses the current device for the default tensor type.
+
+ Returns:
+ Quaternions as tensor of shape (N, 4).
+ """
+ if isinstance(device, str):
+ device = torch.device(device)
+ o = torch.randn((n, 4), dtype=dtype, device=device)
+ s = (o * o).sum(1)
+ o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
+ return o
+
+
+def random_rotations(
+ n: int, dtype: Optional[torch.dtype] = None, device: Optional[Device] = None
+) -> torch.Tensor:
+ """
+ Generate random rotations as 3x3 rotation matrices.
+
+ Args:
+ n: Number of rotation matrices in a batch to return.
+ dtype: Type to return.
+ device: Device of returned tensor. Default: if None,
+ uses the current device for the default tensor type.
+
+ Returns:
+ Rotation matrices as tensor of shape (n, 3, 3).
+ """
+ quaternions = random_quaternions(n, dtype=dtype, device=device)
+ return quaternion_to_matrix(quaternions)
diff --git a/src/boltz/model/optim/__init__.py b/src/boltz/model/optim/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/boltz/model/optim/ema.py b/src/boltz/model/optim/ema.py
new file mode 100644
index 0000000..1e370c3
--- /dev/null
+++ b/src/boltz/model/optim/ema.py
@@ -0,0 +1,382 @@
+# --------------------------------------------------------------------------------------
+# Modified from Bio-Diffusion (https://github.com/BioinfoMachineLearning/bio-diffusion):
+# Modified from : https://github.com/BioinfoMachineLearning/bio-diffusion/blob/main/src/utils/__init__.py
+# --------------------------------------------------------------------------------------
+
+from typing import Any, Dict, Optional
+
+import torch
+from pytorch_lightning import Callback, LightningModule, Trainer
+from pytorch_lightning.utilities.exceptions import MisconfigurationException
+from pytorch_lightning.utilities.types import STEP_OUTPUT
+
+
+class EMA(Callback):
+ """Implements Exponential Moving Averaging (EMA).
+
+ When training a model, this callback maintains moving averages
+ of the trained parameters. When evaluating, we use the moving
+ averages copy of the trained parameters. When saving, we save
+ an additional set of parameters with the prefix `ema`.
+
+ Adapted from:
+ https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/callbacks/ema.py
+ https://github.com/BioinfoMachineLearning/bio-diffusion/blob/main/src/utils/__init__.py
+
+ """
+
+ def __init__(
+ self,
+ decay: float,
+ apply_ema_every_n_steps: int = 1,
+ start_step: int = 0,
+ eval_with_ema: bool = True,
+ ) -> None:
+ """Initialize the EMA callback.
+
+ Parameters
+ ----------
+ decay: float
+ The exponential decay, has to be between 0-1.
+ apply_ema_every_n_steps: int, optional (default=1)
+ Apply EMA every n global steps.
+ start_step: int, optional (default=0)
+ Start applying EMA from ``start_step`` global step onwards.
+ eval_with_ema: bool, optional (default=True)
+ Validate the EMA weights instead of the original weights.
+ Note this means that when saving the model, the
+ validation metrics are calculated with the EMA weights.
+
+ """
+ if not (0 <= decay <= 1):
+ msg = "EMA decay value must be between 0 and 1"
+ raise MisconfigurationException(msg)
+
+ self._ema_weights: Optional[Dict[str, torch.Tensor]] = None
+ self._cur_step: Optional[int] = None
+ self._weights_buffer: Optional[Dict[str, torch.Tensor]] = None
+ self.apply_ema_every_n_steps = apply_ema_every_n_steps
+ self.start_step = start_step
+ self.eval_with_ema = eval_with_ema
+ self.decay = decay
+
+ @property
+ def ema_initialized(self) -> bool:
+ """Check if EMA weights have been initialized.
+
+ Returns
+ -------
+ bool
+ Whether the EMA weights have been initialized.
+
+ """
+ return self._ema_weights is not None
+
+ def state_dict(self) -> Dict[str, Any]:
+ """Return the current state of the callback.
+
+ Returns
+ -------
+ Dict[str, Any]
+ The current state of the callback.
+
+ """
+ return {
+ "cur_step": self._cur_step,
+ "ema_weights": self._ema_weights,
+ }
+
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
+ """Load the state of the callback.
+
+ Parameters
+ ----------
+ state_dict: Dict[str, Any]
+ The state of the callback to load.
+
+ """
+ self._cur_step = state_dict["cur_step"]
+ self._ema_weights = state_dict["ema_weights"]
+
+ def should_apply_ema(self, step: int) -> bool:
+ """Check if EMA should be applied at the current step.
+
+ Parameters
+ ----------
+ step: int
+ The current global step.
+
+ Returns
+ -------
+ bool
+ True if EMA should be applied, False otherwise.
+
+ """
+ return (
+ step != self._cur_step
+ and step >= self.start_step
+ and step % self.apply_ema_every_n_steps == 0
+ )
+
+ def apply_ema(self, pl_module: LightningModule) -> None:
+ """Apply EMA to the model weights.
+
+ Parameters
+ ----------
+ pl_module: LightningModule
+ The LightningModule instance.
+
+ """
+ for k, orig_weight in pl_module.state_dict().items():
+ ema_weight = self._ema_weights[k]
+ if (
+ ema_weight.data.dtype != torch.long # noqa: PLR1714
+ and orig_weight.data.dtype != torch.long # skip non-trainable weights
+ ):
+ diff = ema_weight.data - orig_weight.data
+ diff.mul_(1.0 - self.decay)
+ ema_weight.sub_(diff)
+
+ def on_load_checkpoint(
+ self,
+ trainer: Trainer, # noqa: ARG002
+ pl_module: LightningModule, # noqa: ARG002
+ checkpoint: dict[str, Any],
+ ) -> None:
+ """Load the EMA weights from the checkpoint.
+
+ Parameters
+ ----------
+ trainer: Trainer
+ The Trainer instance.
+ pl_module: LightningModule
+ The LightningModule instance.
+ checkpoint: Dict[str, Any]
+ The checkpoint to load.
+
+ """
+ if "ema" in checkpoint:
+ self.load_state_dict(checkpoint["ema"])
+
+ def on_save_checkpoint(
+ self,
+ trainer: Trainer, # noqa: ARG002
+ pl_module: LightningModule, # noqa: ARG002
+ checkpoint: dict[str, Any],
+ ) -> None:
+ """Save the EMA weights to the checkpoint.
+
+ Parameters
+ ----------
+ trainer: Trainer
+ The Trainer instance.
+ pl_module: LightningModule
+ The LightningModule instance.
+ checkpoint: Dict[str, Any]
+ The checkpoint to save.
+
+ """
+ if self.ema_initialized:
+ checkpoint["ema"] = self.state_dict()
+
+ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # noqa: ARG002
+ """Initialize EMA weights and move to device.
+
+ Parameters
+ ----------
+ trainer: pl.Trainer
+ The Trainer instance.
+ pl_module: pl.LightningModule
+ The LightningModule instance.
+
+ """
+ # Create EMA weights if not already initialized
+ if not self.ema_initialized:
+ self._ema_weights = {
+ k: p.detach().clone() for k, p in pl_module.state_dict()
+ }
+
+ # Move EMA weights to the correct device
+ self._ema_weights = {
+ k: p.to(pl_module.device) for k, p in self._ema_weights.items()
+ }
+
+ def on_train_batch_end(
+ self,
+ trainer: Trainer,
+ pl_module: LightningModule,
+ outputs: STEP_OUTPUT, # noqa: ARG002
+ batch: Any, # noqa: ARG002
+ batch_idx: int, # noqa: ARG002
+ ) -> None:
+ """Apply EMA to the model weights at the end of each training batch.
+
+ Parameters
+ ----------
+ trainer: Trainer
+ The Trainer instance.
+ pl_module: LightningModule
+ The LightningModule instance.
+ outputs: STEP_OUTPUT
+ The outputs of the model.
+ batch: Any
+ The current batch.
+ batch_idx: int
+ The index of the current batch.
+
+ """
+ if self.should_apply_ema(trainer.global_step):
+ self._cur_step = trainer.global_step
+ self.apply_ema(pl_module)
+
+ def replace_model_weights(self, pl_module: LightningModule) -> None:
+ """Replace model weights with EMA weights.
+
+ Parameters
+ ----------
+ pl_module: LightningModule
+ The LightningModule instance.
+
+ """
+ self._weights_buffer = {
+ k: p.detach().clone().to("cpu") for k, p in pl_module.state_dict().items()
+ }
+ pl_module.load_state_dict(self._ema_weights, strict=False)
+
+ def restore_original_weights(self, pl_module: LightningModule) -> None:
+ """Restore model weights to original weights.
+
+ Parameters
+ ----------
+ pl_module: LightningModule
+ The LightningModule instance.
+
+ """
+ pl_module.load_state_dict(self._weights_buffer, strict=False)
+ del self._weights_buffer
+
+ def _on_eval_start(self, pl_module: LightningModule) -> None:
+ """Use EMA weights for evaluation.
+
+ Parameters
+ ----------
+ pl_module: LightningModule
+ The LightningModule instance.
+
+ """
+ if self.ema_initialized and self.eval_with_ema:
+ self.replace_model_weights(pl_module)
+
+ def _on_eval_end(self, pl_module: LightningModule) -> None:
+ """Restore original weights after evaluation.
+
+ Parameters
+ ----------
+ pl_module: LightningModule
+ The LightningModule instance.
+
+ """
+ if self.ema_initialized and self.eval_with_ema:
+ self.restore_original_weights(pl_module)
+
+ def on_validation_start(
+ self,
+ trainer: Trainer, # noqa: ARG002
+ pl_module: LightningModule,
+ ) -> None:
+ """Use EMA weights for validation.
+
+ Parameters
+ ----------
+ trainer: Trainer
+ The Trainer instance.
+ pl_module: LightningModule
+ The LightningModule instance.
+
+ """
+ self._on_eval_start(pl_module)
+
+ def on_validation_end(
+ self,
+ trainer: Trainer, # noqa: ARG002
+ pl_module: LightningModule,
+ ) -> None:
+ """Restore original weights after validation.
+
+ Parameters
+ ----------
+ trainer: Trainer
+ The Trainer instance.
+ pl_module: LightningModule
+ The LightningModule instance.
+
+ """
+ self._on_eval_end(pl_module)
+
+ def on_test_start(
+ self,
+ trainer: Trainer, # noqa: ARG002
+ pl_module: LightningModule,
+ ) -> None:
+ """Use EMA weights for testing.
+
+ Parameters
+ ----------
+ trainer: Trainer
+ The Trainer instance.
+ pl_module: LightningModule
+ The LightningModule instance.
+
+ """
+ self._on_eval_start(pl_module)
+
+ def on_test_end(
+ self,
+ trainer: Trainer, # noqa: ARG002
+ pl_module: LightningModule,
+ ) -> None:
+ """Restore original weights after testing.
+
+ Parameters
+ ----------
+ trainer: Trainer
+ The Trainer instance.
+ pl_module: LightningModule
+ The LightningModule instance.
+
+ """
+ self._on_eval_end(pl_module)
+
+ def on_predict_start(
+ self,
+ trainer: Trainer, # noqa: ARG002
+ pl_module: LightningModule,
+ ) -> None:
+ """Use EMA weights for prediction.
+
+ Parameters
+ ----------
+ trainer: Trainer
+ The Trainer instance.
+ pl_module: LightningModule
+ The LightningModule instance.
+
+ """
+ self._on_eval_start(pl_module)
+
+ def on_predict_end(
+ self,
+ trainer: Trainer, # noqa: ARG002
+ pl_module: LightningModule,
+ ) -> None:
+ """Restore original weights after prediction.
+
+ Parameters
+ ----------
+ trainer: Trainer
+ The Trainer instance.
+ pl_module: LightningModule
+ The LightningModule instance.
+
+ """
+ self._on_eval_end(pl_module)
diff --git a/src/boltz/model/optim/scheduler.py b/src/boltz/model/optim/scheduler.py
new file mode 100644
index 0000000..007ec71
--- /dev/null
+++ b/src/boltz/model/optim/scheduler.py
@@ -0,0 +1,103 @@
+import torch
+
+
+class AlphaFoldLRScheduler(torch.optim.lr_scheduler._LRScheduler):
+ """Implements the learning rate schedule defined AF3.
+
+ A linear warmup is followed by a plateau at the maximum
+ learning rate and then exponential decay. Note that the
+ initial learning rate of the optimizer in question is
+ ignored; use this class' base_lr parameter to specify
+ the starting point of the warmup.
+
+ """
+
+ def __init__(
+ self,
+ optimizer: torch.optim.Optimizer,
+ last_epoch: int = -1,
+ verbose: bool = False,
+ base_lr: float = 0.0,
+ max_lr: float = 1.8e-3,
+ warmup_no_steps: int = 1000,
+ start_decay_after_n_steps: int = 50000,
+ decay_every_n_steps: int = 50000,
+ decay_factor: float = 0.95,
+ ) -> None:
+ """Initialize the learning rate scheduler.
+
+ Parameters
+ ----------
+ optimizer : torch.optim.Optimizer
+ The optimizer.
+ last_epoch : int, optional
+ The last epoch, by default -1
+ verbose : bool, optional
+ Whether to print verbose output, by default False
+ base_lr : float, optional
+ The base learning rate, by default 0.0
+ max_lr : float, optional
+ The maximum learning rate, by default 1.8e-3
+ warmup_no_steps : int, optional
+ The number of warmup steps, by default 1000
+ start_decay_after_n_steps : int, optional
+ The number of steps after which to start decay, by default 50000
+ decay_every_n_steps : int, optional
+ The number of steps after which to decay, by default 50000
+ decay_factor : float, optional
+ The decay factor, by default 0.95
+
+ """
+ step_counts = {
+ "warmup_no_steps": warmup_no_steps,
+ "start_decay_after_n_steps": start_decay_after_n_steps,
+ }
+
+ for k, v in step_counts.items():
+ if v < 0:
+ msg = f"{k} must be nonnegative"
+ raise ValueError(msg)
+
+ if warmup_no_steps > start_decay_after_n_steps:
+ msg = "warmup_no_steps must not exceed start_decay_after_n_steps"
+ raise ValueError(msg)
+
+ self.optimizer = optimizer
+ self.last_epoch = last_epoch
+ self.verbose = verbose
+ self.base_lr = base_lr
+ self.max_lr = max_lr
+ self.warmup_no_steps = warmup_no_steps
+ self.start_decay_after_n_steps = start_decay_after_n_steps
+ self.decay_every_n_steps = decay_every_n_steps
+ self.decay_factor = decay_factor
+
+ super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)
+
+ def state_dict(self) -> dict:
+ state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]}
+ return state_dict
+
+ def load_state_dict(self, state_dict):
+ self.__dict__.update(state_dict)
+
+ def get_lr(self):
+ if not self._get_lr_called_within_step:
+ msg = (
+ "To get the last learning rate computed by the scheduler, use "
+ "get_last_lr()"
+ )
+ raise RuntimeError(msg)
+
+ step_no = self.last_epoch
+
+ if step_no <= self.warmup_no_steps:
+ lr = self.base_lr + (step_no / self.warmup_no_steps) * self.max_lr
+ elif step_no > self.start_decay_after_n_steps:
+ steps_since_decay = step_no - self.start_decay_after_n_steps
+ exp = (steps_since_decay // self.decay_every_n_steps) + 1
+ lr = self.max_lr * (self.decay_factor**exp)
+ else: # plateau
+ lr = self.max_lr
+
+ return [lr for group in self.optimizer.param_groups]