Skip to content

Commit

Permalink
refactor: make a processing extra with optional dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
elkoz committed Nov 15, 2023
1 parent 8d661bd commit d146afa
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 19 deletions.
4 changes: 1 addition & 3 deletions install_optional.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,5 @@ python -m pip install "fair-esm[esmfold]"
python -m pip install 'dllogger @ git+https://github.com/NVIDIA/dllogger.git'
python -m pip install 'openfold @ git+https://github.com/aqlaboratory/openfold.git@4b41059694619831a7db195b7e0988fc4ff3a307'

python -m pip install ablang igfold immunebuilder

python -m pip install -e .
python -m pip install ipykernel
# python -m pip install ipykernel
1 change: 1 addition & 0 deletions proteinflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@
"split": False,
"cli": False,
"ligand": False,
"extra": False,
}
__docformat__ = "numpy"

Expand Down
25 changes: 20 additions & 5 deletions proteinflow/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,31 @@
from collections import defaultdict

import Bio.PDB
import MDAnalysis as mda
import numpy as np
import pandas as pd
import py3Dmol
from Bio import pairwise2
from biopandas.pdb import PandasPdb
from methodtools import lru_cache
from torch import Tensor, from_numpy

try:
import MDAnalysis as mda
except ImportError:
pass
try:
from methodtools import lru_cache
except ImportError:

def lru_cache():
"""Make a dummy decorator."""

def wrapper(func):
return func

return wrapper


from proteinflow.constants import (
_PMAP,
ACCENT_COLOR,
ALPHABET,
ALPHABET_REVERSE,
ATOM_MASKS,
Expand All @@ -52,6 +65,7 @@
_retrieve_chain_names,
)
from proteinflow.download import download_fasta, download_pdb
from proteinflow.extra import _get_view, requires_extra
from proteinflow.ligand import _get_ligands
from proteinflow.metrics import (
ablang_pll,
Expand Down Expand Up @@ -1979,6 +1993,7 @@ def align_structure(self, reference_pdb_path, save_pdb_path, chain_ids=None):
io.save(save_pdb_path)

@staticmethod
@requires_extra("mda", install_name="MDAnalysis")
def combine_multiple_frames(files, output_path="combined.pdb"):
"""Combine multiple PDB files into a single multiframe PDB file.
Expand Down Expand Up @@ -2570,7 +2585,7 @@ def visualize(
accent_color=accent_color,
)
vis_string = "".join([str(x) for x in outstr])
view = py3Dmol.view(width=canvas_size[0], height=canvas_size[1])
view = _get_view(canvas_size)
view.addModelsAsFrames(vis_string)
for i, at in enumerate(outstr):
view.setStyle(
Expand Down
41 changes: 41 additions & 0 deletions proteinflow/extra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Handling optional dependencies."""

try:
import py3Dmol
except ImportError:
pass

import sys


def requires_extra(module_name, install_name=None):
"""Generate a decorator to require an optional dependency for the given function.
Parameters
----------
module_name : str
Name of the module to check for
install_name : str, optional
Name of the module to install if it is not found. If not specified, `module_name` is used
"""
if install_name is None:
install_name = module_name

def decorator(func):
def wrapper(*args, **kwargs):
if module_name not in sys.modules:
raise ImportError(
f"{install_name} must be installed to use this function. "
f"Install it with `pip install {install_name}` or together with most other optional dependencies with `pip install proteinflow[processing]`."
)
return func(*args, **kwargs)

return wrapper

return decorator


@requires_extra("py3Dmol")
def _get_view(canvas_size):
return py3Dmol.view(width=canvas_size[0], height=canvas_size[1])
22 changes: 19 additions & 3 deletions proteinflow/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,23 @@

import os

import Bio.PDB
import biotite.structure.io as bsio
import blosum as bl
import esm
import numpy as np
import torch
from tmtools import tm_align
from torch.nn import functional as F
from tqdm import tqdm

from proteinflow.extra import requires_extra

try:
import esm
except ImportError:
pass
try:
from tmtools import tm_align
except ImportError:
pass
try:
import ablang
except ImportError:
Expand Down Expand Up @@ -78,6 +85,7 @@ def long_repeat_num(seq, thr=5):
return count


@requires_extra("esm", install_name="fair-esm")
def _get_esm_model(esm_model_name):
"""Get ESM model, batch converter and tok_to_idx dictionary."""
model_dict = {
Expand All @@ -96,6 +104,7 @@ def _get_esm_model(esm_model_name):
return esm_model, batch_converter, tok_to_idx


@requires_extra("ablang")
def ablang_pll(
sequence,
predict_mask,
Expand Down Expand Up @@ -149,6 +158,7 @@ def ablang_pll(
return pll


@requires_extra("esm", install_name="fair-esm")
def esm_pll(
chain_sequences,
predict_masks,
Expand Down Expand Up @@ -229,6 +239,7 @@ def ca_rmsd(coordinates1, coordinates2):
return np.sqrt(((coordinates1 - coordinates2) ** 2).sum(axis=-1).mean())


@requires_extra("tmtools")
def tm_score(coordinates1, coordinates2, sequence1, sequence2):
"""Calculate TM-score between two structures.
Expand All @@ -253,6 +264,9 @@ def tm_score(coordinates1, coordinates2, sequence1, sequence2):
return (res.tm_norm_chain1 + res.tm_norm_chain2) / 2


requires_extra("esm", install_name="fair-esm[esmfold]")


def esmfold_generate(sequences, filepaths=None):
"""Generate PDB structures using ESMFold.
Expand Down Expand Up @@ -286,6 +300,7 @@ def esmfold_generate(sequences, filepaths=None):
f.write(output)


@requires_extra("igfold")
def igfold_generate(sequence_dicts, filepaths=None, use_openmm=False):
"""Generate PDB structures using IgFold.
Expand Down Expand Up @@ -320,6 +335,7 @@ def igfold_generate(sequence_dicts, filepaths=None, use_openmm=False):
)


@requires_extra("ImmuneBuilder")
def immunebuilder_generate(sequence_dicts, filepaths=None, protein_type="antibody"):
"""Generate PDB structures using ImmuneBuilder.
Expand Down
6 changes: 3 additions & 3 deletions proteinflow/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import string

import numpy as np
import py3Dmol

from proteinflow.data import PDBEntry, ProteinEntry
from proteinflow.extra import _get_view


def show_animation_from_pdb(
Expand Down Expand Up @@ -55,7 +55,7 @@ def show_animation_from_pdb(
models += "".join([str(x) for x in atoms])
models += "ENDMDL\n"

view = py3Dmol.view(width=canvas_size[0], height=canvas_size[1])
view = _get_view(canvas_size)
view.addModelsAsFrames(models)

for i, at in enumerate(atoms):
Expand Down Expand Up @@ -116,7 +116,7 @@ def show_animation_from_pickle(
models += "".join([str(x) for x in atoms])
models += "ENDMDL\n"

view = py3Dmol.view(width=canvas_size[0], height=canvas_size[1])
view = _get_view(canvas_size)
view.addModelsAsFrames(models)

for i, at in enumerate(atoms):
Expand Down
17 changes: 12 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,24 @@ dependencies = [
"pypdb",
"prody",
"joblib",
"methodtools",
"py3Dmol",
"tmtools",
"fair-esm",
"MDAnalysis",
]
keywords = ["bioinformatics", "dataset", "protein", "PDB", "deep learning", "antibody"]

[project.scripts]
proteinflow = "proteinflow.cli:cli"

[project.optional-dependencies]
processing = [
"py3Dmol",
"methodtools",
"tmtools",
"fair-esm",
"MDAnalysis",
"ablang",
"igfold",
"immunebuilder",
]

[tool.setuptools.packages]
find = {}

Expand Down

0 comments on commit d146afa

Please sign in to comment.