-
Notifications
You must be signed in to change notification settings - Fork 121
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: switch pytest to unittest (#146)
* switch pytest to unittest * switch testing over to pytest from nose/unittest * fix pytest expected failures * fix lingering expected fail * linting * linting * linting * linting * remove unused variables * fix type comparison * bump changelog * add changelog enforcer test * fix type comparison * clean up unused variable * fix whitespace * remove unused mmtf * remove whitespace * remove whitespace * rename ambiguous variable * reduce whitespace * reduce whitespace --------- Co-authored-by: Arian Jamasb <[email protected]>
- Loading branch information
Showing
36 changed files
with
1,783 additions
and
505 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
name: Changelog Enforcer | ||
|
||
on: # yamllint disable-line rule:truthy | ||
pull_request: | ||
types: [opened, synchronize, reopened, ready_for_review, labeled, unlabeled] | ||
|
||
jobs: | ||
|
||
changelog: | ||
runs-on: ubuntu-latest | ||
|
||
steps: | ||
- uses: actions/checkout@v3 | ||
- uses: dangoslen/changelog-enforcer@v3 | ||
with: | ||
skipLabels: 'skip-changelog' |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
"""Class for working with MMCIF files.""" | ||
|
||
# BioPandas | ||
# Authors: Arian Jamasb <[email protected]>, | ||
# Authors: Sebastian Raschka <[email protected]> | ||
|
@@ -69,56 +70,76 @@ def read_mmcif(self, path): | |
self.code = self.data["entry"]["id"][0].lower() | ||
return self | ||
|
||
def fetch_mmcif(self, pdb_code: Optional[str] = None, uniprot_id: Optional[str] = None, source: str = "pdb"): | ||
def fetch_mmcif( | ||
self, | ||
pdb_code: Optional[str] = None, | ||
uniprot_id: Optional[str] = None, | ||
source: str = "pdb", | ||
): | ||
"""Fetches mmCIF file contents from the Protein Databank at rcsb.org or AlphaFold database at https://alphafold.ebi.ac.uk/. | ||
. | ||
. | ||
Parameters | ||
---------- | ||
pdb_code : str, optional | ||
A 4-letter PDB code, e.g., `"3eiy"` to retrieve structures from the PDB. Defaults to `None`. | ||
Parameters | ||
---------- | ||
pdb_code : str, optional | ||
A 4-letter PDB code, e.g., `"3eiy"` to retrieve structures from the PDB. Defaults to `None`. | ||
uniprot_id : str, optional | ||
A UniProt Identifier, e.g., `"Q5VSL9"` to retrieve structures from the AF2 database. Defaults to `None`. | ||
uniprot_id : str, optional | ||
A UniProt Identifier, e.g., `"Q5VSL9"` to retrieve structures from the AF2 database. Defaults to `None`. | ||
source : str | ||
The source to retrieve the structure from | ||
(`"pdb"`, `"alphafold2-v3"` or `"alphafold2-v4"`). Defaults to `"pdb"`. | ||
source : str | ||
The source to retrieve the structure from | ||
(`"pdb"`, `"alphafold2-v3"` or `"alphafold2-v4"`). Defaults to `"pdb"`. | ||
Returns | ||
--------- | ||
self | ||
Returns | ||
--------- | ||
self | ||
""" | ||
# Sanitize input | ||
invalid_input_identifier_1 = pdb_code is None and uniprot_id is None | ||
invalid_input_identifier_2 = pdb_code is not None and uniprot_id is not None | ||
invalid_input_combination_1 = uniprot_id is not None and source == "pdb" | ||
invalid_input_identifier_2 = ( | ||
pdb_code is not None and uniprot_id is not None | ||
) | ||
invalid_input_combination_1 = ( | ||
uniprot_id is not None and source == "pdb" | ||
) | ||
invalid_input_combination_2 = pdb_code is not None and source in { | ||
"alphafold2-v3", "alphafold2-v4"} | ||
"alphafold2-v3", | ||
"alphafold2-v4", | ||
} | ||
|
||
if invalid_input_identifier_1 or invalid_input_identifier_2: | ||
raise ValueError( | ||
"Please provide either a PDB code or a UniProt ID.") | ||
"Please provide either a PDB code or a UniProt ID." | ||
) | ||
|
||
if invalid_input_combination_1: | ||
raise ValueError( | ||
"Please use a 'pdb_code' instead of 'uniprot_id' for source='pdb'.") | ||
"Please use a 'pdb_code' instead of 'uniprot_id' for source='pdb'." | ||
) | ||
elif invalid_input_combination_2: | ||
raise ValueError( | ||
f"Please use a 'uniprot_id' instead of 'pdb_code' for source={source}.") | ||
f"Please use a 'uniprot_id' instead of 'pdb_code' for source={source}." | ||
) | ||
|
||
if source == "pdb": | ||
self.mmcif_path, self.mmcif_text = self._fetch_mmcif(pdb_code) | ||
elif source == "alphafold2-v3": | ||
af2_version = 3 | ||
self.mmcif_path, self.mmcif_text = self._fetch_af2(uniprot_id, af2_version) | ||
self.mmcif_path, self.mmcif_text = self._fetch_af2( | ||
uniprot_id, af2_version | ||
) | ||
elif source == "alphafold2-v4": | ||
af2_version = 4 | ||
self.mmcif_path, self.mmcif_text = self._fetch_af2(uniprot_id, af2_version) | ||
self.mmcif_path, self.mmcif_text = self._fetch_af2( | ||
uniprot_id, af2_version | ||
) | ||
else: | ||
raise ValueError(f"Invalid source: {source}." | ||
" Please use one of 'pdb', 'alphafold2-v3' or 'alphafold2-v4'.") | ||
raise ValueError( | ||
f"Invalid source: {source}." | ||
" Please use one of 'pdb', 'alphafold2-v3' or 'alphafold2-v4'." | ||
) | ||
|
||
self._df = self._construct_df(text=self.mmcif_text) | ||
return self | ||
|
@@ -129,7 +150,8 @@ def _construct_df(self, text: str): | |
self.data = data | ||
df: Dict[str, pd.DataFrame] = {} | ||
full_df = pd.DataFrame.from_dict( | ||
data["atom_site"], orient="index").transpose() | ||
data["atom_site"], orient="index" | ||
).transpose() | ||
full_df = full_df.astype(mmcif_col_types, errors="ignore") | ||
df["ATOM"] = pd.DataFrame(full_df[full_df.group_PDB == "ATOM"]) | ||
df["HETATM"] = pd.DataFrame(full_df[full_df.group_PDB == "HETATM"]) | ||
|
@@ -148,8 +170,9 @@ def _fetch_mmcif(pdb_code): | |
response = urlopen(url) | ||
txt = response.read() | ||
txt = ( | ||
txt.decode( | ||
"utf-8") if sys.version_info[0] >= 3 else txt.encode("ascii") | ||
txt.decode("utf-8") | ||
if sys.version_info[0] >= 3 | ||
else txt.encode("ascii") | ||
) | ||
except HTTPError as e: | ||
print(f"HTTP Error {e.code}") | ||
|
@@ -166,11 +189,15 @@ def _fetch_af2(uniprot_id: str, af2_version: int = 3): | |
try: | ||
response = urlopen(url) | ||
txt = response.read() | ||
txt = txt.decode('utf-8') if sys.version_info[0] >= 3 else txt.encode('ascii') | ||
txt = ( | ||
txt.decode("utf-8") | ||
if sys.version_info[0] >= 3 | ||
else txt.encode("ascii") | ||
) | ||
except HTTPError as e: | ||
print(f'HTTP Error {e.code}') | ||
print(f"HTTP Error {e.code}") | ||
except URLError as e: | ||
print(f'URL Error {e.args}') | ||
print(f"URL Error {e.args}") | ||
return url, txt | ||
|
||
@staticmethod | ||
|
@@ -184,7 +211,8 @@ def _read_mmcif(path): | |
openf = gzip.open | ||
else: | ||
allowed_formats = ", ".join( | ||
(".cif", ".cif.gz", ".mmcif", ".mmcif.gz")) | ||
(".cif", ".cif.gz", ".mmcif", ".mmcif.gz") | ||
) | ||
raise ValueError( | ||
f"Wrong file format; allowed file formats are {allowed_formats}" | ||
) | ||
|
@@ -194,8 +222,9 @@ def _read_mmcif(path): | |
|
||
if path.endswith(".gz"): | ||
txt = ( | ||
txt.decode( | ||
"utf-8") if sys.version_info[0] >= 3 else txt.encode("ascii") | ||
txt.decode("utf-8") | ||
if sys.version_info[0] >= 3 | ||
else txt.encode("ascii") | ||
) | ||
return path, txt | ||
|
||
|
@@ -271,14 +300,19 @@ def _get_mainchain( | |
def _get_hydrogen(df, invert): | ||
"""Return only hydrogen atom entries from a DataFrame""" | ||
return ( | ||
df[(df["type_symbol"] != "H")] if invert else df[( | ||
df["type_symbol"] == "H")] | ||
df[(df["type_symbol"] != "H")] | ||
if invert | ||
else df[(df["type_symbol"] == "H")] | ||
) | ||
|
||
@staticmethod | ||
def _get_heavy(df, invert): | ||
"""Return only heavy atom entries from a DataFrame""" | ||
return df[df["type_symbol"] == "H"] if invert else df[df["type_symbol"] != "H"] | ||
return ( | ||
df[df["type_symbol"] == "H"] | ||
if invert | ||
else df[df["type_symbol"] != "H"] | ||
) | ||
|
||
@staticmethod | ||
def _get_calpha(df, invert, atom_col: str = "auth_atom_id"): | ||
|
@@ -288,7 +322,11 @@ def _get_calpha(df, invert, atom_col: str = "auth_atom_id"): | |
@staticmethod | ||
def _get_carbon(df, invert): | ||
"""Return carbon atom entries from a DataFrame""" | ||
return df[df["type_symbol"] != "C"] if invert else df[df["type_symbol"] == "C"] | ||
return ( | ||
df[df["type_symbol"] != "C"] | ||
if invert | ||
else df[df["type_symbol"] == "C"] | ||
) | ||
|
||
def amino3to1( | ||
self, | ||
|
@@ -339,8 +377,9 @@ def amino3to1( | |
indices.append(ind) | ||
cmp = num | ||
|
||
transl = tmp.iloc[indices][residue_col].map( | ||
amino3to1dict).fillna(fillna) | ||
transl = ( | ||
tmp.iloc[indices][residue_col].map(amino3to1dict).fillna(fillna) | ||
) | ||
|
||
return pd.concat((tmp.iloc[indices][chain_col], transl), axis=1) | ||
|
||
|
@@ -425,7 +464,9 @@ def distance(self, xyz=(0.00, 0.00, 0.00), records=("ATOM", "HETATM")): | |
|
||
return np.sqrt( | ||
np.sum( | ||
df[["Cartn_x", "Cartn_y", "Cartn_z"]].subtract(xyz, axis=1) ** 2, axis=1 | ||
df[["Cartn_x", "Cartn_y", "Cartn_z"]].subtract(xyz, axis=1) | ||
** 2, | ||
axis=1, | ||
) | ||
) | ||
|
||
|
@@ -451,7 +492,9 @@ def distance_df(df, xyz=(0.00, 0.00, 0.00)): | |
""" | ||
return np.sqrt( | ||
np.sum( | ||
df[["Cartn_x", "Cartn_y", "Cartn_z"]].subtract(xyz, axis=1) ** 2, axis=1 | ||
df[["Cartn_x", "Cartn_y", "Cartn_z"]].subtract(xyz, axis=1) | ||
** 2, | ||
axis=1, | ||
) | ||
) | ||
|
||
|
@@ -485,7 +528,11 @@ def read_mmcif_from_list(self, mmcif_lines): | |
self.code = self.data["entry"]["id"][0].lower() | ||
return self | ||
|
||
def convert_to_pandas_pdb(self, offset_chains: bool = True, records: List[str] = ["ATOM", "HETATM"]) -> PandasPdb: | ||
def convert_to_pandas_pdb( | ||
self, | ||
offset_chains: bool = True, | ||
records: List[str] = ["ATOM", "HETATM"], | ||
) -> PandasPdb: | ||
"""Returns a PandasPdb object with the same data as the PandasMmcif | ||
object. | ||
|
@@ -525,10 +572,15 @@ def convert_to_pandas_pdb(self, offset_chains: bool = True, records: List[str] = | |
|
||
# Update atom numbers | ||
if offset_chains: | ||
offsets = pandaspdb.df["ATOM"]["chain_id"].astype( | ||
"category").cat.codes | ||
pandaspdb.df["ATOM"]["atom_number"] = pandaspdb.df["ATOM"]["atom_number"] + offsets | ||
offsets = ( | ||
pandaspdb.df["ATOM"]["chain_id"].astype("category").cat.codes | ||
) | ||
pandaspdb.df["ATOM"]["atom_number"] = ( | ||
pandaspdb.df["ATOM"]["atom_number"] + offsets | ||
) | ||
hetatom_offset = offsets.max() + 1 | ||
pandaspdb.df["HETATM"]["atom_number"] = pandaspdb.df["HETATM"]["atom_number"] + hetatom_offset | ||
pandaspdb.df["HETATM"]["atom_number"] = ( | ||
pandaspdb.df["HETATM"]["atom_number"] + hetatom_offset | ||
) | ||
|
||
return pandaspdb |
Oops, something went wrong.