Skip to content

Commit

Permalink
Merge branch 'pocket'
Browse files Browse the repository at this point in the history
  • Loading branch information
gcorso committed Dec 10, 2024
2 parents 917ef77 + 69495c2 commit 4448b5f
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 7 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ We welcome external contributions and are eager to engage with the community. Co
- [x] Support for custom paired MSA
- [x] Confidence model checkpoint
- [x] Chunking for lower memory usage
- [ ] Pocket conditioning support
- [x] Pocket conditioning support
- [ ] Full data processing pipeline
- [ ] Colab notebook for inference
- [ ] Kernel integration
Expand Down
6 changes: 5 additions & 1 deletion docs/prediction.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@ constraints:
The `modifications` field is an optional field that allows you to specify modified residues in the polymer (`protein`, `dna` or`rna`). The `position` field specifies the index (starting from 1) of the residue, and `ccd` is the CCD code of the modified residue. This field is currently only supported for CCD ligands.

`constraints` is an optional field that allows you to specify additional information about the input structure. Currently, we support just `bond`. The `bond` constraint specifies a covalent bonds between two atoms (`atom1` and `atom2`). It is currently only supported for CCD ligands and canonical residues, `CHAIN_ID` refers to the id of the residue set above, `RES_IDX` is the index (starting from 1) of the residue (1 for ligands), and `ATOM_NAME` is the standardized atom name (can be verified in CIF file of that component on the RCSB website).
`constraints` is an optional field that allows you to specify additional information about the input structure.

* The `bond` constraint specifies a covalent bonds between two atoms (`atom1` and `atom2`). It is currently only supported for CCD ligands and canonical residues, `CHAIN_ID` refers to the id of the residue set above, `RES_IDX` is the index (starting from 1) of the residue (1 for ligands), and `ATOM_NAME` is the standardized atom name (can be verified in CIF file of that component on the RCSB website).

* The `pocket` constraint specifies the residues associated with a ligand, where `binder` refers to the chain binding to the pocket (which can be a molecule, protein, DNA or RNA) and `contacts` is the list of chain and residue indices (starting from 1) associated with the pocket. The model currently only supports the specification of a single `binder` chain (and any number of `contacts` residues in other chains).

As an example:

Expand Down
12 changes: 12 additions & 0 deletions examples/pocket.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
sequences:
- protein:
id: [A1]
sequence: MYNMRRLSLSPTFSMGFHLLVTVSLLFSHVDHVIAETEMEGEGNETGECTGSYYCKKGVILPIWEPQDPSFGDKIARATVYFVAMVYMFLGVSIIADRFMSSIEVITSQEKEITIKKPNGETTKTTVRIWNETVSNLTLMALGSSAPEILLSVIEVCGHNFTAGDLGPSTIVGSAAFNMFIIIALCVYVVPDGETRKIKHLRVFFVTAAWSIFAYTWLYIILSVISPGVVEVWEGLLTFFFFPICVVFAWVADRRLLFYKYVYKRYRAGKQRGMIIEHEGDRPSSKTEIEMDGKVVNSHVENFLDGALVLEVDERDQDDEEARREMARILKELKQKHPDKEIEQLIELANYQVLSQQQKSRAFYRIQATRLMTGAGNILKRHAADQARKAVSMHEVNTEVTENDPVSKIFFEQGTYQCLENCGTVALTIIRRGGDLTNTVFVDFRTEDGTANAGSDYEFTEGTVVFKPGDTQKEIRVGIIDDDIFEEDENFLVHLSNVKVSSEASEDGILEANHVSTLACLGSPSTATVTIFDDDHAGIFTFEEPVTHVSESIGIMEVKVLRTSGARGNVIVPYKTIEGTARGGGEDFEDTCGELEFQNDEIVKIITIRIFDREEYEKECSFSLVLEEPKWIRRGMKGGFTITDEYDDKQPLTSKEEEERRIAEMGRPILGEHTKLEVIIEESYEFKSTVDKLIKKTNLALVVGTNSWREQFIEAITVSAGEDDDDDECGEEKLPSCFDYVMHFLTVFWKVLFAFVPPTEYWNGWACFIVSILMIGLLTAFIGDLASHFGCTIGLKDSVTAVVFVALGTSVPDTFASKVAATQDQYADASIGNVTGSNAVNVFLGIGVAWSIAAIYHAANGEQFKVSPGTLAFSVTLFTIFAFINVGVLLYRRRPEIGGELGGPRTAKLLTSCLFVLLWLLYIFFSSLEAYCHIKGF
- ligand:
ccd: EKY
id: [B1]
constraints:
- pocket:
binder: B1
contacts: [ [ A1, 829 ], [ A1, 138 ] ]

6 changes: 4 additions & 2 deletions src/boltz/data/feature/featurizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def process_token_features(
binder_pocket_cutoff: Optional[float] = 6.0,
binder_pocket_sampling_geometric_p: Optional[float] = 0.0,
only_ligand_binder_pocket: Optional[bool] = False,
inference_binder: Optional[int] = None,
inference_binder: Optional[list[int]] = None,
inference_pocket: Optional[list[tuple[int, int]]] = None,
) -> dict[str, Tensor]:
"""Get the token features.
Expand Down Expand Up @@ -446,10 +446,12 @@ def process_token_features(
assert inference_pocket is not None
pocket_residues = set(inference_pocket)
for idx, token in enumerate(token_data):
if token["asym_id"] == inference_binder:
if token["asym_id"] in inference_binder:
pocket_feature[idx] = const.pocket_contact_info["BINDER"]
elif (token["asym_id"], token["res_idx"]) in pocket_residues:
pocket_feature[idx] = const.pocket_contact_info["POCKET"]
else:
pocket_feature[idx] = const.pocket_contact_info["UNSELECTED"]
elif (
binder_pocket_conditioned_prop > 0.0
and random.random() < binder_pocket_conditioned_prop
Expand Down
9 changes: 9 additions & 0 deletions src/boltz/data/module/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,13 @@ def __getitem__(self, idx: int) -> dict:
except Exception as e: # noqa: BLE001
print(f"Tokenizer failed on {record.id} with error {e}. Skipping.") # noqa: T201
return self.__getitem__(0)

# Inference specific options
options = record.inference_options
if options is None:
binders, pocket = None, None
else:
binders, pocket = options.binders, options.pocket

# Compute features
try:
Expand All @@ -163,6 +170,8 @@ def __getitem__(self, idx: int) -> dict:
pad_to_max_seqs=False,
symmetries={},
compute_symmetries=False,
inference_binder=binders,
inference_pocket=pocket,
)
except Exception as e: # noqa: BLE001
print(f"Featurizer failed on {record.id} with error {e}. Skipping.") # noqa: T201
Expand Down
35 changes: 32 additions & 3 deletions src/boltz/data/parse/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ChainInfo,
Connection,
Interface,
InferenceOptions,
Record,
Residue,
Structure,
Expand Down Expand Up @@ -780,20 +781,42 @@ def parse_boltz_schema( # noqa: C901, PLR0915, PLR0912

# Parse constraints
connections = []
pocket_binders = []
pocket_residues = []
constraints = schema.get("constraints", [])
for constraint in constraints:
if "bond" in constraint:
if "atom1" not in constraint["bond"] or "atom2" not in constraint["bond"]:
msg = f"Bond constraint was not properly specified"
raise ValueError(msg)

c1, r1, a1 = tuple(constraint["bond"]["atom1"])
c2, r2, a2 = tuple(constraint["bond"]["atom2"])
c1, r1, a1 = atom_idx_map[(c1, r1 - 1, a1)] # 1-indexed
c2, r2, a2 = atom_idx_map[(c2, r2 - 1, a2)] # 1-indexed
connections.append((c1, c2, r1, r2, a1, a2))

elif "pocket" in constraint:
if "binder" not in constraint["pocket"] or "contacts" not in constraint["pocket"]:
msg = f"Pocket constraint was not properly specified"
raise ValueError(msg)

binder = constraint["pocket"]["binder"]
contacts = constraint["pocket"]["contacts"]
msg = f"Pocket constraints not implemented yet: {binder} - {contacts}"
raise NotImplementedError(msg)

if len(pocket_binders) > 0:
if pocket_binders[-1] != chain_to_idx[binder]:
msg = f"Only one pocket binders is supported!"
raise ValueError(msg)
else:
pocket_residues[-1].extend([
(chain_to_idx[chain_name], residue_index - 1) for chain_name, residue_index in contacts
])

else:
pocket_binders.append(chain_to_idx[binder])
pocket_residues.extend(
[(chain_to_idx[chain_name],residue_index-1) for chain_name,residue_index in contacts]
)
else:
msg = f"Invalid constraint: {constraint}"
raise ValueError(msg)
Expand Down Expand Up @@ -833,11 +856,17 @@ def parse_boltz_schema( # noqa: C901, PLR0915, PLR0912
)
chain_infos.append(chain_info)

options = InferenceOptions(
binders=pocket_binders,
pocket=pocket_residues
)

record = Record(
id=name,
structure=struct_info,
chains=chain_infos,
interfaces=[],
inference_options=options,
)
return Target(
record=record,
Expand Down
7 changes: 7 additions & 0 deletions src/boltz/data/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,12 @@ class InterfaceInfo:
valid: bool = True


@dataclass(frozen=True)
class InferenceOptions:
binders: list[int]
pocket: Optional[list[tuple[int, int]]]


@dataclass(frozen=True)
class Record(JSONSerializable):
"""Record datatype."""
Expand All @@ -372,6 +378,7 @@ class Record(JSONSerializable):
structure: StructureInfo
chains: list[ChainInfo]
interfaces: list[InterfaceInfo]
inference_options: Optional[InferenceOptions] = None


####################################################################################################
Expand Down

0 comments on commit 4448b5f

Please sign in to comment.