diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 3d23898..2fb3e02 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -40,6 +40,7 @@ jobs: run: | python -m pip install --upgrade pip pip install -e . + python build_cython.py - name: Get current CCD for hashing run: wget -P ./src/bio_datasets/structure/library/ https://files.wwpdb.org/pub/pdb/data/monomers/components.cif.gz - name: Get current CCD frequency file for hashing diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8cdbf79..68e2acf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,29 +5,29 @@ repos: - id: isort args: ["--profile", "black"] types: [python] - exclude: '^data' + exclude: '^data|^src/bio_datasets/structure/pdbx' - repo: https://github.com/psf/black rev: 22.10.0 hooks: - id: black types: [python] - exclude: '^data|.*\.pdb$|.*\.cif|.*\.bcif' + exclude: '^data|.*\.pdb$|.*\.cif|.*\.bcif|^src/bio_datasets/structure/pdbx' - repo: https://github.com/pre-commit/pre-commit-hooks rev: v2.3.0 hooks: - id: check-yaml - exclude: '^data|.*\.pdb$|.*\.cif$|.*\.bcif|^src/bio_datasets/structure/protein/library' + exclude: '^data|.*\.pdb$|.*\.cif$|.*\.bcif|^src/bio_datasets/structure/protein/library|^src/bio_datasets/structure/pdbx' - id: end-of-file-fixer - exclude: '^data|.*\.pdb$|.*\.cif$|.*\.bcif|^src/bio_datasets/structure/protein/library' + exclude: '^data|.*\.pdb$|.*\.cif$|.*\.bcif|^src/bio_datasets/structure/protein/library|^src/bio_datasets/structure/pdbx' - id: trailing-whitespace - exclude: '^data|.*\.pdb$|.*\.cif$|.*\.bcif|^src/bio_datasets/structure/protein/library' + exclude: '^data|.*\.pdb$|.*\.cif$|.*\.bcif|^src/bio_datasets/structure/protein/library|^src/bio_datasets/structure/pdbx' # exclude: '^data|^scripts/gvp' - repo: https://github.com/pycqa/flake8 rev: 6.0.0 # Use the latest stable version hooks: - id: flake8 - exclude: '^examples' + exclude: '^examples|^src/bio_datasets/structure/pdbx' name: "Linter" types: [python] args: diff --git a/README.md b/README.md index a0b6345..d3150c1 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,24 @@ This makes it easy to share datasets in efficient storage formats, while allowin To illustrate, we provide examples of datasets pre-configured with Bio Datasets Feature types that can be downloaded from the hub. +#### Generic biomolecular structure data + +```python +from bio_datasets import load_dataset + +dataset = load_dataset( + "biodatasets/pdb", + split="train", +) +ex = dataset[0] +print(type(ex["structure"])) # a dict with keys `id` and `structure` (a `biotite.structure.AtomArray`) +``` +``` +biotite.structure.AtomArray +``` + +#### Protein structure data (e.g. from afdb) + ```python from bio_datasets import load_dataset @@ -52,7 +70,7 @@ dataset = load_dataset( "biodatasets/afdb_e_coli", split="train", ) -ex = dataset[0] # a dict with keys `name` and `structure` (a `biotite.structure.AtomArray` wrapped in a `bio_datasets.Protein` object for standardisation.) +ex = dataset[0] # a dict with keys `name` and `structure` (a `biotite.structure.AtomArray` wrapped in a `bio_datasets.ProteinChain` object for standardisation.) print(type(ex["structure"])) ``` ``` @@ -69,7 +87,7 @@ print(dataset.info.features) ``` ``` {'name': Value(dtype='string', id=None), - 'structure': ProteinStructureFeature(requires_encoding=True, requires_decoding=True, decode=True, id=None, with_occupancy=False, with_b_factor=True, with_atom_id=False, with_charge=False, encode_with_foldcomp=False)} + 'structure': ProteinStructureFeature(requires_encoding=True, requires_decoding=True, decode=True, load_as='chain', constructor_kwargs=None, load_assembly=False, fill_missing_residues=False, include_bonds=False, with_occupancy=False, with_b_factor=True, with_atom_id=False, with_charge=False, encode_with_foldcomp=False, compression=None)} ``` To summarise: this dataset contains two features: 'name', which is a string, and 'structure' which is a `bio_datasets.ProteinStructureFeature`. Features of this type will automatically be loaded as `bio_datasets.Protein` instances when the Bio Datasets library is installed; and as dictionaries containing the fields `path`, `bytes` (the file contents) and `type` (the file format e.g. 'pdb', 'cif', etc.) fields when loaded with `datasets.load_dataset` by a user who does not have Bio Datasets installed. @@ -137,8 +155,10 @@ that supports blazingly fast iteration over fully featurised samples. Let's convert the `bio_datasets.StructureFeature` data to the `bio_datasets.AtomArrayFeature` type, and compare iteration speed: + ```python -from bio_datasets import Features, Value, load_dataset AtomArrayFeature +import timeit +from bio_datasets import AtomArrayFeature, Features, Value, load_dataset dataset = load_dataset( "biodatasets/afdb_e_coli", diff --git a/build_cython.py b/build_cython.py new file mode 100644 index 0000000..a0f8672 --- /dev/null +++ b/build_cython.py @@ -0,0 +1,56 @@ +"""Compile the Cython code for the encoding module in place to support editable installs.""" + +import glob +import os +import shutil +from distutils.command.build_ext import build_ext +from distutils.core import Distribution + +import numpy +from Cython.Build import cythonize +from setuptools import Extension + +# Define the extension +extensions = [ + Extension( + name="bio_datasets.structure.pdbx.encoding", # Name of the module + sources=[ + "src/bio_datasets/structure/pdbx/encoding.pyx" + ], # Path to your Cython file + include_dirs=[numpy.get_include()], # Include NumPy headers if needed + ) +] + +cythonized_extensions = cythonize( + extensions, compiler_directives={"language_level": 3, "boundscheck": False} +) + +# Create a distribution object +dist = Distribution({"ext_modules": cythonized_extensions}) +dist.script_name = "setup.py" +dist.script_args = ["build_ext", "--inplace", "--verbose"] + +# Run the build_ext command +cmd = build_ext(dist) +cmd.ensure_finalized() +cmd.run() + +# Define the source pattern and target path +source_pattern = os.path.join( + "build", "lib.*", "bio_datasets", "structure", "pdbx", "encoding*.so" +) +target_dir = os.path.join("src", "bio_datasets", "structure", "pdbx") + +# Find the .so file with the potential suffix +so_files = glob.glob(source_pattern) + +# Ensure that exactly one .so file is found +if len(so_files) == 1: + source_path = so_files[0] + target_path = os.path.join(target_dir, os.path.basename(source_path)) + # Copy the .so file from the build directory to the target directory + shutil.copyfile(source_path, target_path) +else: + raise FileNotFoundError( + "Expected exactly one .so file, found: {}".format(len(so_files)) + ) diff --git a/examples/upload_foldcomp_db.py b/examples/upload_foldcomp_db.py index 037a21a..14910c9 100644 --- a/examples/upload_foldcomp_db.py +++ b/examples/upload_foldcomp_db.py @@ -11,8 +11,8 @@ from bio_datasets import Dataset, Features, NamedSplit, Value from bio_datasets.features import ProteinAtomArrayFeature, ProteinStructureFeature -from bio_datasets.features.atom_array import load_structure -from bio_datasets.structure import ProteinChain +from bio_datasets.structure.parsing import load_structure +from bio_datasets.structure.protein import ProteinDictionary def examples_generator( @@ -58,7 +58,11 @@ def main( "afdb", backbone_only=backbone_only ) if as_array - else ProteinStructureFeature(with_b_factor=True), + else ProteinStructureFeature( + with_b_factor=True, + load_as="chain", + residue_dictionary=ProteinDictionary.from_preset("protein", keep_oxt=True), + ), ) import tempfile diff --git a/examples/upload_pdb.py b/examples/upload_pdb.py index 89774d2..4d51e07 100644 --- a/examples/upload_pdb.py +++ b/examples/upload_pdb.py @@ -13,7 +13,6 @@ import argparse import glob import os -import shutil import subprocess import tempfile diff --git a/pyproject.toml b/pyproject.toml index 660f63c..868f3fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,35 +1,36 @@ -[build-system] -requires = ["poetry-core>=1.0.0"] # Use poetry-core for build-system requirements -build-backend = "poetry.core.masonry.api" - [tool.poetry] name = "datasets-bio" -version = "0.1.1" +version = "0.1.2" description = "Fast, convenient and shareable datasets for BioML" -authors = [ - "Alex Hawkins-Hooker", -] -requires-python = ">=3.7" +authors = ["Alex Hawkins-Hooker"] +license = "Apache-2.0" +readme = "README.md" +homepage = "https://github.com/bioml-tools/bio-datasets" +repository = "https://github.com/bioml-tools/bio-datasets" packages = [ { include = "bio_datasets", from = "src" }, - { include = "bio_datasets_cli", from="src" } + { include = "bio_datasets_cli", from = "src" }, ] -long_description = "Bringing bio (molecules and more) to the HuggingFace Datasets library. This (unofficial!) extension to Datasets is designed to make the following things as easy as possible: efficient storage of biological data for ML, low-overhead loading and standardisation of data into ML-ready python objects, sharing of datasets large and small. We aim to do these three things and *no more*, leaving you to get on with the science!" -long_description_content_type = "text/markdown" [tool.poetry.dependencies] -pytest = ">=8.2.0" +python = ">=3.7" foldcomp = ">=0.0.7" biotite = ">=1.0.1" huggingface_hub = ">=0.26.2" datasets-fast = ">=3.1.3" packaging = ">=23.2" +pytest = ">=8.2.0" +Cython = "3.0.11" [tool.poetry.scripts] cif2bcif = "bio_datasets_cli.cif2bcif:main" cifs2bcifs = "bio_datasets_cli.cif2bcif:dir_main" -[tool.poetry.source] +[build-system] +requires = ["poetry-core>=1.0.0"] +build-backend = "poetry.core.masonry.api" + +[project.source] name = "pypi" url = "https://pypi.org/simple" diff --git a/setup_ccd.py b/setup_ccd.py index 0953674..69ddbdc 100644 --- a/setup_ccd.py +++ b/setup_ccd.py @@ -10,7 +10,8 @@ import numpy as np import requests -from biotite.structure.io.pdbx import * + +from bio_datasets.structure.pdbx import * OUTPUT_CCD = ( Path(__file__).parent diff --git a/src/bio_datasets/__init__.py b/src/bio_datasets/__init__.py index a6ee70c..b0bfc49 100644 --- a/src/bio_datasets/__init__.py +++ b/src/bio_datasets/__init__.py @@ -1,6 +1,16 @@ # flake8: noqa: E402, F401 +import os +from pathlib import Path + +# Change cache location - n.b. this will also affect the datasets cache in same session +# this prevents issues with pre-cached datasets downloaded with datasets instead of bio_datasets +DEFAULT_XDG_CACHE_HOME = "~/.cache" +XDG_CACHE_HOME = os.getenv("XDG_CACHE_HOME", DEFAULT_XDG_CACHE_HOME) +DEFAULT_HF_CACHE_HOME = os.path.join(XDG_CACHE_HOME, "huggingface") +HF_CACHE_HOME = os.path.expanduser(os.getenv("HF_HOME", DEFAULT_HF_CACHE_HOME)) +os.environ["HF_DATASETS_CACHE"] = os.path.join(HF_CACHE_HOME, "bio_datasets") + import importlib -import inspect import json import logging from pathlib import Path diff --git a/src/bio_datasets/features/atom_array.py b/src/bio_datasets/features/atom_array.py index 00f06b0..3794879 100644 --- a/src/bio_datasets/features/atom_array.py +++ b/src/bio_datasets/features/atom_array.py @@ -966,15 +966,20 @@ class ProteinStructureFeature(StructureFeature): load_as: str = "complex" # biomolecule or chain or complex or biotite; if chain must be monomer _type: str = field(default="ProteinStructureFeature", init=False, repr=False) + residue_dictionary: Optional[Union[ResidueDictionary, Dict]] = None def __post_init__(self): # residue_dictionary will be set to default if not provided in constructor_kwargs - if "residue_dictionary" not in ( - self.constructor_kwargs or {} - ) and self.load_as in ["chain", "complex"]: + if self.residue_dictionary is None and self.load_as in ["chain", "complex"]: + self.residue_dictionary = ProteinDictionary.from_preset("protein") logger.info( "No residue_dictionary provided for ProteinStructureFeature, default ProteinDictionary will be used to decode." ) + self.deserialize() + + def deserialize(self): + if isinstance(self.residue_dictionary, dict): + self.residue_dictionary = ProteinDictionary(**self.residue_dictionary) def encode_example(self, value: Union[ProteinMixin, dict, bs.AtomArray]) -> dict: if isinstance(value, bs.AtomArray): @@ -995,9 +1000,13 @@ def _decode_example( "Returning biomolecule for protein-specific feature not supported." ) elif self.load_as == "chain": - return ProteinChain(atoms, **constructor_kwargs) + return ProteinChain( + atoms, residue_dictionary=self.residue_dictionary, **constructor_kwargs + ) elif self.load_as == "complex": - return ProteinComplex.from_atoms(atoms, **constructor_kwargs) + return ProteinComplex.from_atoms( + atoms, residue_dictionary=self.residue_dictionary, **constructor_kwargs + ) else: raise ValueError(f"Unsupported load_as: {self.load_as}") @@ -1041,10 +1050,6 @@ def __post_init__(self): def deserialize(self): if isinstance(self.residue_dictionary, dict): self.residue_dictionary = ProteinDictionary(**self.residue_dictionary) - elif self.all_atoms_present: - assert isinstance( - self.residue_dictionary, ProteinDictionary - ), "residue_dictionary must be a ProteinDictionary" @classmethod def from_preset(cls, preset: str, **kwargs): @@ -1060,6 +1065,7 @@ def from_preset(cls, preset: str, **kwargs): all_atoms_present=True, with_element=False, with_hetero=False, + load_as="chain", **kwargs, ) elif preset == "pdb": diff --git a/src/bio_datasets/info.py b/src/bio_datasets/info.py index 36032e0..9f3d86d 100644 --- a/src/bio_datasets/info.py +++ b/src/bio_datasets/info.py @@ -2,7 +2,7 @@ import dataclasses import json from dataclasses import asdict, dataclass -from typing import ClassVar, Dict, List, Optional +from typing import ClassVar, Dict, List from datasets.info import DatasetInfo from datasets.splits import SplitDict @@ -21,35 +21,20 @@ class DatasetInfo(DatasetInfo): but during serialisation, features needs to be fallback features (compatible with standard Datasets lib). """ - bio_features: Optional[Features] = None - _INCLUDED_INFO_IN_YAML: ClassVar[List[str]] = [ "config_name", "download_size", "dataset_size", "features", - "bio_features", "splits", ] - def __post_init__(self): - super().__post_init__() - if self.bio_features is None and self.features is not None: - self.bio_features = self.features - if self.bio_features is not None and not isinstance( - self.bio_features, Features - ): - self.bio_features = Features.from_dict(self.bio_features) - if self.bio_features is not None: - self.features = self.bio_features - def _to_yaml_dict(self) -> dict: # sometimes features are None - if self.bio_features is not None: - self.features = self.bio_features.to_fallback() + datasets_features = self.features.to_fallback() ret = super()._to_yaml_dict() - if self.bio_features is not None: - self.features = self.bio_features + ret["bio_features"] = ret["features"] + ret["features"] = datasets_features._to_yaml_list() return ret @classmethod @@ -73,10 +58,8 @@ def _dump_info(self, file, pretty_print=False): def _from_yaml_dict(cls, yaml_data: dict) -> "DatasetInfo": yaml_data = copy.deepcopy(yaml_data) if yaml_data.get("bio_features") is not None: - yaml_data["bio_features"] = Features._from_yaml_list( - yaml_data["bio_features"] - ) - if yaml_data.get("features") is not None: + yaml_data["features"] = Features._from_yaml_list(yaml_data["bio_features"]) + elif yaml_data.get("features") is not None: yaml_data["features"] = Features._from_yaml_list(yaml_data["features"]) if yaml_data.get("splits") is not None: yaml_data["splits"] = SplitDict._from_yaml_list(yaml_data["splits"]) diff --git a/src/bio_datasets/structure/parsing.py b/src/bio_datasets/structure/parsing.py index 9e9c14b..866a2e2 100644 --- a/src/bio_datasets/structure/parsing.py +++ b/src/bio_datasets/structure/parsing.py @@ -1,4 +1,5 @@ import gzip +import io import os from os import PathLike from typing import Optional @@ -325,18 +326,17 @@ def _load_foldcomp_structure( if is_open_compatible(fpath_or_handler): with open(fpath_or_handler, "rb") as fcz: fcz_binary = fcz.read() + elif isinstance(fpath_or_handler, io.BytesIO): + fcz_binary = fpath_or_handler.read() else: - raise ValueError("Unsupported file type: expected path or bytes handler") + raise ValueError( + f"Unsupported file type: expected path or bytes handler: {type(fpath_or_handler)}" + ) (_, pdb_str) = foldcomp.decompress(fcz_binary) - lines = pdb_str.splitlines() - pdbf = PDBFile() - pdbf.lines = lines - structure = pdbf.get_structure( - model=model, - extra_fields=extra_fields, - include_bonds=include_bonds, - ) - return structure + io_str = io.StringIO( + pdb_str + ) # TODO: check how pdbfile handles handler vs open type checking. + return _load_pdb_structure(io_str) def load_structure( diff --git a/src/bio_datasets/structure/pdbx/__init__.py b/src/bio_datasets/structure/pdbx/__init__.py new file mode 100644 index 0000000..2c39d2d --- /dev/null +++ b/src/bio_datasets/structure/pdbx/__init__.py @@ -0,0 +1,23 @@ +# This source code is part of the Biotite package and is distributed +# under the 3-Clause BSD License. Please see 'LICENSE.rst' for further +# information. + +""" +This subpackage provides support for the the modern PDBx file formats. +The :class:`CIFFile` class provides dictionary-like access to +every field in text-based *mmCIF* files. +:class:`BinaryCIFFile` provides analogous functionality for the +*BinaryCIF* format. +Additional utility functions allow reading and writing structures +from/to these files. +""" + +__name__ = "biotite.structure.io.pdbx" +__author__ = "Patrick Kunzmann" + +from .bcif import * +from .cif import * +from .component import * +from .compress import * +from .convert import * +from .encoding import * diff --git a/src/bio_datasets/structure/pdbx/bcif.py b/src/bio_datasets/structure/pdbx/bcif.py new file mode 100644 index 0000000..9a72640 --- /dev/null +++ b/src/bio_datasets/structure/pdbx/bcif.py @@ -0,0 +1,656 @@ +# This source code is part of the Biotite package and is distributed +# under the 3-Clause BSD License. Please see 'LICENSE.rst' for further +# information. + +__name__ = "biotite.structure.io.pdbx" +__author__ = "Patrick Kunzmann" +__all__ = [ + "BinaryCIFFile", + "BinaryCIFBlock", + "BinaryCIFCategory", + "BinaryCIFColumn", + "BinaryCIFData", +] + +from collections.abc import Sequence +import msgpack +import numpy as np +from biotite.file import File, SerializationError, is_binary, is_open_compatible +from biotite.structure.io.pdbx.component import ( + MaskValue, + _Component, + _HierarchicalContainer, +) +from bio_datasets.structure.pdbx.encoding import ( + create_uncompressed_encoding, + decode_stepwise, + deserialize_encoding, + encode_stepwise, +) + + +class BinaryCIFData(_Component): + r""" + This class represents the data in a :class:`BinaryCIFColumn`. + + Parameters + ---------- + array : array_like or int or float or str + The data array to be stored. + If a single item is given, it is converted into an array. + encoding : list of Encoding , optional + The encoding steps that are successively applied to the data. + By default, the data is stored uncompressed directly as bytes. + + Attributes + ---------- + array : ndarray + The stored data array. + encoding : list of Encoding + The encoding steps. + + Examples + -------- + + >>> data = BinaryCIFData([1, 2, 3]) + >>> print(data.array) + [1 2 3] + >>> print(len(data)) + 3 + >>> # A single item is converted into an array + >>> data = BinaryCIFData("apple") + >>> print(data.array) + ['apple'] + + A well-chosen encoding can significantly reduce the serialized data + size: + + >>> # Default uncompressed encoding + >>> array = np.arange(100) + >>> uncompressed_bytes = BinaryCIFData(array).serialize()["data"] + >>> print(len(uncompressed_bytes)) + 400 + >>> # Delta encoding followed by run-length encoding + >>> # [0, 1, 2, ...] -> [0, 1, 1, ...] -> [0, 1, 1, 99] + >>> compressed_bytes = BinaryCIFData( + ... array, + ... encoding = [ + ... # [0, 1, 2, ...] -> [0, 1, 1, ...] + ... DeltaEncoding(), + ... # [0, 1, 1, ...] -> [0, 1, 1, 99] + ... RunLengthEncoding(), + ... # [0, 1, 1, 99] -> b"\x00\x00..." + ... ByteArrayEncoding() + ... ] + ... ).serialize()["data"] + >>> print(len(compressed_bytes)) + 16 + """ + + def __init__(self, array, encoding=None): + if not isinstance(array, (Sequence, np.ndarray)) or isinstance(array, str): + array = [array] + array = np.asarray(array) + if np.issubdtype(array.dtype, np.object_): + raise ValueError("Object arrays are not supported") + + self._array = array + if encoding is None: + self._encoding = create_uncompressed_encoding(array) + else: + self._encoding = list(encoding) + + @property + def array(self): + return self._array + + @property + def encoding(self): + return self._encoding + + @staticmethod + def subcomponent_class(): + return None + + @staticmethod + def supercomponent_class(): + return BinaryCIFColumn + + @staticmethod + def deserialize(content): + encoding = [deserialize_encoding(enc) for enc in content["encoding"]] + return BinaryCIFData(decode_stepwise(content["data"], encoding), encoding) + + def serialize(self): + serialized_data = encode_stepwise(self._array, self._encoding) + if not isinstance(serialized_data, bytes): + raise SerializationError("Final encoding must return 'bytes'") + serialized_encoding = [enc.serialize() for enc in self._encoding] + return {"data": serialized_data, "encoding": serialized_encoding} + + def __len__(self): + return len(self._array) + + def __eq__(self, other): + if not isinstance(other, type(self)): + return False + if not np.array_equal(self._array, other._array): + return False + if self._encoding != other._encoding: + return False + return True + + +class BinaryCIFColumn(_Component): + """ + This class represents a single column in a :class:`CIFCategory`. + + Parameters + ---------- + data : BinaryCIFData or array_like or int or float or str + The data to be stored. + If no :class:`BinaryCIFData` is given, the passed argument is + coerced into such an object. + mask : BinaryCIFData or array_like, dtype=int or int + The mask to be stored. + If given, the mask indicates whether the `data` is + inapplicable (``.``) or missing (``?``) in some rows. + The data presence is indicated by values from the + :class:`MaskValue` enum. + If no :class:`BinaryCIFData` is given, the passed argument is + coerced into such an object. + By default, no mask is created. + + Attributes + ---------- + data : BinaryCIFData + The stored data. + mask : BinaryCIFData + The mask that indicates whether certain data elements are + inapplicable or missing. + If no mask is present, this attribute is ``None``. + + Examples + -------- + + >>> print(BinaryCIFColumn([1, 2, 3]).as_array()) + [1 2 3] + >>> mask = [MaskValue.PRESENT, MaskValue.INAPPLICABLE, MaskValue.MISSING] + >>> # Mask values are only inserted into string arrays + >>> print(BinaryCIFColumn([1, 2, 3], mask).as_array(int)) + [1 2 3] + >>> print(BinaryCIFColumn([1, 2, 3], mask).as_array(str)) + ['1' '.' '?'] + >>> print(BinaryCIFColumn([1]).as_item()) + 1 + >>> print(BinaryCIFColumn([1], mask=[MaskValue.MISSING]).as_item()) + ? + """ + + def __init__(self, data, mask=None): + if not isinstance(data, BinaryCIFData): + data = BinaryCIFData(data) + if mask is not None: + if not isinstance(mask, BinaryCIFData): + mask = BinaryCIFData(mask) + if len(data) != len(mask): + raise IndexError( + f"Data has length {len(data)}, " f"but mask has length {len(mask)}" + ) + self._data = data + self._mask = mask + + @property + def data(self): + return self._data + + @property + def mask(self): + return self._mask + + @staticmethod + def subcomponent_class(): + return BinaryCIFData + + @staticmethod + def supercomponent_class(): + return BinaryCIFCategory + + def as_item(self): + """ + Get the only item in the data of this column. + + If the data is masked as inapplicable or missing, ``'.'`` or + ``'?'`` is returned, respectively. + If the data contains more than one item, an exception is raised. + + Returns + ------- + item : str or int or float + The item in the data. + """ + if self._mask is None: + return self._data.array.item() + mask = self._mask.array.item() + if mask is None or mask == MaskValue.PRESENT: + return self._data.array.item() + elif mask == MaskValue.INAPPLICABLE: + return "." + elif mask == MaskValue.MISSING: + return "?" + + def as_array(self, dtype=None, masked_value=None): + """ + Get the data of this column as an :class:`ndarray`. + + This is a shortcut to get ``BinaryCIFColumn.data.array``. + Furthermore, the mask is applied to the data. + + Parameters + ---------- + dtype : dtype-like, optional + The data type the array should be converted to. + By default, the original type is used. + masked_value : str or int or float, optional + The value that should be used for masked elements, i.e. + ``MaskValue.INAPPLICABLE`` or ``MaskValue.MISSING``. + By default, masked elements are converted to ``'.'`` or + ``'?'`` depending on the :class:`MaskValue`. + """ + if dtype is None: + dtype = self._data.array.dtype + + if self._mask is None: + return self._data.array.astype(dtype, copy=False) + + elif np.issubdtype(dtype, np.str_): + # Copy, as otherwise original data would be overwritten + # with mask values + array = self._data.array.astype(dtype, copy=True) + if masked_value is None: + array[self._mask.array == MaskValue.INAPPLICABLE] = "." + array[self._mask.array == MaskValue.MISSING] = "?" + else: + array[self._mask.array == MaskValue.INAPPLICABLE] = masked_value + array[self._mask.array == MaskValue.MISSING] = masked_value + return array + + elif np.dtype(dtype).kind == self._data.array.dtype.kind: + if masked_value is None: + return self._data.array.astype(dtype, copy=False) + else: + array = self._data.array.astype(dtype, copy=True) + array[self._mask.array == MaskValue.INAPPLICABLE] = masked_value + array[self._mask.array == MaskValue.MISSING] = masked_value + return array + + else: + # Array needs to be converted, but masked values are + # not necessarily convertible + # (e.g. '' cannot be converted to int) + if masked_value is None: + array = np.zeros(len(self._data), dtype=dtype) + else: + array = np.full(len(self._data), masked_value, dtype=dtype) + + present_mask = self._mask.array == MaskValue.PRESENT + array[present_mask] = self._data.array[present_mask].astype(dtype) + return array + + @staticmethod + def deserialize(content): + return BinaryCIFColumn( + BinaryCIFData.deserialize(content["data"]), + BinaryCIFData.deserialize(content["mask"]) + if content["mask"] is not None + else None, + ) + + def serialize(self): + return { + "data": self._data.serialize(), + "mask": self._mask.serialize() if self._mask is not None else None, + } + + def __len__(self): + return len(self._data) + + def __eq__(self, other): + if not isinstance(other, type(self)): + return False + if self._data != other._data: + return False + if self._mask != other._mask: + return False + return True + + +class BinaryCIFCategory(_HierarchicalContainer): + """ + This class represents a category in a :class:`BinaryCIFBlock`. + + Columns can be accessed and modified like a dictionary. + The values are :class:`BinaryCIFColumn` objects. + + Parameters + ---------- + columns : dict, optional + The columns of the category. + The keys are the column names and the values are the + :class:`BinaryCIFColumn` objects (or objects that can be coerced + into a :class:`BinaryCIFColumn`). + By default, an empty category is created. + Each column must have the same length. + + Attributes + ---------- + row_count : int + The number of rows in the category, i.e. the length of each + column. + + Examples + -------- + + >>> # Add column on creation + >>> category = BinaryCIFCategory({"fruit": ["apple", "banana"]}) + >>> # Add column later on + >>> category["taste"] = ["delicious", "tasty"] + >>> # Add column the formal way + >>> category["color"] = BinaryCIFColumn(BinaryCIFData(["red", "yellow"])) + >>> # Access a column + >>> print(category["fruit"].as_array()) + ['apple' 'banana'] + """ + + def __init__(self, columns=None, row_count=None): + if columns is None: + columns = {} + else: + columns = { + key: BinaryCIFColumn(col) + if not isinstance(col, (BinaryCIFColumn, dict)) + else col + for key, col in columns.items() + } + + self._row_count = row_count + super().__init__(columns) + + @property + def row_count(self): + if self._row_count is None: + # Row count is not determined yet + # -> check the length of the first column + self._row_count = len(next(iter(self.values()))) + return self._row_count + + @staticmethod + def subcomponent_class(): + return BinaryCIFColumn + + @staticmethod + def supercomponent_class(): + return BinaryCIFBlock + + @staticmethod + def deserialize(content): + return BinaryCIFCategory( + BinaryCIFCategory._deserialize_elements(content["columns"], "name"), + content["rowCount"], + ) + + def serialize(self): + if len(self) == 0: + raise SerializationError("At least one column is required") + + for column_name, column in self.items(): + if self._row_count is None: + self._row_count = len(column) + elif len(column) != self._row_count: + raise SerializationError( + f"All columns must have the same length, " + f"but '{column_name}' has length {len(column)}, " + f"while the first column has row_count {self._row_count}" + ) + + return { + "rowCount": self.row_count, + "columns": self._serialize_elements("name"), + } + + def __setitem__(self, key, element): + if not isinstance(element, (BinaryCIFColumn, dict)): + element = BinaryCIFColumn(element) + super().__setitem__(key, element) + + +class BinaryCIFBlock(_HierarchicalContainer): + """ + This class represents a block in a :class:`BinaryCIFFile`. + + Categories can be accessed and modified like a dictionary. + The values are :class:`BinaryCIFCategory` objects. + + Parameters + ---------- + categories : dict, optional + The categories of the block. + The keys are the category names and the values are the + :class:`BinaryCIFCategory` objects. + By default, an empty block is created. + + Notes + ----- + The category names do not include the leading underscore character. + This character is automatically added when the category is + serialized. + + Examples + -------- + + >>> # Add category on creation + >>> block = BinaryCIFBlock({"foo": BinaryCIFCategory({"some_column": 1})}) + >>> # Add category later on + >>> block["bar"] = BinaryCIFCategory({"another_column": [2, 3]}) + >>> # Access a column + >>> print(block["bar"]["another_column"].as_array()) + [2 3] + """ + + def __init__(self, categories=None): + if categories is None: + categories = {} + super().__init__( + # Actual bcif files use leading '_' as category names + {"_" + name: category for name, category in categories.items()} + ) + + @staticmethod + def subcomponent_class(): + return BinaryCIFCategory + + @staticmethod + def supercomponent_class(): + return BinaryCIFFile + + @staticmethod + def deserialize(content): + return BinaryCIFBlock( + { + # The superclass uses leading '_' in category names, + # but on the level of this class, the leading '_' is omitted + name.lstrip("_"): category + for name, category in BinaryCIFBlock._deserialize_elements( + content["categories"], "name" + ).items() + } + ) + + def serialize(self): + return {"categories": self._serialize_elements("name")} + + def __getitem__(self, key): + try: + return super().__getitem__("_" + key) + except KeyError: + raise KeyError(key) + + def __setitem__(self, key, element): + try: + return super().__setitem__("_" + key, element) + except KeyError: + raise KeyError(key) + + def __delitem__(self, key): + try: + return super().__setitem__("_" + key) + except KeyError: + raise KeyError(key) + + def __iter__(self): + return (key.lstrip("_") for key in super().__iter__()) + + def __contains__(self, key): + return super().__contains__("_" + key) + + +class BinaryCIFFile(File, _HierarchicalContainer): + """ + This class represents a *BinaryCIF* file. + + The categories of the file can be accessed and modified like a + dictionary. + The values are :class:`BinaryCIFBlock` objects. + + To parse or write a structure from/to a :class:`BinaryCIFFile` + object, use the high-level :func:`get_structure()` or + :func:`set_structure()` function respectively. + + Notes + ----- + The content of *BinaryCIF* files are lazily deserialized: + Only when a column is accessed, the time consuming data decoding + is performed. + The decoded :class:`BinaryCIFBlock`/:class:`BinaryCIFCategory` + objects are cached for subsequent accesses. + + Attributes + ---------- + block : BinaryCIFBlock + The sole block of the file. + If the file contains multiple blocks, an exception is raised. + + Examples + -------- + Read a *BinaryCIF* file and access its content: + + >>> import os.path + >>> file = BinaryCIFFile.read(os.path.join(path_to_structures, "1l2y.bcif")) + >>> print(file["1L2Y"]["citation_author"]["name"].as_array()) + ['Neidigh, J.W.' 'Fesinmeyer, R.M.' 'Andersen, N.H.'] + >>> # Access the only block in the file + >>> print(file.block["entity"]["pdbx_description"].as_item()) + TC5b + + Create a *BinaryCIF* file and write it to disk: + + >>> category = BinaryCIFCategory({"some_column": "some_value"}) + >>> block = BinaryCIFBlock({"some_category": category}) + >>> file = BinaryCIFFile({"some_block": block}) + >>> file.write(os.path.join(path_to_directory, "some_file.bcif")) + """ + + def __init__(self, blocks=None): + File.__init__(self) + _HierarchicalContainer.__init__(self, blocks) + + @property + def block(self): + if len(self) != 1: + raise ValueError("There are multiple blocks in the file") + return self[next(iter(self))] + + @staticmethod + def subcomponent_class(): + return BinaryCIFBlock + + @staticmethod + def supercomponent_class(): + return None + + @staticmethod + def deserialize(content): + return BinaryCIFFile( + BinaryCIFFile._deserialize_elements(content["dataBlocks"], "header") + ) + + def serialize(self): + return {"dataBlocks": self._serialize_elements("header")} + + @classmethod + def read(cls, file): + """ + Read a *BinaryCIF* file. + + Parameters + ---------- + file : file-like object or str + The file to be read. + Alternatively a file path can be supplied. + + Returns + ------- + file_object : BinaryCIFFile + The parsed file. + """ + # File name + if is_open_compatible(file): + with open(file, "rb") as f: + return BinaryCIFFile.deserialize( + msgpack.unpackb(f.read(), use_list=True, raw=False) + ) + # File object + else: + if not is_binary(file): + raise TypeError("A file opened in 'binary' mode is required") + return BinaryCIFFile.deserialize( + msgpack.unpackb(file.read(), use_list=True, raw=False) + ) + + def write(self, file): + """ + Write contents into a *BinaryCIF* file. + + Parameters + ---------- + file : file-like object or str + The file to be written to. + Alternatively, a file path can be supplied. + """ + serialized_content = self.serialize() + serialized_content["encoder"] = "biotite" + serialized_content["version"] = "0.3.0" + packed_bytes = msgpack.packb( + serialized_content, use_bin_type=True, default=_encode_numpy + ) + if is_open_compatible(file): + with open(file, "wb") as f: + f.write(packed_bytes) + else: + if not is_binary(file): + raise TypeError("A file opened in 'binary' mode is required") + file.write(packed_bytes) + + +def _encode_numpy(item): + """ + Convert NumPy scalar types to native Python types, + as *Msgpack* cannot handle NumPy types (e.g. float32). + + The function is given to the Msgpack packer as value for the + `default` parameter. + """ + if isinstance(item, np.generic): + return item.item() + else: + raise TypeError(f"can not serialize '{type(item).__name__}' object") diff --git a/src/bio_datasets/structure/pdbx/cif.py b/src/bio_datasets/structure/pdbx/cif.py new file mode 100644 index 0000000..41e07d4 --- /dev/null +++ b/src/bio_datasets/structure/pdbx/cif.py @@ -0,0 +1,1095 @@ +# This source code is part of the Biotite package and is distributed +# under the 3-Clause BSD License. Please see 'LICENSE.rst' for further +# information. + +__name__ = "biotite.structure.io.pdbx" +__author__ = "Patrick Kunzmann" +__all__ = ["CIFFile", "CIFBlock", "CIFCategory", "CIFColumn", "CIFData"] + +import itertools +import re +from collections.abc import MutableMapping, Sequence +import numpy as np +from biotite.file import ( + DeserializationError, + File, + SerializationError, + is_open_compatible, + is_text, +) +from bio_datasets.structure.pdbx.component import MaskValue, _Component + +UNICODE_CHAR_SIZE = 4 + + +# Small class without much functionality +# It exists merely for consistency with BinaryCIFFile +class CIFData: + """ + This class represents the data in a :class:`CIFColumn`. + + Parameters + ---------- + array : array_like or int or float or str + The data array to be stored. + If a single item is given, it is converted into an array. + dtype : dtype-like, optional + If given, the *dtype* the stored array should be converted to. + + Attributes + ---------- + array : ndarray + The stored data array. + + Notes + ----- + When a :class:`CIFFile` is written, the data type is automatically + converted to string. + The other way around, when a :class:`CIFFile` is read, the data type + is always a string type. + + Examples + -------- + + >>> data = CIFData([1, 2, 3]) + >>> print(data.array) + [1 2 3] + >>> print(len(data)) + 3 + >>> # A single item is converted into an array + >>> data = CIFData("apple") + >>> print(data.array) + ['apple'] + """ + + def __init__(self, array, dtype=None): + self._array = _arrayfy(array) + if np.issubdtype(self._array.dtype, np.object_): + raise ValueError("Object arrays are not supported") + if dtype is not None: + self._array = self._array.astype(dtype) + + @property + def array(self): + return self._array + + @staticmethod + def subcomponent_class(): + return None + + @staticmethod + def supercomponent_class(): + return CIFColumn + + def __len__(self): + return len(self._array) + + def __eq__(self, other): + if not isinstance(other, type(self)): + return False + return np.array_equal(self._array, other._array) + + +class CIFColumn: + """ + This class represents a single column in a :class:`CIFCategory`. + + Parameters + ---------- + data : CIFData or array_like or int or float or str + The data to be stored. + If no :class:`CIFData` is given, the passed argument is + coerced into such an object. + mask : CIFData or array_like, dtype=int or int + The mask to be stored. + If given, the mask indicates whether the `data` is + inapplicable (``.``) or missing (``?``) in some rows. + The data presence is indicated by values from the + :class:`MaskValue` enum. + If no :class:`CIFData` is given, the passed argument is + coerced into such an object. + By default, no mask is created. + + Attributes + ---------- + data : CIFData + The stored data. + mask : CIFData + The mask that indicates whether certain data elements are + inapplicable or missing. + If no mask is present, this attribute is ``None``. + + Examples + -------- + + >>> print(CIFColumn([1, 2, 3]).as_array()) + ['1' '2' '3'] + >>> mask = [MaskValue.PRESENT, MaskValue.INAPPLICABLE, MaskValue.MISSING] + >>> print(CIFColumn([1, 2, 3], mask).as_array()) + ['1' '.' '?'] + >>> print(CIFColumn([1]).as_item()) + 1 + >>> print(CIFColumn([1], mask=[MaskValue.MISSING]).as_item()) + ? + """ + + def __init__(self, data, mask=None): + if not isinstance(data, CIFData): + data = CIFData(data, str) + if mask is None: + mask = np.full(len(data), MaskValue.PRESENT, dtype=np.uint8) + mask[data.array == "."] = MaskValue.INAPPLICABLE + mask[data.array == "?"] = MaskValue.MISSING + if np.all(mask == MaskValue.PRESENT): + # No mask required + mask = None + else: + mask = CIFData(mask) + else: + if not isinstance(mask, CIFData): + mask = CIFData(mask, np.uint8) + if len(mask) != len(data): + raise IndexError( + f"Data has length {len(data)}, " f"but mask has length {len(mask)}" + ) + self._data = data + self._mask = mask + + @property + def data(self): + return self._data + + @property + def mask(self): + return self._mask + + @staticmethod + def subcomponent_class(): + return CIFData + + @staticmethod + def supercomponent_class(): + return CIFCategory + + def as_item(self): + """ + Get the only item in the data of this column. + + If the data is masked as inapplicable or missing, ``'.'`` or + ``'?'`` is returned, respectively. + If the data contains more than one item, an exception is raised. + + Returns + ------- + item : str + The item in the data. + """ + if self._mask is None: + return self._data.array.item() + mask = self._mask.array.item() + if self._mask is None or mask == MaskValue.PRESENT: + item = self._data.array.item() + # Limit float precision to 3 decimals + if isinstance(item, float): + return f"{item:.3f}" + else: + return str(item) + elif mask == MaskValue.INAPPLICABLE: + return "." + elif mask == MaskValue.MISSING: + return "?" + + def as_array(self, dtype=str, masked_value=None): + """ + Get the data of this column as an :class:`ndarray`. + + This is a shortcut to get ``CIFColumn.data.array``. + Furthermore, the mask is applied to the data. + + Parameters + ---------- + dtype : dtype-like, optional + The data type the array should be converted to. + By default, a string type is used. + masked_value : str, optional + The value that should be used for masked elements, i.e. + ``MaskValue.INAPPLICABLE`` or ``MaskValue.MISSING``. + By default, masked elements are converted to ``'.'`` or + ``'?'`` depending on the :class:`MaskValue`. + """ + if self._mask is None: + return self._data.array.astype(dtype, copy=False) + + elif np.issubdtype(dtype, np.str_): + # Limit float precision to 3 decimals + if np.issubdtype(self._data.array.dtype, np.floating): + array = np.array([f"{e:.3f}" for e in self._data.array], type=dtype) + else: + # Copy, as otherwise original data would be overwritten + # with mask values + array = self._data.array.astype(dtype, copy=True) + if masked_value is None: + array[self._mask.array == MaskValue.INAPPLICABLE] = "." + array[self._mask.array == MaskValue.MISSING] = "?" + else: + array[self._mask.array == MaskValue.INAPPLICABLE] = masked_value + array[self._mask.array == MaskValue.MISSING] = masked_value + return array + + else: + # Array needs to be converted, but masked values are + # not necessarily convertible + # (e.g. '' cannot be converted to int) + if masked_value is None: + array = np.zeros(len(self._data), dtype=dtype) + else: + array = np.full(len(self._data), masked_value, dtype=dtype) + + present_mask = self._mask.array == MaskValue.PRESENT + array[present_mask] = self._data.array[present_mask].astype(dtype) + return array + + def __len__(self): + return len(self._data) + + def __eq__(self, other): + if not isinstance(other, type(self)): + return False + if self._data != other._data: + return False + if self._mask != other._mask: + return False + return True + + +class CIFCategory(_Component, MutableMapping): + """ + This class represents a category in a :class:`CIFBlock`. + + Columns can be accessed and modified like a dictionary. + The values are :class:`CIFColumn` objects. + + Parameters + ---------- + columns : dict, optional + The columns of the category. + The keys are the column names and the values are the + :class:`CIFColumn` objects (or objects that can be coerced into + a :class:`CIFColumn`). + By default, an empty category is created. + Each column must have the same length. + name : str, optional + The name of the category. + This is only used for serialization and is automatically set, + when the :class:`CIFCategory` is added to a :class:`CIFBlock`. + It only needs to be set manually, when the category is directly + serialized. + + Attributes + ---------- + name : str + The name of the category. + row_count : int + The number of rows in the category, i.e. the length of each + column. + + Notes + ----- + When a column containing strings with line breaks are added, these + strings are written as multiline strings to the CIF file. + + Examples + -------- + + >>> # Add column on creation + >>> category = CIFCategory({"fruit": ["apple", "banana"]}, name="fruits") + >>> # Add column later on + >>> category["taste"] = ["delicious", "tasty"] + >>> # Add column the formal way + >>> category["color"] = CIFColumn(CIFData(["red", "yellow"])) + >>> # Access a column + >>> print(category["fruit"].as_array()) + ['apple' 'banana'] + >>> print(category.serialize()) + loop_ + _fruits.fruit + _fruits.taste + _fruits.color + apple delicious red + banana tasty yellow + """ + + def __init__(self, columns=None, name=None): + self._name = name + if columns is None: + columns = {} + else: + columns = { + key: CIFColumn(col) if not isinstance(col, CIFColumn) else col + for key, col in columns.items() + } + + self._row_count = None + self._columns = columns + + @property + def name(self): + return self._name + + @name.setter + def name(self, name): + self._name = name + + @property + def row_count(self): + if self._row_count is None: + # Row count is not determined yet + # -> check the length of the first column + self._row_count = len(next(iter(self.values()))) + return self._row_count + + @staticmethod + def subcomponent_class(): + return CIFColumn + + @staticmethod + def supercomponent_class(): + return CIFBlock + + @staticmethod + def deserialize(text, expect_whitespace=True): + lines = [line.strip() for line in text.splitlines() if not _is_empty(line)] + + if _is_loop_start(lines[0]): + is_looped = True + lines.pop(0) + else: + is_looped = False + + category_name = _parse_category_name(lines[0]) + if category_name is None: + raise DeserializationError("Failed to parse category name") + + lines = _to_single(lines) + if is_looped: + category_dict = CIFCategory._deserialize_looped(lines, expect_whitespace) + else: + category_dict = CIFCategory._deserialize_single(lines) + return CIFCategory(category_dict, category_name) + + def serialize(self): + if self._name is None: + raise SerializationError("Category name is required") + if not self._columns: + raise ValueError("At least one column is required") + + for column_name, column in self.items(): + if self._row_count is None: + self._row_count = len(column) + elif len(column) != self._row_count: + raise SerializationError( + f"All columns must have the same length, " + f"but '{column_name}' has length {len(column)}, " + f"while the first column has row_count {self._row_count}" + ) + + if self._row_count == 0: + raise ValueError("At least one row is required") + elif self._row_count == 1: + lines = self._serialize_single() + else: + lines = self._serialize_looped() + # Enforce terminal line break + lines.append("") + return "\n".join(lines) + + def __getitem__(self, key): + return self._columns[key] + + def __setitem__(self, key, column): + if not isinstance(column, CIFColumn): + column = CIFColumn(column) + self._columns[key] = column + + def __delitem__(self, key): + if len(self._columns) == 1: + raise ValueError("At least one column must remain") + del self._columns[key] + + def __contains__(self, key): + return key in self._columns + + def __iter__(self): + return iter(self._columns) + + def __len__(self): + return len(self._columns) + + def __eq__(self, other): + # Row count can be omitted here, as it is based on the columns + if not isinstance(other, type(self)): + return False + if set(self.keys()) != set(other.keys()): + return False + for col_name in self.keys(): + if self[col_name] != other[col_name]: + return False + return True + + @staticmethod + def _deserialize_single(lines): + """ + Process a category where each field has a single value. + """ + category_dict = {} + line_i = 0 + while line_i < len(lines): + line = lines[line_i] + parts = _split_one_line(line) + if len(parts) == 2: + # Standard case -> name and value in one line + name_part, value_part = parts + line_i += 1 + elif len(parts) == 1: + # Value is a multiline value on the next line + name_part = parts[0] + parts = _split_one_line(lines[line_i + 1]) + if len(parts) == 1: + value_part = parts[0] + else: + raise DeserializationError(f"Failed to parse line '{line}'") + line_i += 2 + elif len(parts) == 0: + raise DeserializationError("Empty line within category") + else: + raise DeserializationError(f"Failed to parse line '{line}'") + category_dict[name_part.split(".")[1]] = CIFColumn(value_part) + return category_dict + + @staticmethod + def _deserialize_looped(lines, expect_whitespace): + """ + Process a category where each field has multiple values + (category is a table). + """ + category_dict = {} + column_names = [] + i = 0 + for key_line in lines: + if key_line[0] == "_": + # Key line + key = key_line.split(".")[1] + column_names.append(key) + category_dict[key] = [] + i += 1 + else: + break + + data_lines = lines[i:] + # Rows may be split over multiple lines -> do not rely on + # row-line-alignment at all and simply cycle through columns + column_indices = itertools.cycle(range(len(column_names))) + for data_line in data_lines: + # If whitespace is expected in quote protected values, + # use regex-based _split_one_line() to split + # Otherwise use much more faster whitespace split + # and quote removal if applicable. + if expect_whitespace: + values = _split_one_line(data_line) + else: + values = data_line.split() + for k in range(len(values)): + # Remove quotes + if (values[k][0] == '"' and values[k][-1] == '"') or ( + values[k][0] == "'" and values[k][-1] == "'" + ): + values[k] = values[k][1:-1] + for val in values: + column_index = next(column_indices) + column_name = column_names[column_index] + category_dict[column_name].append(val) + + # Check if all columns have the same length + # Otherwise, this would indicate a parsing error or an invalid CIF file + column_index = next(column_indices) + if column_index != 0: + raise DeserializationError( + "Category contains columns with different lengths" + ) + + return category_dict + + def _serialize_single(self): + keys = ["_" + self._name + "." + name for name in self.keys()] + max_len = max(len(key) for key in keys) + # "+3" Because of three whitespace chars after longest key + req_len = max_len + 3 + return [ + # Remove potential terminal newlines from multiline values + (key.ljust(req_len) + _escape(column.as_item())).strip() + for key, column in zip(keys, self.values()) + ] + + def _serialize_looped(self): + key_lines = ["_" + self._name + "." + key + " " for key in self.keys()] + + column_arrays = [] + for column in self.values(): + array = column.as_array(str) + # Quote before measuring the number of chars, + # as the quote characters modify the length + array = np.array([_escape(element) for element in array]) + column_arrays.append(array) + + # Number of characters the longest string in the column needs + # This can be deduced from the dtype + # The "+1" is for the small whitespace column + column_n_chars = [ + array.dtype.itemsize // UNICODE_CHAR_SIZE + 1 for array in column_arrays + ] + value_lines = [""] * self._row_count + for i in range(self._row_count): + for j, array in enumerate(column_arrays): + value_lines[i] += array[i].ljust(column_n_chars[j]) + # Remove trailing justification of last column + # and potential terminal newlines from multiline values + value_lines[i] = value_lines[i].strip() + + return ["loop_"] + key_lines + value_lines + + +class CIFBlock(_Component, MutableMapping): + """ + This class represents a block in a :class:`CIFFile`. + + Categories can be accessed and modified like a dictionary. + The values are :class:`CIFCategory` objects. + + Parameters + ---------- + categories : dict, optional + The categories of the block. + The keys are the category names and the values are the + :class:`CIFCategory` objects. + By default, an empty block is created. + name : str, optional + The name of the block. + This is only used for serialization and is automatically set, + when the :class:`CIFBlock` is added to a :class:`CIFFile`. + It only needs to be set manually, when the block is directly + serialized. + + Attributes + ---------- + name : str + The name of the block. + + Notes + ----- + The category names do not include the leading underscore character. + This character is automatically added when the category is + serialized. + + Examples + -------- + + >>> # Add category on creation + >>> block = CIFBlock({"foo": CIFCategory({"some_column": 1})}, name="baz") + >>> # Add category later on + >>> block["bar"] = CIFCategory({"another_column": [2, 3]}) + >>> # Access a column + >>> print(block["bar"]["another_column"].as_array()) + ['2' '3'] + >>> print(block.serialize()) + data_baz + # + _foo.some_column 1 + # + loop_ + _bar.another_column + 2 + 3 + # + """ + + def __init__(self, categories=None, name=None): + self._name = name + if categories is None: + categories = {} + self._categories = categories + + @property + def name(self): + return self._name + + @name.setter + def name(self, name): + self._name = name + + @staticmethod + def subcomponent_class(): + return CIFCategory + + @staticmethod + def supercomponent_class(): + return CIFFile + + @staticmethod + def deserialize(text): + lines = text.splitlines() + current_category_name = None + category_starts = [] + category_names = [] + for i, line in enumerate(lines): + if not _is_empty(line): + is_loop_in_line = _is_loop_start(line) + category_name_in_line = _parse_category_name(line) + if is_loop_in_line or ( + category_name_in_line != current_category_name + and category_name_in_line is not None + ): + # Track the new category + if is_loop_in_line: + # In case of lines with "loop_" the category is + # in the next line + category_name_in_line = _parse_category_name(lines[i + 1]) + current_category_name = category_name_in_line + category_starts.append(i) + category_names.append(current_category_name) + return CIFBlock(_create_element_dict(lines, category_names, category_starts)) + + def serialize(self): + if self._name is None: + raise SerializationError("Block name is required") + # The block starts with the black name line followed by a comment line + text_blocks = ["data_" + self._name + "\n#\n"] + for category_name, category in self._categories.items(): + if isinstance(category, str): + # Category is already stored as lines + text_blocks.append(category) + else: + try: + category.name = category_name + text_blocks.append(category.serialize()) + except Exception: + raise SerializationError( + f"Failed to serialize category '{category_name}'" + ) + # A comment line is set after each category + text_blocks.append("#\n") + return "".join(text_blocks) + + def __getitem__(self, key): + category = self._categories[key] + if isinstance(category, str): + # Element is stored in serialized form + # -> must be deserialized first + try: + # Special optimization for "atom_site": + # Even if the values are quote protected, + # no whitespace is expected in escaped values + # Therefore slow regex-based _split_one_line() call is not necessary + if key == "atom_site": + expect_whitespace = False + else: + expect_whitespace = True + category = CIFCategory.deserialize(category, expect_whitespace) + except Exception: + raise DeserializationError(f"Failed to deserialize category '{key}'") + # Update with deserialized object + self._categories[key] = category + return category + + def __setitem__(self, key, category): + if not isinstance(category, CIFCategory): + raise TypeError( + f"Expected 'CIFCategory', but got '{type(category).__name__}'" + ) + category.name = key + self._categories[key] = category + + def __delitem__(self, key): + del self._categories[key] + + def __contains__(self, key): + return key in self._categories + + def __iter__(self): + return iter(self._categories) + + def __len__(self): + return len(self._categories) + + def __eq__(self, other): + if not isinstance(other, type(self)): + return False + if set(self.keys()) != set(other.keys()): + return False + for cat_name in self.keys(): + if self[cat_name] != other[cat_name]: + return False + return True + + +class CIFFile(_Component, File, MutableMapping): + """ + This class represents a CIF file. + + The categories of the file can be accessed and modified like a + dictionary. + The values are :class:`CIFBlock` objects. + + To parse or write a structure from/to a :class:`CIFFile` object, + use the high-level :func:`get_structure()` or + :func:`set_structure()` function respectively. + + Notes + ----- + The content of CIF files are lazily deserialized: + When reading the file only the line positions of all blocks are + indexed. + The time consuming deserialization of a block/category is only + performed when accessed. + The deserialized :class:`CIFBlock`/:class:`CIFCategory` objects + are cached for subsequent accesses. + + Attributes + ---------- + block : CIFBlock + The sole block of the file. + If the file contains multiple blocks, an exception is raised. + + Examples + -------- + Read a CIF file and access its content: + + >>> import os.path + >>> file = CIFFile.read(os.path.join(path_to_structures, "1l2y.cif")) + >>> print(file["1L2Y"]["citation_author"]["name"].as_array()) + ['Neidigh, J.W.' 'Fesinmeyer, R.M.' 'Andersen, N.H.'] + >>> # Access the only block in the file + >>> print(file.block["entity"]["pdbx_description"].as_item()) + TC5b + + Create a CIF file and write it to disk: + + >>> category = CIFCategory( + ... {"some_column": "some_value", "another_column": "another_value"} + ... ) + >>> block = CIFBlock({"some_category": category, "another_category": category}) + >>> file = CIFFile({"some_block": block, "another_block": block}) + >>> print(file.serialize()) + data_some_block + # + _some_category.some_column some_value + _some_category.another_column another_value + # + _another_category.some_column some_value + _another_category.another_column another_value + # + data_another_block + # + _some_category.some_column some_value + _some_category.another_column another_value + # + _another_category.some_column some_value + _another_category.another_column another_value + # + >>> file.write(os.path.join(path_to_directory, "some_file.cif")) + """ + + def __init__(self, blocks=None): + if blocks is None: + blocks = {} + self._blocks = blocks + + @property + def lines(self): + return self.serialize().splitlines() + + @property + def block(self): + if len(self) != 1: + raise ValueError("There are multiple blocks in the file") + return self[next(iter(self))] + + @staticmethod + def subcomponent_class(): + return CIFBlock + + @staticmethod + def supercomponent_class(): + return None + + @staticmethod + def deserialize(text): + lines = text.splitlines() + block_starts = [] + block_names = [] + for i, line in enumerate(lines): + if not _is_empty(line): + data_block_name = _parse_data_block_name(line) + if data_block_name is not None: + block_starts.append(i) + block_names.append(data_block_name) + return CIFFile(_create_element_dict(lines, block_names, block_starts)) + + def serialize(self): + text_blocks = [] + for block_name, block in self._blocks.items(): + if isinstance(block, str): + # Block is already stored as text + text_blocks.append(block) + else: + try: + block.name = block_name + text_blocks.append(block.serialize()) + except Exception: + raise SerializationError( + f"Failed to serialize block '{block_name}'" + ) + # Enforce terminal line break + text_blocks.append("") + return "".join(text_blocks) + + @classmethod + def read(cls, file): + """ + Read a CIF file. + + Parameters + ---------- + file : file-like object or str + The file to be read. + Alternatively a file path can be supplied. + + Returns + ------- + file_object : CIFFile + The parsed file. + """ + # File name + if is_open_compatible(file): + with open(file, "r") as f: + text = f.read() + # File object + else: + if not is_text(file): + raise TypeError("A file opened in 'text' mode is required") + text = file.read() + return CIFFile.deserialize(text) + + def write(self, file): + """ + Write the contents of this object into a CIF file. + + Parameters + ---------- + file : file-like object or str + The file to be written to. + Alternatively a file path can be supplied. + """ + if is_open_compatible(file): + with open(file, "w") as f: + f.write(self.serialize()) + else: + if not is_text(file): + raise TypeError("A file opened in 'text' mode is required") + file.write(self.serialize()) + + def __getitem__(self, key): + block = self._blocks[key] + if isinstance(block, str): + # Element is stored in serialized form + # -> must be deserialized first + try: + block = CIFBlock.deserialize(block) + except Exception: + raise DeserializationError(f"Failed to deserialize block '{key}'") + # Update with deserialized object + self._blocks[key] = block + return block + + def __setitem__(self, key, block): + if not isinstance(block, CIFBlock): + raise TypeError(f"Expected 'CIFBlock', but got '{type(block).__name__}'") + block.name = key + self._blocks[key] = block + + def __delitem__(self, key): + del self._blocks[key] + + def __contains__(self, key): + return key in self._blocks + + def __iter__(self): + return iter(self._blocks) + + def __len__(self): + return len(self._blocks) + + def __eq__(self, other): + if not isinstance(other, type(self)): + return False + if set(self.keys()) != set(other.keys()): + return False + for block_name in self.keys(): + if self[block_name] != other[block_name]: + return False + return True + + +def _is_empty(line): + return len(line.strip()) == 0 or line[0] == "#" + + +def _create_element_dict(lines, element_names, element_starts): + """ + Create a dict mapping the `element_names` to the corresponding + `lines`, which are located between ``element_starts[i]`` and + ``element_starts[i+1]``. + """ + # Add exclusive stop to indices for easier slicing + element_starts.append(len(lines)) + # Lazy deserialization + # -> keep as text for now and deserialize later if needed + return { + element_name: "\n".join(lines[element_starts[i] : element_starts[i + 1]]) + "\n" + for i, element_name in enumerate(element_names) + } + + +def _parse_data_block_name(line): + """ + If the line defines a data block, return this name. + Return ``None`` otherwise. + """ + if line.startswith("data_"): + return line[5:] + else: + return None + + +def _parse_category_name(line): + """ + If the line defines a category, return this name. + Return ``None`` otherwise. + """ + if line[0] != "_": + return None + else: + return line[1 : line.find(".")] + + +def _is_loop_start(line): + """ + Return whether the line starts a looped category. + """ + return line.startswith("loop_") + + +def _to_single(lines): + r""" + Convert multiline values into singleline values + (in terms of 'lines' list elements). + Linebreaks are preserved as ``'\n'`` characters within a list element. + The initial ``';'`` character is also preserved, while the final ``';'`` character + is removed. + """ + processed_lines = [] + in_multi_line = False + mutli_line_value = [] + for line in lines: + # Multiline value are enclosed by ';' at the start of the beginning and end line + if line[0] == ";": + if not in_multi_line: + # Start of multiline value + in_multi_line = True + mutli_line_value.append(line) + else: + # End of multiline value + in_multi_line = False + # The current line contains only the end character ';' + # Hence this line is not added to the processed lines + processed_lines.append("\n".join(mutli_line_value)) + mutli_line_value = [] + else: + if in_multi_line: + mutli_line_value.append(line) + else: + processed_lines.append(line) + return processed_lines + + +def _escape(value): + """ + Escape special characters in a value to make it compatible with CIF. + """ + if "\n" in value: + # A value with linebreaks must be represented as multiline value + return _multiline(value) + elif "'" in value and '"' in value: + # If both quote types are present, you cannot use them for escaping + return _multiline(value) + elif len(value) == 0: + return "''" + elif value[0] == "_": + return "'" + value + "'" + elif "'" in value: + return '"' + value + '"' + elif '"' in value: + return "'" + value + "'" + elif " " in value: + return "'" + value + "'" + elif "\t" in value: + return "'" + value + "'" + else: + return value + + +def _multiline(value): + """ + Convert a string that may contain linebreaks into CIF-compatible + multiline string. + """ + return "\n;" + value + "\n;\n" + + +def _split_one_line(line): + """ + Split a line into its fields. + Supporting embedded quotes (' or "), like `'a dog's life'` to `a dog's life` + """ + # Special case of multiline value, where the line starts with ';' + if line[0] == ";": + return [line[1:]] + + # Define the patterns for different types of fields + single_quote_pattern = r"('(?:'(?! )|[^'])*')(?:\s|$)" + double_quote_pattern = r'("(?:"(?! )|[^"])*")(?:\s|$)' + unquoted_pattern = r"([^\s]+)" + + # Combine the patterns using alternation + combined_pattern = ( + f"{single_quote_pattern}|{double_quote_pattern}|{unquoted_pattern}" + ) + + # Find all matches + matches = re.findall(combined_pattern, line) + + # Extract non-empty groups from the matches + fields = [] + for match in matches: + field = next(group for group in match if group) + if field[0] == field[-1] == "'" or field[0] == field[-1] == '"': + field = field[1:-1] + fields.append(field) + return fields + + +def _arrayfy(data): + if not isinstance(data, (Sequence, np.ndarray)) or isinstance(data, str): + data = [data] + elif len(data) == 0: + raise ValueError("Array must contain at least one element") + return np.asarray(data) diff --git a/src/bio_datasets/structure/pdbx/component.py b/src/bio_datasets/structure/pdbx/component.py new file mode 100644 index 0000000..b48e8ce --- /dev/null +++ b/src/bio_datasets/structure/pdbx/component.py @@ -0,0 +1,245 @@ +# This source code is part of the Biotite package and is distributed +# under the 3-Clause BSD License. Please see 'LICENSE.rst' for further +# information. + +""" +This module contains internally abstract classes for representing parts +of CIF/BinaryCIF files, such as categories and columns. +""" + +__name__ = "biotite.structure.io.pdbx" +__author__ = "Patrick Kunzmann" +__all__ = ["MaskValue"] + +from abc import ABCMeta, abstractmethod +from collections.abc import MutableMapping +from enum import IntEnum +from biotite.file import DeserializationError, SerializationError + + +class MaskValue(IntEnum): + """ + This enum type represents the possible values of a mask array. + + - `PRESENT` : A value is present. + - `INAPPLICABLE` : For this row no value is applicable or + inappropriate (``.`` in *CIF*). + In some cases it may also refer to a default value for the + respective column. + - `MISSING` : For this row the value is missing or unknown + (``?`` in *CIF*). + """ + + PRESENT = 0 + INAPPLICABLE = 1 + MISSING = 2 + + +class _Component(metaclass=ABCMeta): + """ + Base class for all components in a CIF/BinaryCIF file. + """ + + @staticmethod + def subcomponent_class(): + """ + Get the class of the components that are stored in this component. + + Returns + ------- + subcomponent_class : type + The class of the subcomponent. + If this component already represents the lowest level, i.e. + it does not contain subcomponents, ``None`` is + returned. + """ + return None + + @staticmethod + def supercomponent_class(): + """ + Get the class of the component that contains this component. + + Returns + ------- + supercomponent_class : type + The class of the supercomponent. + If this component present already the highest level, i.e. + it is not contained in another component, ``None`` is + returned. + """ + return None + + @staticmethod + @abstractmethod + def deserialize(content): + """ + Create this component by deserializing the given content. + + Parameters + ---------- + content : str or dict + The content to be deserialized. + The type of this parameter depends on the file format. + In case of *CIF* files, this is the text of the lines + that represent this component. + In case of *BinaryCIF* files, this is a dictionary + parsed from the *MessagePack* data. + """ + raise NotImplementedError() + + @abstractmethod + def serialize(self): + """ + Convert this component into a Python object that can be written + to a file. + + Returns + ------- + content : str or dict + The content to be serialized. + The type of this return value depends on the file format. + In case of *CIF* files, this is the text of the lines + that represent this component. + In case of *BinaryCIF* files, this is a dictionary + that can be encoded into *MessagePack*. + """ + raise NotImplementedError() + + def __str__(self): + return str(self.serialize()) + + +class _HierarchicalContainer(_Component, MutableMapping, metaclass=ABCMeta): + """ + A container for hierarchical data in BinaryCIF files. + For example, the file contains multiple blocks, each block contains + multiple categories and each category contains multiple columns. + + It uses lazy deserialization: + A component is only deserialized from the serialized data, if it + is accessed. + The deserialized component is then cached in the container. + """ + + def __init__(self, elements=None): + if elements is None: + elements = {} + for element in elements.values(): + if not isinstance(element, (dict, self.subcomponent_class())): + raise TypeError( + f"Expected '{self.subcomponent_class().__name__}', " + f"but got '{type(element).__name__}'" + ) + self._elements = elements + + @staticmethod + def _deserialize_elements(content, take_key_from): + """ + Lazily deserialize the elements of this container. + + Parameters + ---------- + content : dict + The serialized content describing the elements for this + container. + take_key_from : str + The key in each element of `content`, whose value is used as + the key for the respective element. + + Returns + ------- + elements : dict + The elements that should be stored in this container. + This return value can be given to the constructor. + """ + elements = {} + for serialized_element in content: + key = serialized_element[take_key_from] + # Lazy deserialization + # -> keep serialized for now and deserialize later if needed + elements[key] = serialized_element + return elements + + def _serialize_elements(self, store_key_in=None): + """ + Serialize the elements that are stored in this container. + + Each element that is still serialized (due to lazy + deserialization), is kept as it is. + + Parameters + ---------- + store_key_in: str, optional + If given, the key of each element is stored as value in the + serialized element. + This is basically the reverse operation of `take_key_from` in + :meth:`_deserialize_elements()`. + """ + serialized_elements = [] + for key, element in self._elements.items(): + if isinstance(element, self.subcomponent_class()): + try: + serialized_element = element.serialize() + except Exception: + raise SerializationError(f"Failed to serialize element '{key}'") + else: + # Element is already stored in serialized form + serialized_element = element + if store_key_in is not None: + serialized_element[store_key_in] = key + serialized_elements.append(serialized_element) + return serialized_elements + + def __getitem__(self, key): + element = self._elements[key] + if not isinstance(element, self.subcomponent_class()): + # Element is stored in serialized form + # -> must be deserialized first + try: + element = self.subcomponent_class().deserialize(element) + except Exception: + raise DeserializationError(f"Failed to deserialize element '{key}'") + # Update container with deserialized object + self._elements[key] = element + return element + + def __setitem__(self, key, element): + if isinstance(element, self.subcomponent_class()): + pass + elif isinstance(element, _HierarchicalContainer): + # A common mistake may be to use the wrong container type + raise TypeError( + f"Expected '{self.subcomponent_class().__name__}', " + f"but got '{type(element).__name__}'" + ) + else: + try: + element = self.subcomponent_class().deserialize(element) + except Exception: + raise DeserializationError("Failed to deserialize given value") + self._elements[key] = element + + def __delitem__(self, key): + del self._elements[key] + + # Implement `__contains__()` explicitly, + # because the mixin method unnecessarily deserializes the value, if available + def __contains__(self, key): + return key in self._elements + + def __iter__(self): + return iter(self._elements) + + def __len__(self): + return len(self._elements) + + def __eq__(self, other): + if not isinstance(other, type(self)): + return False + if set(self.keys()) != set(other.keys()): + return False + for key in self.keys(): + if self[key] != other[key]: + return False + return True diff --git a/src/bio_datasets/structure/pdbx/compress.py b/src/bio_datasets/structure/pdbx/compress.py new file mode 100644 index 0000000..9d981ef --- /dev/null +++ b/src/bio_datasets/structure/pdbx/compress.py @@ -0,0 +1,321 @@ +__all__ = ["compress"] +__name__ = "biotite.structure.io.pdbx" +__author__ = "Patrick Kunzmann" + +import itertools +import msgpack +import numpy as np +import bio_datasets.structure.pdbx.bcif as bcif +from bio_datasets.structure.pdbx.bcif import _encode_numpy as encode_numpy +from bio_datasets.structure.pdbx.encoding import ( + ByteArrayEncoding, + DeltaEncoding, + FixedPointEncoding, + IntegerPackingEncoding, + RunLengthEncoding, + StringArrayEncoding, +) + + +def compress(data, float_tolerance=1e-6): + """ + Try to reduce the size of a *BinaryCIF* file (or block, category, etc.) by testing + different data encodings for each data array and selecting the one, which results in + the smallest size. + + Parameters + ---------- + data : BinaryCIFFile or BinaryCIFBlock or BinaryCIFCategory or BinaryCIFColumn or BinaryCIFData + The data to compress. + + Returns + ------- + compressed_file : BinaryCIFFile or BinaryCIFBlock or BinaryCIFCategory or BinaryCIFColumn or BinaryCIFData + The compressed data with the same type as the input data. + If no improved compression is found for a :class:`BinaryCIFData` array, + the input data is kept. + Hence, the return value is no deep copy of the input data. + float_tolerance : float, optional + The relative error that is accepted when compressing floating point numbers. + + Examples + -------- + + >>> from io import BytesIO + >>> pdbx_file = BinaryCIFFile() + >>> set_structure(pdbx_file, atom_array_stack) + >>> # Write uncompressed file + >>> uncompressed_file = BytesIO() + >>> pdbx_file.write(uncompressed_file) + >>> _ = uncompressed_file.seek(0) + >>> print(f"{len(uncompressed_file.read()) // 1000} KB") + 927 KB + >>> # Write compressed file + >>> pdbx_file = compress(pdbx_file) + >>> compressed_file = BytesIO() + >>> pdbx_file.write(compressed_file) + >>> _ = compressed_file.seek(0) + >>> print(f"{len(compressed_file.read()) // 1000} KB") + 111 KB + """ + match type(data): + case bcif.BinaryCIFFile: + return _compress_file(data, float_tolerance) + case bcif.BinaryCIFBlock: + return _compress_block(data, float_tolerance) + case bcif.BinaryCIFCategory: + return _compress_category(data, float_tolerance) + case bcif.BinaryCIFColumn: + return _compress_column(data, float_tolerance) + case bcif.BinaryCIFData: + return _compress_data(data, float_tolerance) + case _: + raise TypeError(f"Unsupported type {type(data).__name__}") + + +def _compress_file(bcif_file, float_tolerance): + compressed_file = bcif.BinaryCIFFile() + for block_name, bcif_block in bcif_file.items(): + compressed_block = _compress_block(bcif_block, float_tolerance) + compressed_file[block_name] = compressed_block + return compressed_file + + +def _compress_block(bcif_block, float_tolerance): + compressed_block = bcif.BinaryCIFBlock() + for category_name, bcif_category in bcif_block.items(): + compressed_category = _compress_category(bcif_category, float_tolerance) + compressed_block[category_name] = compressed_category + return compressed_block + + +def _compress_category(bcif_category, float_tolerance): + compressed_category = bcif.BinaryCIFCategory() + for column_name, bcif_column in bcif_category.items(): + compressed_column = _compress_column(bcif_column, float_tolerance) + compressed_category[column_name] = compressed_column + return compressed_category + + +def _compress_column(bcif_column, float_tolerance): + data = _compress_data(bcif_column.data, float_tolerance) + if bcif_column.mask is not None: + mask = _compress_data(bcif_column.mask, float_tolerance) + else: + mask = None + return bcif.BinaryCIFColumn(data, mask) + + +def _compress_data(bcif_data, float_tolerance): + array = bcif_data.array + if len(array) == 1: + # No need to compress a single value -> Use default uncompressed encoding + return bcif.BinaryCIFData(array) + + if np.issubdtype(array.dtype, np.str_): + # Leave encoding empty for now, as it is explicitly set later + encoding = StringArrayEncoding(data_encoding=[], offset_encoding=[]) + # Run encode to initialize the data and offset arrays + indices = encoding.encode(array) + offsets = np.cumsum([0] + [len(s) for s in encoding.strings]) + encoding.data_encoding, _ = _find_best_integer_compression(indices) + encoding.offset_encoding, _ = _find_best_integer_compression(offsets) + return bcif.BinaryCIFData(array, [encoding]) + + elif np.issubdtype(array.dtype, np.floating): + to_integer_encoding = FixedPointEncoding( + 10 ** _get_decimal_places(array, float_tolerance) + ) + integer_array = to_integer_encoding.encode(array) + best_encoding, size_compressed = _find_best_integer_compression(integer_array) + if size_compressed < _data_size_in_file(bcif.BinaryCIFData(array)): + return bcif.BinaryCIFData(array, [to_integer_encoding] + best_encoding) + else: + # The float array is smaller -> encode it directly as bytes + return bcif.BinaryCIFData(array, [ByteArrayEncoding()]) + + elif np.issubdtype(array.dtype, np.integer): + array = _to_smallest_integer_type(array) + encodings, _ = _find_best_integer_compression(array) + return bcif.BinaryCIFData(array, encodings) + + else: + raise TypeError(f"Unsupported data type {array.dtype}") + + +def _find_best_integer_compression(array): + """ + Try different data encodings on an integer array and return the one that results in + the smallest size. + """ + best_encoding_sequence = None + smallest_size = np.inf + + for use_delta in [False, True]: + if use_delta: + encoding = DeltaEncoding() + array_after_delta = encoding.encode(array) + encodings_after_delta = [encoding] + else: + encodings_after_delta = [] + array_after_delta = array + for use_run_length in [False, True]: + # Use encoded data from previous step to save time + if use_run_length: + encoding = RunLengthEncoding() + array_after_rle = encoding.encode(array_after_delta) + encodings_after_rle = encodings_after_delta + [encoding] + else: + encodings_after_rle = encodings_after_delta + array_after_rle = array_after_delta + for packed_byte_count in [None, 1, 2]: + if packed_byte_count is not None: + # Quickly check this heuristic + # to avoid computing an exploding packed data array + if ( + _estimate_packed_length(array_after_rle, packed_byte_count) + >= array_after_rle.nbytes + ): + # Packing would not reduce the size + continue + encoding = IntegerPackingEncoding(packed_byte_count) + array_after_packing = encoding.encode(array_after_rle) + encodings_after_packing = encodings_after_rle + [encoding] + else: + encodings_after_packing = encodings_after_rle + array_after_packing = array_after_rle + encoding = ByteArrayEncoding() + encoded_array = encoding.encode(array_after_packing) + encodings = encodings_after_packing + [encoding] + # Pack data directly instead of using the BinaryCIFData class + # to avoid the unnecessary re-encoding of the array, + # as it is already available in 'encoded_array' + serialized_encoding = [enc.serialize() for enc in encodings] + serialized_data = { + "data": encoded_array, + "encoding": serialized_encoding, + } + size = _data_size_in_file(serialized_data) + if size < smallest_size: + best_encoding_sequence = encodings + smallest_size = size + return best_encoding_sequence, smallest_size + + +def _estimate_packed_length(array, packed_byte_count): + """ + Estimate the length of an integer array after packing it with a given number of + bytes. + + Parameters + ---------- + array : numpy.ndarray + The array to pack. + packed_byte_count : int + The number of bytes used for packing. + + Returns + ------- + length : int + The estimated length of the packed array. + """ + # Use int64 to avoid integer overflow in the following line + max_val_per_element = np.int64(2 ** (8 * packed_byte_count)) + n_bytes_per_element = packed_byte_count * (np.abs(array // max_val_per_element) + 1) + return np.sum(n_bytes_per_element, dtype=np.int64) + + +def _to_smallest_integer_type(array): + """ + Convert an integer array to the smallest possible integer type, that is still able + to represent all values in the array. + + Parameters + ---------- + array : numpy.ndarray + The array to convert. + + Returns + ------- + array : numpy.ndarray + The converted array. + """ + if array.min() >= 0: + for dtype in [np.uint8, np.uint16, np.uint32, np.uint64]: + if np.all(array <= np.iinfo(dtype).max): + return array.astype(dtype) + for dtype in [np.int8, np.int16, np.int32, np.int64]: + if np.all(array >= np.iinfo(dtype).min) and np.all( + array <= np.iinfo(dtype).max + ): + return array.astype(dtype) + raise ValueError("Array is out of bounds for all integer types") + + +def _data_size_in_file(data): + """ + Get the size of the data, it would have when written into a *BinaryCIF* file. + + Parameters + ---------- + data : BinaryCIFData or dict + The data array whose size is measured. + Can be either a :class:`BinaryCIFData` object or already serialized data. + + Returns + ------- + size : int + The size of the data array in the file in bytes. + """ + if isinstance(data, bcif.BinaryCIFData): + data = data.serialize() + bytes_in_file = msgpack.packb(data, use_bin_type=True, default=encode_numpy) + return len(bytes_in_file) + + +def _get_decimal_places(array, tol): + """ + Get the number of decimal places in a floating point array. + + Parameters + ---------- + array : numpy.ndarray + The array to analyze. + tol : float, optional + The relative tolerance allowed when the values are cut off after the returned + number of decimal places. + + Returns + ------- + decimals : int + The number of decimal places. + """ + # Decimals of NaN or infinite values do not make sense + # and 0 would give NaN when rounding on decimals + array = array[np.isfinite(array) & (array != 0)] + for decimals in itertools.count(start=-_order_magnitude(array)): + error = np.abs(np.round(array, decimals) - array) + if np.all(error < tol * np.abs(array)): + return decimals + + +def _order_magnitude(array): + """ + Get the order of magnitude of floating point values. + + Parameters + ---------- + array : ndarray, dtype=float + The value to analyze. + + Returns + ------- + magnitude : int + The order of magnitude, i.e. the maximum exponent a number in the array would + have in scientific notation, if only one digit is left of the decimal point. + """ + array = array[array != 0] + if len(array) == 0: + # No non-zero values -> define order of magnitude as 0 + return 0 + return int(np.max(np.floor(np.log10(np.abs(array)))).item()) diff --git a/src/bio_datasets/structure/pdbx/convert.py b/src/bio_datasets/structure/pdbx/convert.py new file mode 100644 index 0000000..1b388f3 --- /dev/null +++ b/src/bio_datasets/structure/pdbx/convert.py @@ -0,0 +1,1779 @@ +# This source code is part of the Biotite package and is distributed +# under the 3-Clause BSD License. Please see 'LICENSE.rst' for further +# information. + +__name__ = "biotite.structure.io.pdbx" +__author__ = "Fabrice Allain, Patrick Kunzmann" +__all__ = [ + "get_sequence", + "get_model_count", + "get_structure", + "set_structure", + "get_component", + "set_component", + "list_assemblies", + "get_assembly", +] +import itertools +import warnings +import numpy as np +from enum import IntEnum +from biotite.file import InvalidFileError +from biotite.sequence.seqtypes import NucleotideSequence, ProteinSequence +from biotite.structure.atoms import AtomArray, AtomArrayStack, repeat +from biotite.structure.bonds import BondList, connect_via_residue_names +from biotite.structure.box import unitcell_from_vectors, vectors_from_unitcell +from biotite.structure.error import BadStructureError +from biotite.structure.filter import _canonical_aa_list as canonical_aa_list +from biotite.structure.filter import ( + _canonical_nucleotide_list as canonical_nucleotide_list, +) +from biotite.structure.filter import ( + filter_first_altloc, + filter_highest_occupancy_altloc, +) +from .bcif import ( + BinaryCIFBlock, + BinaryCIFColumn, + BinaryCIFFile, +) +from bio_datasets.structure.pdbx.cif import CIFBlock, CIFFile +from bio_datasets.structure.pdbx.component import MaskValue +from bio_datasets.structure.pdbx.encoding import StringArrayEncoding +from biotite.structure.residues import ( + get_residue_count, + get_residue_positions, + get_residue_starts_for, +) +from biotite.structure.util import matrix_rotate + + +class BondType(IntEnum): + """ + This enum type represents the type of a chemical bond. + + - `ANY` - Used if the actual type is unknown + - `SINGLE` - Single bond + - `DOUBLE` - Double bond + - `TRIPLE` - Triple bond + - `QUADRUPLE` - A quadruple bond + - `AROMATIC_SINGLE` - Aromatic bond with a single formal bond + - `AROMATIC_DOUBLE` - Aromatic bond with a double formal bond + - `AROMATIC_TRIPLE` - Aromatic bond with a triple formal bond + - `COORDINATION` - Coordination complex involving a metal atom + """ + ANY = 0 + SINGLE = 1 + DOUBLE = 2 + TRIPLE = 3 + QUADRUPLE = 4 + AROMATIC_SINGLE = 5 + AROMATIC_DOUBLE = 6 + AROMATIC_TRIPLE = 7 + COORDINATION = 8 + + +# Bond types in `struct_conn` category that refer to covalent bonds +PDBX_BOND_TYPE_ID_TO_TYPE = { + # Although a covalent bond, could in theory have a higher bond order, + # practically inter-residue bonds are always single + "covale": BondType.SINGLE, + "covale_base": BondType.SINGLE, + "covale_phosphate": BondType.SINGLE, + "covale_sugar": BondType.SINGLE, + "disulf": BondType.SINGLE, + "modres": BondType.SINGLE, + "modres_link": BondType.SINGLE, + "metalc": BondType.COORDINATION, +} +PDBX_BOND_TYPE_TO_TYPE_ID = { + BondType.ANY: "covale", + BondType.SINGLE: "covale", + BondType.DOUBLE: "covale", + BondType.TRIPLE: "covale", + BondType.QUADRUPLE: "covale", + BondType.AROMATIC_SINGLE: "covale", + BondType.AROMATIC_DOUBLE: "covale", + BondType.AROMATIC_TRIPLE: "covale", + BondType.COORDINATION: "metalc", +} +PDBX_BOND_TYPE_TO_ORDER = { + BondType.SINGLE: "sing", + BondType.DOUBLE: "doub", + BondType.TRIPLE: "trip", + BondType.QUADRUPLE: "quad", + BondType.AROMATIC_SINGLE: "sing", + BondType.AROMATIC_DOUBLE: "doub", + BondType.AROMATIC_TRIPLE: "trip", + # These are masked later, it is merely added here to avoid a KeyError + BondType.ANY: "", + BondType.COORDINATION: "", +} +# Map 'chem_comp_bond' bond orders and aromaticity to 'BondType'... +COMP_BOND_ORDER_TO_TYPE = { + ("SING", "N"): BondType.SINGLE, + ("DOUB", "N"): BondType.DOUBLE, + ("TRIP", "N"): BondType.TRIPLE, + ("QUAD", "N"): BondType.QUADRUPLE, + ("SING", "Y"): BondType.AROMATIC_SINGLE, + ("DOUB", "Y"): BondType.AROMATIC_DOUBLE, + ("TRIP", "Y"): BondType.AROMATIC_TRIPLE, +} +# ...and vice versa +COMP_BOND_TYPE_TO_ORDER = { + bond_type: order for order, bond_type in COMP_BOND_ORDER_TO_TYPE.items() +} +CANONICAL_RESIDUE_LIST = canonical_aa_list + canonical_nucleotide_list + +_proteinseq_type_list = ["polypeptide(D)", "polypeptide(L)"] +_nucleotideseq_type_list = [ + "polydeoxyribonucleotide", + "polyribonucleotide", + "polydeoxyribonucleotide/polyribonucleotide hybrid", +] +_other_type_list = [ + "cyclic-pseudo-peptide", + "other", + "peptide nucleic acid", + "polysaccharide(D)", + "polysaccharide(L)", +] + + +def _filter(category, index): + """ + Reduce the ``atom_site`` category to the values for the given + model. + """ + Category = type(category) + Column = Category.subcomponent_class() + Data = Column.subcomponent_class() + + return Category( + { + key: Column( + Data(column.data.array[index]), + (Data(column.mask.array[index]) if column.mask is not None else None), + ) + for key, column in category.items() + } + ) + + +def get_sequence(pdbx_file, data_block=None): + """ + Get the protein and nucleotide sequences from the + ``entity_poly.pdbx_seq_one_letter_code_can`` entry. + + Supported polymer types (``_entity_poly.type``) are: + ``'polypeptide(D)'``, ``'polypeptide(L)'``, + ``'polydeoxyribonucleotide'``, ``'polyribonucleotide'`` and + ``'polydeoxyribonucleotide/polyribonucleotide hybrid'``. + Uracil is converted to Thymine. + + Parameters + ---------- + pdbx_file : CIFFile or CIFBlock or BinaryCIFFile or BinaryCIFBlock + The file object. + data_block : str, optional + The name of the data block. + Default is the first (and most times only) data block of the + file. + If the data block object is passed directly to `pdbx_file`, + this parameter is ignored. + + Returns + ------- + sequence_dict : Dictionary of Sequences + Dictionary keys are derived from ``entity_poly.pdbx_strand_id`` + (often equivalent to chain_id and atom_site.auth_asym_id + in most cases). Dictionary values are sequences. + + Notes + ----- + The ``entity_poly.pdbx_seq_one_letter_code_can`` field contains the initial + complete sequence. If the structure represents a truncated or spliced + version of this initial sequence, it will include only a subset of the + initial sequence. Use biotite.structure.get_residues to retrieve only + the residues that are represented in the structure. + """ + + block = _get_block(pdbx_file, data_block) + poly_category = block["entity_poly"] + + seq_string = poly_category["pdbx_seq_one_letter_code_can"].as_array(str) + seq_type = poly_category["type"].as_array(str) + + sequences = [ + _convert_string_to_sequence(string, stype) + for string, stype in zip(seq_string, seq_type) + ] + + strand_ids = poly_category["pdbx_strand_id"].as_array(str) + strand_ids = [strand_id.split(",") for strand_id in strand_ids] + + sequence_dict = { + strand_id: sequence + for sequence, strand_ids in zip(sequences, strand_ids) + for strand_id in strand_ids + if sequence is not None + } + + return sequence_dict + + +def get_model_count(pdbx_file, data_block=None): + """ + Get the number of models contained in a file. + + Parameters + ---------- + pdbx_file : CIFFile or CIFBlock or BinaryCIFFile or BinaryCIFBlock + The file object. + data_block : str, optional + The name of the data block. + Default is the first (and most times only) data block of the + file. + If the data block object is passed directly to `pdbx_file`, + this parameter is ignored. + + Returns + ------- + model_count : int + The number of models. + """ + block = _get_block(pdbx_file, data_block) + return len( + _get_model_starts(block["atom_site"]["pdbx_PDB_model_num"].as_array(np.int32)) + ) + + +def get_structure( + pdbx_file, + model=None, + data_block=None, + altloc="first", + extra_fields=None, + use_author_fields=True, + include_bonds=False, +): + """ + Create an :class:`AtomArray` or :class:`AtomArrayStack` from the + ``atom_site`` category in a file. + + Parameters + ---------- + pdbx_file : CIFFile or CIFBlock or BinaryCIFFile or BinaryCIFBlock + The file object. + model : int, optional + If this parameter is given, the function will return an + :class:`AtomArray` from the atoms corresponding to the given + model number (starting at 1). + Negative values are used to index models starting from the last + model insted of the first model. + If this parameter is omitted, an :class:`AtomArrayStack` + containing all models will be returned, even if the structure + contains only one model. + data_block : str, optional + The name of the data block. + Default is the first (and most times only) data block of the + file. + If the data block object is passed directly to `pdbx_file`, + this parameter is ignored. + altloc : {'first', 'occupancy', 'all'} + This parameter defines how *altloc* IDs are handled: + - ``'first'`` - Use atoms that have the first *altloc* ID + appearing in a residue. + - ``'occupancy'`` - Use atoms that have the *altloc* ID + with the highest occupancy for a residue. + - ``'all'`` - Use all atoms. + Note that this leads to duplicate atoms. + When this option is chosen, the ``altloc_id`` annotation + array is added to the returned structure. + extra_fields : list of str, optional + The strings in the list are entry names, that are + additionally added as annotation arrays. + The annotation category name will be the same as the PDBx + subcategory name. + The array type is always `str`. + An exception are the special field identifiers: + ``'atom_id'``, ``'b_factor'``, ``'occupancy'`` and ``'charge'``. + These will convert the fitting subcategory into an + annotation array with reasonable type. + use_author_fields : bool, optional + Some fields can be read from two alternative sources, + for example both, ``label_seq_id`` and ``auth_seq_id`` describe + the ID of the residue. + While, the ``label_xxx`` fields can be used as official pointers + to other categories in the file, the ``auth_xxx`` + fields are set by the author(s) of the structure and are + consistent with the corresponding values in PDB files. + If `use_author_fields` is true, the annotation arrays will be + read from the ``auth_xxx`` fields (if applicable), + otherwise from the the ``label_xxx`` fields. + If the requested field is not available, the respective other + field is taken as fallback. + include_bonds : bool, optional + If set to true, a :class:`BondList` will be created for the + resulting :class:`AtomArray` containing the bond information + from the file. + Inter-residue bonds, will be read from the ``struct_conn`` + category. + Intra-residue bonds will be read from the ``chem_comp_bond``, if + available, otherwise they will be derived from the Chemical + Component Dictionary. + + Returns + ------- + array : AtomArray or AtomArrayStack + The return type depends on the `model` parameter. + + Examples + -------- + + >>> import os.path + >>> file = CIFFile.read(os.path.join(path_to_structures, "1l2y.cif")) + >>> arr = get_structure(file, model=1) + >>> print(len(arr)) + 304 + + """ + block = _get_block(pdbx_file, data_block) + + extra_fields = set() if extra_fields is None else set(extra_fields) + + atom_site = block.get("atom_site") + if atom_site is None: + raise InvalidFileError("Missing 'atom_site' category in file") + + models = atom_site["pdbx_PDB_model_num"].as_array(np.int32) + model_starts = _get_model_starts(models) + model_count = len(model_starts) + atom_count = len(models) + + if model is None: + # For a stack, the annotations are derived from the first model + model_atom_site = _filter_model(atom_site, model_starts, 1) + # Any field of the category would work here to get the length + model_length = model_atom_site.row_count + atoms = AtomArrayStack(model_count, model_length) + + # Check if each model has the same amount of atoms + # If not, raise exception + if model_length * model_count != atom_count: + raise InvalidFileError( + "The models in the file have unequal " + "amount of atoms, give an explicit model " + "instead" + ) + + atoms.coord[:, :, 0] = ( + atom_site["Cartn_x"] + .as_array(np.float32) + .reshape((model_count, model_length)) + ) + atoms.coord[:, :, 1] = ( + atom_site["Cartn_y"] + .as_array(np.float32) + .reshape((model_count, model_length)) + ) + atoms.coord[:, :, 2] = ( + atom_site["Cartn_z"] + .as_array(np.float32) + .reshape((model_count, model_length)) + ) + + box = _get_box(block) + if box is not None: + # Duplicate same box for each model + atoms.box = np.repeat(box[np.newaxis, ...], model_count, axis=0) + + else: + if model == 0: + raise ValueError("The model index must not be 0") + # Negative models mean model indexing starting from last model + model = model_count + model + 1 if model < 0 else model + if model > model_count: + raise ValueError( + f"The file has {model_count} models, " + f"the given model {model} does not exist" + ) + + model_atom_site = _filter_model(atom_site, model_starts, model) + # Any field of the category would work here to get the length + model_length = model_atom_site.row_count + atoms = AtomArray(model_length) + + atoms.coord[:, 0] = model_atom_site["Cartn_x"].as_array(np.float32) + atoms.coord[:, 1] = model_atom_site["Cartn_y"].as_array(np.float32) + atoms.coord[:, 2] = model_atom_site["Cartn_z"].as_array(np.float32) + + atoms.box = _get_box(block) + + # The below part is the same for both, AtomArray and AtomArrayStack + _fill_annotations(atoms, model_atom_site, extra_fields, use_author_fields) + if include_bonds: + if "chem_comp_bond" in block: + try: + custom_bond_dict = _parse_intra_residue_bonds(block["chem_comp_bond"]) + except KeyError: + warnings.warn( + "The 'chem_comp_bond' category has missing columns, " + "falling back to using Chemical Component Dictionary", + UserWarning, + ) + custom_bond_dict = None + bonds = connect_via_residue_names(atoms, custom_bond_dict=custom_bond_dict) + else: + bonds = connect_via_residue_names(atoms) + if "struct_conn" in block: + bonds = bonds.merge( + _parse_inter_residue_bonds(model_atom_site, block["struct_conn"]) + ) + atoms.bonds = bonds + atoms = _filter_altloc(atoms, model_atom_site, altloc) + + return atoms + + +def _get_block(pdbx_component, block_name): + if not isinstance(pdbx_component, (CIFBlock, BinaryCIFBlock)): + # Determine block + if block_name is None: + return pdbx_component.block + else: + return pdbx_component[block_name] + else: + return pdbx_component + + +def _get_or_fallback(category, key, fallback_key): + """ + Return column related to key in category if it exists, + otherwise try to get the column related to fallback key. + """ + if key not in category: + warnings.warn( + f"Attribute '{key}' not found within 'atom_site' category. " + f"The fallback attribute '{fallback_key}' will be used instead", + UserWarning, + ) + try: + return category[fallback_key] + except KeyError as key_exc: + raise InvalidFileError( + f"Fallback attribute '{fallback_key}' not found within " + "'atom_site' category" + ) from key_exc + return category[key] + + +def _fill_annotations(array, atom_site, extra_fields, use_author_fields): + """Fill atom_site annotations in atom array or atom array stack. + + Parameters + ---------- + array : AtomArray or AtomArrayStack + Atom array or stack which will be annotated. + atom_site : CIFCategory or BinaryCIFCategory + ``atom_site`` category with values for one model. + extra_fields : list of str + Entry names, that are additionally added as annotation arrays. + use_author_fields : bool + Define if alternate fields prefixed with ``auth_`` should be used + instead of ``label_``. + """ + + prefix, alt_prefix = ("auth", "label") if use_author_fields else ("label", "auth") + + array.set_annotation( + "chain_id", + _get_or_fallback( + atom_site, f"{prefix}_asym_id", f"{alt_prefix}_asym_id" + ).as_array(str), + ) + array.set_annotation( + "res_id", + _get_or_fallback( + atom_site, f"{prefix}_seq_id", f"{alt_prefix}_seq_id" + ).as_array(int, -1), + ) + array.set_annotation("ins_code", atom_site["pdbx_PDB_ins_code"].as_array(str, "")) + array.set_annotation( + "res_name", + _get_or_fallback( + atom_site, f"{prefix}_comp_id", f"{alt_prefix}_comp_id" + ).as_array(str), + ) + array.set_annotation("hetero", atom_site["group_PDB"].as_array(str) == "HETATM") + array.set_annotation( + "atom_name", + _get_or_fallback( + atom_site, f"{prefix}_atom_id", f"{alt_prefix}_atom_id" + ).as_array(str), + ) + array.set_annotation("element", atom_site["type_symbol"].as_array(str)) + + if "atom_id" in extra_fields: + if "id" in atom_site: + array.set_annotation("atom_id", atom_site["id"].as_array(int)) + else: + warnings.warn( + "Missing 'id' in 'atom_site' category. 'atom_id' generated automatically.", + UserWarning, + ) + array.set_annotation("atom_id", np.arange(array.array_length())) + extra_fields.remove("atom_id") + if "b_factor" in extra_fields: + if "B_iso_or_equiv" in atom_site: + array.set_annotation( + "b_factor", atom_site["B_iso_or_equiv"].as_array(float) + ) + else: + warnings.warn( + "Missing 'B_iso_or_equiv' in 'atom_site' category. 'b_factor' will be set to `nan`.", + UserWarning, + ) + array.set_annotation("b_factor", np.full(array.array_length(), np.nan)) + extra_fields.remove("b_factor") + if "occupancy" in extra_fields: + if "occupancy" in atom_site: + array.set_annotation("occupancy", atom_site["occupancy"].as_array(float)) + else: + warnings.warn( + "Missing 'occupancy' in 'atom_site' category. 'occupancy' will be assumed to be 1.0", + UserWarning, + ) + array.set_annotation( + "occupancy", np.ones(array.array_length(), dtype=float) + ) + extra_fields.remove("occupancy") + if "charge" in extra_fields: + if "pdbx_formal_charge" in atom_site: + array.set_annotation( + "charge", + atom_site["pdbx_formal_charge"].as_array( + int, 0 + ), # masked values are set to 0 + ) + else: + warnings.warn( + "Missing 'pdbx_formal_charge' in 'atom_site' category. 'charge' will be set to 0", + UserWarning, + ) + array.set_annotation("charge", np.zeros(array.array_length(), dtype=int)) + extra_fields.remove("charge") + + # Handle all remaining custom fields + for field in extra_fields: + array.set_annotation(field, atom_site[field].as_array(str)) + + +def _parse_intra_residue_bonds(chem_comp_bond): + """ + Create a :func:`connect_via_residue_names()` compatible + `custom_bond_dict` from the ``chem_comp_bond`` category. + """ + custom_bond_dict = {} + for res_name, atom_1, atom_2, order, aromatic_flag in zip( + chem_comp_bond["comp_id"].as_array(str), + chem_comp_bond["atom_id_1"].as_array(str), + chem_comp_bond["atom_id_2"].as_array(str), + chem_comp_bond["value_order"].as_array(str), + chem_comp_bond["pdbx_aromatic_flag"].as_array(str), + ): + if res_name not in custom_bond_dict: + custom_bond_dict[res_name] = {} + bond_type = COMP_BOND_ORDER_TO_TYPE.get( + (order.upper(), aromatic_flag), BondType.ANY + ) + custom_bond_dict[res_name][atom_1.item(), atom_2.item()] = bond_type + return custom_bond_dict + + +def _parse_inter_residue_bonds(atom_site, struct_conn): + """ + Create inter-residue bonds by parsing the ``struct_conn`` category. + The atom indices of each bond are found by matching the bond labels + to the ``atom_site`` category. + """ + # Identity symmetry operation + IDENTITY = "1_555" + # Columns in 'atom_site' that should be matched by 'struct_conn' + COLUMNS = [ + "label_asym_id", + "label_comp_id", + "label_seq_id", + "label_atom_id", + "label_alt_id", + "auth_asym_id", + "auth_comp_id", + "auth_seq_id", + "pdbx_PDB_ins_code", + ] + + covale_mask = np.isin( + struct_conn["conn_type_id"].as_array(str), + list(PDBX_BOND_TYPE_ID_TO_TYPE.keys()), + ) + if "ptnr1_symmetry" in struct_conn: + covale_mask &= struct_conn["ptnr1_symmetry"].as_array(str, IDENTITY) == IDENTITY + if "ptnr2_symmetry" in struct_conn: + covale_mask &= struct_conn["ptnr2_symmetry"].as_array(str, IDENTITY) == IDENTITY + + atom_indices = [None] * 2 + for i in range(2): + reference_arrays = [] + query_arrays = [] + for col_name in COLUMNS: + struct_conn_col_name = _get_struct_conn_col_name(col_name, i + 1) + if col_name not in atom_site or struct_conn_col_name not in struct_conn: + continue + # Ensure both arrays have the same dtype to allow comparison + reference = atom_site[col_name].as_array() + dtype = reference.dtype + query = struct_conn[struct_conn_col_name].as_array(dtype) + if np.issubdtype(reference.dtype, str): + # The mask value is not necessarily consistent + # between query and reference + # -> make it consistent + reference[reference == "?"] = "." + query[query == "?"] = "." + reference_arrays.append(reference) + query_arrays.append(query[covale_mask]) + # Match the combination of 'label_asym_id', 'label_comp_id', etc. + # in 'atom_site' and 'struct_conn' + atom_indices[i] = _find_matches(query_arrays, reference_arrays) + atoms_indices_1 = atom_indices[0] + atoms_indices_2 = atom_indices[1] + + # Some bonds in 'struct_conn' may not be found in 'atom_site' + # This is okay, + # as 'atom_site' might already be reduced to a single model + mapping_exists_mask = (atoms_indices_1 != -1) & (atoms_indices_2 != -1) + atoms_indices_1 = atoms_indices_1[mapping_exists_mask] + atoms_indices_2 = atoms_indices_2[mapping_exists_mask] + + bond_type_id = struct_conn["conn_type_id"].as_array() + # Consecutively apply the same masks as applied to the atom indices + # Logical combination does not work here, + # as the second mask was created based on already filtered data + bond_type_id = bond_type_id[covale_mask][mapping_exists_mask] + # The type ID is always present in the dictionary, + # as it was used to filter the applicable bonds + bond_types = [PDBX_BOND_TYPE_ID_TO_TYPE[type_id] for type_id in bond_type_id] + + return BondList( + atom_site.row_count, + np.stack([atoms_indices_1, atoms_indices_2, bond_types], axis=-1), + ) + + +def _find_matches(query_arrays, reference_arrays): + """ + For each index in the `query_arrays` find the indices in the + `reference_arrays` where all query values match the reference counterpart. + If no match is found for a query, the corresponding index is -1. + """ + match_masks_for_all_columns = np.stack( + [ + query[:, np.newaxis] == reference[np.newaxis, :] + for query, reference in zip(query_arrays, reference_arrays) + ], + axis=-1, + ) + match_masks = np.all(match_masks_for_all_columns, axis=-1) + query_matches, reference_matches = np.where(match_masks) + + # Duplicate matches indicate that an atom from the query cannot + # be uniquely matched to an atom in the reference + unique_query_matches, counts = np.unique(query_matches, return_counts=True) + if np.any(counts > 1): + ambiguous_query = unique_query_matches[np.where(counts > 1)[0][0]] + raise InvalidFileError( + f"The covalent bond in the 'struct_conn' category at index " + f"{ambiguous_query} cannot be unambiguously assigned to atoms in " + f"the 'atom_site' category" + ) + + # -1 indicates that no match was found in the reference + match_indices = np.full(len(query_arrays[0]), -1, dtype=int) + match_indices[query_matches] = reference_matches + return match_indices + + +def _get_struct_conn_col_name(col_name, partner): + """ + For a column name in ``atom_site`` get the corresponding column name + in ``struct_conn``. + """ + if col_name == "label_alt_id": + return f"pdbx_ptnr{partner}_label_alt_id" + elif col_name.startswith("pdbx_"): + # Move 'pdbx_' to front + return f"pdbx_ptnr{partner}_{col_name[5:]}" + else: + return f"ptnr{partner}_{col_name}" + + +def _filter_altloc(array, atom_site, altloc): + altloc_ids = atom_site.get("label_alt_id") + occupancy = atom_site.get("occupancy") + + # Filter altloc IDs and return + if altloc_ids is None: + return array + elif altloc == "occupancy" and occupancy is not None: + return array[ + ..., + filter_highest_occupancy_altloc( + array, altloc_ids.as_array(str), occupancy.as_array(float) + ), + ] + # 'first' is also fallback if file has no occupancy information + elif altloc == "first": + return array[..., filter_first_altloc(array, altloc_ids.as_array(str))] + elif altloc == "all": + array.set_annotation("altloc_id", altloc_ids.as_array(str)) + return array + else: + raise ValueError(f"'{altloc}' is not a valid 'altloc' option") + + +def _get_model_starts(model_array): + """ + Get the start index for each model in the arrays of the + ``atom_site`` category. + """ + _, indices = np.unique(model_array, return_index=True) + indices.sort() + return indices + + +def _filter_model(atom_site, model_starts, model): + """ + Reduce the ``atom_site`` category to the values for the given + model. + """ + # Append exclusive stop + model_starts = np.append(model_starts, [atom_site.row_count]) + # Indexing starts at 0, but model number starts at 1 + model_index = model - 1 + index = slice(model_starts[model_index], model_starts[model_index + 1]) + return _filter(atom_site, index) + + +def _get_box(block): + cell = block.get("cell") + if cell is None: + return None + try: + len_a, len_b, len_c = [ + float(cell[length].as_item()) + for length in ["length_a", "length_b", "length_c"] + ] + alpha, beta, gamma = [ + np.deg2rad(float(cell[angle].as_item())) + for angle in ["angle_alpha", "angle_beta", "angle_gamma"] + ] + except ValueError: + # 'cell_dict' has no proper unit cell values, e.g. '?' + return None + return vectors_from_unitcell(len_a, len_b, len_c, alpha, beta, gamma) + + +def set_structure( + pdbx_file, + array, + data_block=None, + include_bonds=False, + extra_fields=[], +): + """ + Set the ``atom_site`` category with atom information from an + :class:`AtomArray` or :class:`AtomArrayStack`. + + This will save the coordinates, the mandatory annotation categories + and the optional annotation categories + ``atom_id``, ``b_factor``, ``occupancy`` and ``charge``. + If the atom array (stack) contains the annotation ``'atom_id'``, + these values will be used for atom numbering instead of continuous + numbering. + Furthermore, inter-residue bonds will be written into the + ``struct_conn`` category. + + Parameters + ---------- + pdbx_file : CIFFile or CIFBlock or BinaryCIFFile or BinaryCIFBlock + The file object. + array : AtomArray or AtomArrayStack + The structure to be written. If a stack is given, each array in + the stack will be in a separate model. + data_block : str, optional + The name of the data block. + Default is the first (and most times only) data block of the + file. + If the data block object is passed directly to `pdbx_file`, + this parameter is ignored. + If the file is empty, a new data block will be created. + include_bonds : bool, optional + If set to true and `array` has associated ``bonds`` , the + intra-residue bonds will be written into the ``chem_comp_bond`` + category. + Inter-residue bonds will be written into the ``struct_conn`` + independent of this parameter. + extra_fields : list of str, optional + List of additional fields from the ``atom_site`` category + that should be written into the file. + Default is an empty list. + + Notes + ----- + In some cases, the written inter-residue bonds cannot be read again + due to ambiguity to which atoms the bond refers. + This is the case, when two equal residues in the same chain have + the same (or a masked) `res_id`. + + Examples + -------- + + >>> import os.path + >>> file = CIFFile() + >>> set_structure(file, atom_array) + >>> file.write(os.path.join(path_to_directory, "structure.cif")) + + """ + _check_non_empty(array) + + block = _get_or_create_block(pdbx_file, data_block) + Category = block.subcomponent_class() + Column = Category.subcomponent_class() + + # Fill PDBx columns from information + # in structures' attribute arrays as good as possible + atom_site = Category() + atom_site["group_PDB"] = np.where(array.hetero, "HETATM", "ATOM") + atom_site["type_symbol"] = np.copy(array.element) + atom_site["label_atom_id"] = np.copy(array.atom_name) + atom_site["label_alt_id"] = Column( + # AtomArrays do not store altloc atoms + np.full(array.array_length(), "."), + np.full(array.array_length(), MaskValue.INAPPLICABLE), + ) + atom_site["label_comp_id"] = np.copy(array.res_name) + atom_site["label_asym_id"] = np.copy(array.chain_id) + atom_site["label_entity_id"] = _determine_entity_id(array.chain_id) + atom_site["label_seq_id"] = np.copy(array.res_id) + atom_site["pdbx_PDB_ins_code"] = Column( + np.copy(array.ins_code), + np.where(array.ins_code == "", MaskValue.INAPPLICABLE, MaskValue.PRESENT), + ) + atom_site["auth_seq_id"] = atom_site["label_seq_id"] + atom_site["auth_comp_id"] = atom_site["label_comp_id"] + atom_site["auth_asym_id"] = atom_site["label_asym_id"] + atom_site["auth_atom_id"] = atom_site["label_atom_id"] + + annot_categories = array.get_annotation_categories() + if "atom_id" in annot_categories: + atom_site["id"] = np.copy(array.atom_id) + if "b_factor" in annot_categories: + atom_site["B_iso_or_equiv"] = np.copy(array.b_factor) + if "occupancy" in annot_categories: + atom_site["occupancy"] = np.copy(array.occupancy) + if "charge" in annot_categories: + atom_site["pdbx_formal_charge"] = Column( + np.array([f"{c:+d}" if c != 0 else "?" for c in array.charge]), + np.where(array.charge == 0, MaskValue.MISSING, MaskValue.PRESENT), + ) + + # Handle all remaining custom fields + if len(extra_fields) > 0: + # ... check to avoid clashes with standard annotations + _standard_annotations = [ + "hetero", + "element", + "atom_name", + "res_name", + "chain_id", + "res_id", + "ins_code", + "atom_id", + "b_factor", + "occupancy", + "charge", + ] + _reserved_annotation_names = list(atom_site.keys()) + _standard_annotations + + for annot in extra_fields: + if annot in _reserved_annotation_names: + raise ValueError( + f"Annotation name '{annot}' is reserved and cannot be written to as extra field. " + "Please choose another name." + ) + atom_site[annot] = np.copy(array.get_annotation(annot)) + + if array.bonds is not None: + struct_conn = _set_inter_residue_bonds(array, atom_site) + if struct_conn is not None: + block["struct_conn"] = struct_conn + if include_bonds: + chem_comp_bond = _set_intra_residue_bonds(array, atom_site) + if chem_comp_bond is not None: + block["chem_comp_bond"] = chem_comp_bond + + # In case of a single model handle each coordinate + # simply like a flattened array + if isinstance(array, AtomArray) or ( + isinstance(array, AtomArrayStack) and array.stack_depth() == 1 + ): + # 'ravel' flattens coord without copy + # in case of stack with stack_depth = 1 + atom_site["Cartn_x"] = np.copy(np.ravel(array.coord[..., 0])) + atom_site["Cartn_y"] = np.copy(np.ravel(array.coord[..., 1])) + atom_site["Cartn_z"] = np.copy(np.ravel(array.coord[..., 2])) + atom_site["pdbx_PDB_model_num"] = np.ones(array.array_length(), dtype=np.int32) + # In case of multiple models repeat annotations + # and use model specific coordinates + else: + atom_site = _repeat(atom_site, array.stack_depth()) + coord = np.reshape(array.coord, (array.stack_depth() * array.array_length(), 3)) + atom_site["Cartn_x"] = np.copy(coord[:, 0]) + atom_site["Cartn_y"] = np.copy(coord[:, 1]) + atom_site["Cartn_z"] = np.copy(coord[:, 2]) + atom_site["pdbx_PDB_model_num"] = np.repeat( + np.arange(1, array.stack_depth() + 1, dtype=np.int32), + repeats=array.array_length(), + ) + if "atom_id" not in annot_categories: + # Count from 1 + atom_site["id"] = np.arange(1, len(atom_site["group_PDB"]) + 1) + block["atom_site"] = atom_site + + # Write box into file + if array.box is not None: + # PDBx files can only store one box for all models + # -> Use first box + if array.box.ndim == 3: + box = array.box[0] + else: + box = array.box + len_a, len_b, len_c, alpha, beta, gamma = unitcell_from_vectors(box) + cell = Category() + cell["length_a"] = len_a + cell["length_b"] = len_b + cell["length_c"] = len_c + cell["angle_alpha"] = np.rad2deg(alpha) + cell["angle_beta"] = np.rad2deg(beta) + cell["angle_gamma"] = np.rad2deg(gamma) + block["cell"] = cell + + +def _check_non_empty(array): + if isinstance(array, AtomArray): + if array.array_length() == 0: + raise BadStructureError("Structure must not be empty") + elif isinstance(array, AtomArrayStack): + if array.array_length() == 0 or array.stack_depth() == 0: + raise BadStructureError("Structure must not be empty") + else: + raise ValueError( + "Structure must be AtomArray or AtomArrayStack, " + f"but got {type(array).__name__}" + ) + + +def _get_or_create_block(pdbx_component, block_name): + Block = pdbx_component.subcomponent_class() + + if isinstance(pdbx_component, (CIFFile, BinaryCIFFile)): + if block_name is None: + if len(pdbx_component) > 0: + block_name = next(iter(pdbx_component.keys())) + else: + # File is empty -> invent a new block name + block_name = "structure" + + if block_name not in pdbx_component: + block = Block() + pdbx_component[block_name] = block + return pdbx_component[block_name] + else: + # Already a block + return pdbx_component + + +def _determine_entity_id(chain_id): + entity_id = np.zeros(len(chain_id), dtype=int) + # Dictionary that translates chain_id to entity_id + id_translation = {} + id = 1 + for i in range(len(chain_id)): + try: + entity_id[i] = id_translation[chain_id[i]] + except KeyError: + # chain_id is not in dictionary -> new entry + id_translation[chain_id[i]] = id + entity_id[i] = id_translation[chain_id[i]] + id += 1 + return entity_id + + +def _repeat(category, repetitions): + Category = type(category) + Column = Category.subcomponent_class() + Data = Column.subcomponent_class() + + category_dict = {} + for key, column in category.items(): + if isinstance(column, BinaryCIFColumn): + data_encoding = column.data.encoding + # Optimization: The repeated string array has the same + # unique values, as the original string array + # -> Use same unique values (faster due to shorter array) + if isinstance(data_encoding[0], StringArrayEncoding): + data_encoding[0].strings = np.unique(column.data.array) + data = Data(np.tile(column.data.array, repetitions), data_encoding) + else: + data = Data(np.tile(column.data.array, repetitions)) + mask = ( + Data(np.tile(column.mask.array, repetitions)) + if column.mask is not None + else None + ) + category_dict[key] = Column(data, mask) + return Category(category_dict) + + +def _set_intra_residue_bonds(array, atom_site): + """ + Create the ``chem_comp_bond`` category containing the intra-residue + bonds. + ``atom_site`` is only used to infer the right :class:`Category` type + (either :class:`CIFCategory` or :class:`BinaryCIFCategory`). + """ + if (array.res_name == "").any(): + raise BadStructureError( + "Structure contains atoms with empty residue name, " + "but it is required to write intra-residue bonds" + ) + if (array.atom_name == "").any(): + raise BadStructureError( + "Structure contains atoms with empty atom name, " + "but it is required to write intra-residue bonds" + ) + + Category = type(atom_site) + Column = Category.subcomponent_class() + + bond_array = _filter_bonds(array, "intra") + if len(bond_array) == 0: + return None + value_order = np.zeros(len(bond_array), dtype="U4") + aromatic_flag = np.zeros(len(bond_array), dtype="U1") + for i, bond_type in enumerate(bond_array[:, 2]): + if bond_type == BondType.ANY: + # ANY bonds will be masked anyway, no need to set the value + continue + order, aromatic = COMP_BOND_TYPE_TO_ORDER[bond_type] + value_order[i] = order + aromatic_flag[i] = aromatic + any_mask = bond_array[:, 2] == BondType.ANY + + # Remove already existing residue and atom name combinations + # These appear when the structure contains a residue multiple times + atom_id_1 = array.atom_name[bond_array[:, 0]] + atom_id_2 = array.atom_name[bond_array[:, 1]] + # Take the residue name from the first atom index, as the residue + # name is the same for both atoms, since we have only intra bonds + comp_id = array.res_name[bond_array[:, 0]] + _, unique_indices = np.unique( + np.stack([comp_id, atom_id_1, atom_id_2], axis=-1), axis=0, return_index=True + ) + unique_indices.sort() + + chem_comp_bond = Category() + n_bonds = len(unique_indices) + chem_comp_bond["pdbx_ordinal"] = np.arange(1, n_bonds + 1, dtype=np.int32) + chem_comp_bond["comp_id"] = comp_id[unique_indices] + chem_comp_bond["atom_id_1"] = atom_id_1[unique_indices] + chem_comp_bond["atom_id_2"] = atom_id_2[unique_indices] + chem_comp_bond["value_order"] = Column( + value_order[unique_indices], + np.where(any_mask[unique_indices], MaskValue.MISSING, MaskValue.PRESENT), + ) + chem_comp_bond["pdbx_aromatic_flag"] = Column( + aromatic_flag[unique_indices], + np.where(any_mask[unique_indices], MaskValue.MISSING, MaskValue.PRESENT), + ) + # BondList does not contain stereo information + # -> all values are missing + chem_comp_bond["pdbx_stereo_config"] = Column( + np.zeros(n_bonds, dtype="U1"), + np.full(n_bonds, MaskValue.MISSING), + ) + return chem_comp_bond + + +def _set_inter_residue_bonds(array, atom_site): + """ + Create the ``struct_conn`` category containing the inter-residue + bonds. + The involved atoms are identified by annotations from the + ``atom_site`` category. + """ + COLUMNS = [ + "label_asym_id", + "label_comp_id", + "label_seq_id", + "label_atom_id", + "pdbx_PDB_ins_code", + ] + + Category = type(atom_site) + Column = Category.subcomponent_class() + + bond_array = _filter_bonds(array, "inter") + if len(bond_array) == 0: + return None + + # Filter out 'standard' links, i.e. backbone bonds between adjacent canonical + # nucleotide/amino acid residues + bond_array = bond_array[~_filter_canonical_links(array, bond_array)] + if len(bond_array) == 0: + return None + + struct_conn = Category() + struct_conn["id"] = np.arange(1, len(bond_array) + 1) + struct_conn["conn_type_id"] = [ + PDBX_BOND_TYPE_TO_TYPE_ID[btype] for btype in bond_array[:, 2] + ] + struct_conn["pdbx_value_order"] = Column( + np.array([PDBX_BOND_TYPE_TO_ORDER[btype] for btype in bond_array[:, 2]]), + np.where( + np.isin(bond_array[:, 2], (BondType.ANY, BondType.COORDINATION)), + MaskValue.MISSING, + MaskValue.PRESENT, + ), + ) + # Write the identifying annotation... + for col_name in COLUMNS: + annot = atom_site[col_name].as_array() + # ...for each bond partner + for i in range(2): + atom_indices = bond_array[:, i] + struct_conn[_get_struct_conn_col_name(col_name, i + 1)] = annot[ + atom_indices + ] + return struct_conn + + +def _filter_bonds(array, connection): + """ + Get a bonds array, that contain either only intra-residue or + only inter-residue bonds. + """ + bond_array = array.bonds.as_array() + # To save computation time call 'get_residue_starts_for()' only once + # with indices of the first and second atom of each bond + residue_starts_1, residue_starts_2 = ( + get_residue_starts_for(array, bond_array[:, :2].flatten()).reshape(-1, 2).T + ) + if connection == "intra": + return bond_array[residue_starts_1 == residue_starts_2] + elif connection == "inter": + return bond_array[residue_starts_1 != residue_starts_2] + else: + raise ValueError("Invalid 'connection' option") + + +def _filter_canonical_links(array, bond_array): + """ + Filter out peptide bonds between adjacent canonical amino acid residues. + """ + # Get the residue index for each bonded atom + residue_indices = get_residue_positions(array, bond_array[:, :2].flatten()).reshape( + -1, 2 + ) + + return ( + # Must be canonical residues + np.isin(array.res_name[bond_array[:, 0]], CANONICAL_RESIDUE_LIST) & + np.isin(array.res_name[bond_array[:, 1]], CANONICAL_RESIDUE_LIST) & + # Must be backbone bond + np.isin(array.atom_name[bond_array[:, 0]], ("C", "O3'")) & + np.isin(array.atom_name[bond_array[:, 1]], ("N", "P")) & + # Must connect adjacent residues + residue_indices[:, 1] - residue_indices[:, 0] == 1 + ) # fmt: skip + + +def get_component(pdbx_file, data_block=None, use_ideal_coord=True, res_name=None): + """ + Create an :class:`AtomArray` for a chemical component from the + ``chem_comp_atom`` and, if available, the ``chem_comp_bond`` + category in a file. + + Parameters + ---------- + pdbx_file : CIFFile or CIFBlock or BinaryCIFFile or BinaryCIFBlock + The file object. + data_block : str, optional + The name of the data block. + Default is the first (and most times only) data block of the + file. + If the data block object is passed directly to `pdbx_file`, + this parameter is ignored. + use_ideal_coord : bool, optional + If true, the *ideal* coordinates are read from the file + (``pdbx_model_Cartn__ideal`` fields), typically + originating from computations. + If set to false, alternative coordinates are read + (``model_Cartn__`` fields). + res_name : str + In rare cases the categories may contain rows for multiple + components. + In this case, the component with the given residue name is + read. + By default, all rows would be read in this case. + + Returns + ------- + array : AtomArray + The parsed chemical component. + + Examples + -------- + + >>> import os.path + >>> file = CIFFile.read( + ... os.path.join(path_to_structures, "molecules", "TYR.cif") + ... ) + >>> comp = get_component(file) + >>> print(comp) + HET 0 TYR N N 1.320 0.952 1.428 + HET 0 TYR CA C -0.018 0.429 1.734 + HET 0 TYR C C -0.103 0.094 3.201 + HET 0 TYR O O 0.886 -0.254 3.799 + HET 0 TYR CB C -0.274 -0.831 0.907 + HET 0 TYR CG C -0.189 -0.496 -0.559 + HET 0 TYR CD1 C 1.022 -0.589 -1.219 + HET 0 TYR CD2 C -1.324 -0.102 -1.244 + HET 0 TYR CE1 C 1.103 -0.282 -2.563 + HET 0 TYR CE2 C -1.247 0.210 -2.587 + HET 0 TYR CZ C -0.032 0.118 -3.252 + HET 0 TYR OH O 0.044 0.420 -4.574 + HET 0 TYR OXT O -1.279 0.184 3.842 + HET 0 TYR H H 1.977 0.225 1.669 + HET 0 TYR H2 H 1.365 1.063 0.426 + HET 0 TYR HA H -0.767 1.183 1.489 + HET 0 TYR HB2 H 0.473 -1.585 1.152 + HET 0 TYR HB3 H -1.268 -1.219 1.134 + HET 0 TYR HD1 H 1.905 -0.902 -0.683 + HET 0 TYR HD2 H -2.269 -0.031 -0.727 + HET 0 TYR HE1 H 2.049 -0.354 -3.078 + HET 0 TYR HE2 H -2.132 0.523 -3.121 + HET 0 TYR HH H -0.123 -0.399 -5.059 + HET 0 TYR HXT H -1.333 -0.030 4.784 + """ + block = _get_block(pdbx_file, data_block) + + try: + atom_category = block["chem_comp_atom"] + except KeyError: + raise InvalidFileError("Missing 'chem_comp_atom' category in file") + if res_name is not None: + atom_category = _filter( + atom_category, atom_category["comp_id"].as_array() == res_name + ) + if atom_category.row_count == 0: + raise KeyError( + f"No rows with residue name '{res_name}' found in " + f"'chem_comp_atom' category" + ) + + array = AtomArray(atom_category.row_count) + + array.set_annotation("hetero", np.full(len(atom_category["comp_id"]), True)) + array.set_annotation("res_name", atom_category["comp_id"].as_array(str)) + array.set_annotation("atom_name", atom_category["atom_id"].as_array(str)) + array.set_annotation("element", atom_category["type_symbol"].as_array(str)) + array.set_annotation("charge", atom_category["charge"].as_array(int, 0)) + + coord_fields = [f"pdbx_model_Cartn_{dim}_ideal" for dim in ("x", "y", "z")] + alt_coord_fields = [f"model_Cartn_{dim}" for dim in ("x", "y", "z")] + if not use_ideal_coord: + # Swap with the fallback option + coord_fields, alt_coord_fields = alt_coord_fields, coord_fields + try: + array.coord = _parse_component_coordinates( + [atom_category[field] for field in coord_fields] + ) + except Exception as err: + if isinstance(err, KeyError): + key = err.args[0] + warnings.warn( + f"Attribute '{key}' not found within 'chem_comp_atom' category. " + f"The fallback coordinates will be used instead", + UserWarning, + ) + elif isinstance(err, ValueError): + warnings.warn( + "The coordinates are missing for some atoms. " + "The fallback coordinates will be used instead", + UserWarning, + ) + else: + raise + array.coord = _parse_component_coordinates( + [atom_category[field] for field in alt_coord_fields] + ) + + try: + bond_category = block["chem_comp_bond"] + if res_name is not None: + bond_category = _filter( + bond_category, bond_category["comp_id"].as_array() == res_name + ) + except KeyError: + warnings.warn( + "Category 'chem_comp_bond' not found. " "No bonds will be parsed", + UserWarning, + ) + else: + bonds = BondList(array.array_length()) + for atom1, atom2, order, aromatic_flag in zip( + bond_category["atom_id_1"].as_array(str), + bond_category["atom_id_2"].as_array(str), + bond_category["value_order"].as_array(str), + bond_category["pdbx_aromatic_flag"].as_array(str), + ): + atom_i = np.where(array.atom_name == atom1)[0][0] + atom_j = np.where(array.atom_name == atom2)[0][0] + bond_type = COMP_BOND_ORDER_TO_TYPE[order, aromatic_flag] + bonds.add_bond(atom_i, atom_j, bond_type) + array.bonds = bonds + + return array + + +def _parse_component_coordinates(coord_columns): + coord = np.zeros((len(coord_columns[0]), 3), dtype=np.float32) + for i, column in enumerate(coord_columns): + if column.mask is not None and column.mask.array.any(): + raise ValueError( + "Missing coordinates for some atoms", + ) + coord[:, i] = column.as_array(np.float32) + return coord + + +def set_component(pdbx_file, array, data_block=None): + """ + Set the ``chem_comp_atom`` and, if bonds are available, + ``chem_comp_bond`` category with atom information from an + :class:`AtomArray`. + + This will save the coordinates, the mandatory annotation categories + and the optional ``charge`` category as well as an associated + :class:`BondList`, if available. + + Parameters + ---------- + pdbx_file : CIFFile or CIFBlock or BinaryCIFFile or BinaryCIFBlock + The file object. + array : AtomArray + The chemical component to be written. + Must contain only a single residue. + data_block : str, optional + The name of the data block. + Default is the first (and most times only) data block of the + file. + If the file is empty, a new data will be created. + If the data block object is passed directly to `pdbx_file`, + this parameter is ignored. + """ + _check_non_empty(array) + + block = _get_or_create_block(pdbx_file, data_block) + Category = block.subcomponent_class() + + if get_residue_count(array) > 1: + raise BadStructureError("The input atom array must comprise only one residue") + res_name = array.res_name[0] + + annot_categories = array.get_annotation_categories() + if "charge" in annot_categories: + charge = array.charge.astype("U2") + else: + charge = np.full(array.array_length(), "?", dtype="U2") + + atom_cat = Category() + atom_cat["comp_id"] = np.full(array.array_length(), res_name) + atom_cat["atom_id"] = np.copy(array.atom_name) + atom_cat["alt_atom_id"] = atom_cat["atom_id"] + atom_cat["type_symbol"] = np.copy(array.element) + atom_cat["charge"] = charge + atom_cat["model_Cartn_x"] = np.copy(array.coord[:, 0]) + atom_cat["model_Cartn_y"] = np.copy(array.coord[:, 1]) + atom_cat["model_Cartn_z"] = np.copy(array.coord[:, 2]) + atom_cat["pdbx_model_Cartn_x_ideal"] = atom_cat["model_Cartn_x"] + atom_cat["pdbx_model_Cartn_y_ideal"] = atom_cat["model_Cartn_y"] + atom_cat["pdbx_model_Cartn_z_ideal"] = atom_cat["model_Cartn_z"] + atom_cat["pdbx_component_atom_id"] = atom_cat["atom_id"] + atom_cat["pdbx_component_comp_id"] = atom_cat["comp_id"] + atom_cat["pdbx_ordinal"] = np.arange(1, array.array_length() + 1).astype(str) + block["chem_comp_atom"] = atom_cat + + if array.bonds is not None and array.bonds.get_bond_count() > 0: + bond_array = array.bonds.as_array() + order_flags = [] + aromatic_flags = [] + for bond_type in bond_array[:, 2]: + order_flag, aromatic_flag = COMP_BOND_TYPE_TO_ORDER[bond_type] + order_flags.append(order_flag) + aromatic_flags.append(aromatic_flag) + + bond_cat = Category() + bond_cat["comp_id"] = np.full(len(bond_array), res_name) + bond_cat["atom_id_1"] = array.atom_name[bond_array[:, 0]] + bond_cat["atom_id_2"] = array.atom_name[bond_array[:, 1]] + bond_cat["value_order"] = np.array(order_flags) + bond_cat["pdbx_aromatic_flag"] = np.array(aromatic_flags) + bond_cat["pdbx_ordinal"] = np.arange(1, len(bond_array) + 1).astype(str) + block["chem_comp_bond"] = bond_cat + + +def list_assemblies(pdbx_file, data_block=None): + """ + List the biological assemblies that are available for the structure + in the given file. + + This function receives the data from the ``pdbx_struct_assembly`` + category in the file. + Consequently, this category must be present in the file. + + Parameters + ---------- + pdbx_file : CIFFile or CIFBlock or BinaryCIFFile or BinaryCIFBlock + The file object. + data_block : str, optional + The name of the data block. + Default is the first (and most times only) data block of the + file. + If the data block object is passed directly to `pdbx_file`, + this parameter is ignored. + + Returns + ------- + assemblies : dict of str -> str + A dictionary that maps an assembly ID to a description of the + corresponding assembly. + + Examples + -------- + >>> import os.path + >>> file = CIFFile.read(os.path.join(path_to_structures, "1f2n.cif")) + >>> assembly_ids = list_assemblies(file) + >>> for key, val in assembly_ids.items(): + ... print(f"'{key}' : '{val}'") + '1' : 'complete icosahedral assembly' + '2' : 'icosahedral asymmetric unit' + '3' : 'icosahedral pentamer' + '4' : 'icosahedral 23 hexamer' + '5' : 'icosahedral asymmetric unit, std point frame' + '6' : 'crystal asymmetric unit, crystal frame' + """ + block = _get_block(pdbx_file, data_block) + + try: + assembly_category = block["pdbx_struct_assembly"] + except KeyError: + raise InvalidFileError("File has no 'pdbx_struct_assembly' category") + return { + id: details + for id, details in zip( + assembly_category["id"].as_array(str), + assembly_category["details"].as_array(str), + ) + } + + +def get_assembly( + pdbx_file, + assembly_id=None, + model=None, + data_block=None, + altloc="first", + extra_fields=None, + use_author_fields=True, + include_bonds=False, + include_sym_id=False, +): + """ + Build the given biological assembly. + + This function receives the data from the + ``pdbx_struct_assembly_gen``, ``pdbx_struct_oper_list`` and + ``atom_site`` categories in the file. + Consequently, these categories must be present in the file. + + Parameters + ---------- + pdbx_file : CIFFile or CIFBlock or BinaryCIFFile or BinaryCIFBlock + The file object. + assembly_id : str + The assembly to build. + Available assembly IDs can be obtained via + :func:`list_assemblies()`. + model : int, optional + If this parameter is given, the function will return an + :class:`AtomArray` from the atoms corresponding to the given + model number (starting at 1). + Negative values are used to index models starting from the last + model insted of the first model. + If this parameter is omitted, an :class:`AtomArrayStack` + containing all models will be returned, even if the structure + contains only one model. + data_block : str, optional + The name of the data block. + Default is the first (and most times only) data block of the + file. + If the data block object is passed directly to `pdbx_file`, + this parameter is ignored. + altloc : {'first', 'occupancy', 'all'} + This parameter defines how *altloc* IDs are handled: + - ``'first'`` - Use atoms that have the first *altloc* ID + appearing in a residue. + - ``'occupancy'`` - Use atoms that have the *altloc* ID + with the highest occupancy for a residue. + - ``'all'`` - Use all atoms. + Note that this leads to duplicate atoms. + When this option is chosen, the ``altloc_id`` annotation + array is added to the returned structure. + extra_fields : list of str, optional + The strings in the list are entry names, that are + additionally added as annotation arrays. + The annotation category name will be the same as the PDBx + subcategory name. + The array type is always `str`. + An exception are the special field identifiers: + ``'atom_id'``, ``'b_factor'``, ``'occupancy'`` and ``'charge'``. + These will convert the fitting subcategory into an + annotation array with reasonable type. + use_author_fields : bool, optional + Some fields can be read from two alternative sources, + for example both, ``label_seq_id`` and ``auth_seq_id`` describe + the ID of the residue. + While, the ``label_xxx`` fields can be used as official pointers + to other categories in the file, the ``auth_xxx`` + fields are set by the author(s) of the structure and are + consistent with the corresponding values in PDB files. + If `use_author_fields` is true, the annotation arrays will be + read from the ``auth_xxx`` fields (if applicable), + otherwise from the the ``label_xxx`` fields. + include_bonds : bool, optional + If set to true, a :class:`BondList` will be created for the + resulting :class:`AtomArray` containing the bond information + from the file. + Bonds, whose order could not be determined from the + *Chemical Component Dictionary* + (e.g. especially inter-residue bonds), + have :attr:`BondType.ANY`, since the PDB format itself does + not support bond orders. + include_sym_id : bool, optional + If set to true, the ``sym_id`` annotation array is added to the + returned structure. This array identifies the set of symmetry operations + that was applied to the corresponding chain in the asymmetric unit. + The sym_id annotation corresponds to ids in the ``pdbx_struct_oper_list`` + category, separated by `-`. + + Returns + ------- + assembly : AtomArray or AtomArrayStack + The assembly. The return type depends on the `model` parameter. + + Examples + -------- + + >>> import os.path + >>> file = CIFFile.read(os.path.join(path_to_structures, "1f2n.cif")) + >>> assembly = get_assembly(file, model=1) + """ + block = _get_block(pdbx_file, data_block) + + try: + assembly_gen_category = block["pdbx_struct_assembly_gen"] + except KeyError: + raise InvalidFileError("File has no 'pdbx_struct_assembly_gen' category") + + try: + struct_oper_category = block["pdbx_struct_oper_list"] + except KeyError: + raise InvalidFileError("File has no 'pdbx_struct_oper_list' category") + + assembly_ids = assembly_gen_category["assembly_id"].as_array(str) + if assembly_id is None: + assembly_id = assembly_ids[0] + elif assembly_id not in assembly_ids: + raise KeyError(f"File has no Assembly ID '{assembly_id}'") + + ### Calculate all possible transformations + transformations = _get_transformations(struct_oper_category) + + ### Get structure according to additional parameters + # Include 'label_asym_id' as annotation array + # for correct asym ID filtering + extra_fields = [] if extra_fields is None else extra_fields + if "label_asym_id" in extra_fields: + extra_fields_and_asym = extra_fields + else: + # The operations apply on asym IDs + # -> they need to be included to select the correct atoms + extra_fields_and_asym = extra_fields + ["label_asym_id"] + structure = get_structure( + pdbx_file, + model, + data_block, + altloc, + extra_fields_and_asym, + use_author_fields, + include_bonds, + ) + + ### Get transformations and apply them to the affected asym IDs + assembly = None + for id, op_expr, asym_id_expr in zip( + assembly_gen_category["assembly_id"].as_array(str), + assembly_gen_category["oper_expression"].as_array(str), + assembly_gen_category["asym_id_list"].as_array(str), + ): + # Find the operation expressions for given assembly ID + # We already asserted that the ID is actually present + if id == assembly_id: + operations = _parse_operation_expression(op_expr) + asym_ids = asym_id_expr.split(",") + # Filter affected asym IDs + sub_structure = structure[..., np.isin(structure.label_asym_id, asym_ids)] + sub_assembly = _apply_transformations( + sub_structure, transformations, operations, include_sym_id + ) + # Merge the chains with asym IDs for this operation + # with chains from other operations + if assembly is None: + assembly = sub_assembly + else: + assembly += sub_assembly + + # Remove 'label_asym_id', if it was not included in the original + # user-supplied 'extra_fields' + if "label_asym_id" not in extra_fields: + assembly.del_annotation("label_asym_id") + + return assembly + + +def _apply_transformations(structure, transformation_dict, operations, include_sym_id): + """ + Get subassembly by applying the given operations to the input + structure containing affected asym IDs. + """ + # Additional first dimesion for 'structure.repeat()' + assembly_coord = np.zeros((len(operations),) + structure.coord.shape) + assembly_transform_ids = [] + # Apply corresponding transformation for each copy in the assembly + for i, operation in enumerate(operations): + coord = structure.coord + # Execute for each transformation step + # in the operation expression + for op_step in operation: + rotation_matrix, translation_vector = transformation_dict[op_step] + # Rotate + coord = matrix_rotate(coord, rotation_matrix) + # Translate + coord += translation_vector + + assembly_transform_ids.append( + np.full(len(structure), "-".join(list(operation))) + ) + assembly_coord[i] = coord + + assembly = repeat(structure, assembly_coord) + if include_sym_id: + assembly.set_annotation("sym_id", np.concatenate(assembly_transform_ids)) + return assembly + + +def _get_transformations(struct_oper): + """ + Get transformation operation in terms of rotation matrix and + translation for each operation ID in ``pdbx_struct_oper_list``. + """ + transformation_dict = {} + for index, id in enumerate(struct_oper["id"].as_array(str)): + rotation_matrix = np.array( + [ + [ + struct_oper[f"matrix[{i}][{j}]"].as_array(float)[index] + for j in (1, 2, 3) + ] + for i in (1, 2, 3) + ] + ) + translation_vector = np.array( + [struct_oper[f"vector[{i}]"].as_array(float)[index] for i in (1, 2, 3)] + ) + transformation_dict[id] = (rotation_matrix, translation_vector) + return transformation_dict + + +def _parse_operation_expression(expression): + """ + Get successive operation steps (IDs) for the given + ``oper_expression``. + Form the cartesian product, if necessary. + """ + # Split groups by parentheses: + # use the opening parenthesis as delimiter + # and just remove the closing parenthesis + # example: '(X0)(1-10,21-25)' from 1a34 + expressions_per_step = expression.replace(")", "").split("(") + expressions_per_step = [e for e in expressions_per_step if len(e) > 0] + # Important: Operations are applied from right to left + expressions_per_step.reverse() + + operations = [] + for one_step_expr in expressions_per_step: + one_step_op_ids = [] + for expr in one_step_expr.split(","): + if "-" in expr: + # Range of operation IDs, they must be integers + first, last = expr.split("-") + one_step_op_ids.extend( + [str(id) for id in range(int(first), int(last) + 1)] + ) + else: + # Single operation ID + one_step_op_ids.append(expr) + operations.append(one_step_op_ids) + + # Cartesian product of operations + return list(itertools.product(*operations)) + + +def _convert_string_to_sequence(string, stype): + """ + Convert strings to `ProteinSequence` if `stype` is contained in + ``proteinseq_type_list`` or to ``NucleotideSequence`` if `stype` is + contained in ``_nucleotideseq_type_list``. + """ + # sequence may be stored as multiline string + string = string.replace("\n", "") + if stype in _proteinseq_type_list: + return ProteinSequence(string) + elif stype in _nucleotideseq_type_list: + string = string.replace("U", "T") + return NucleotideSequence(string) + elif stype in _other_type_list: + return None + else: + raise InvalidFileError("mmCIF _entity_poly.type unsupported" " type: " + stype) diff --git a/src/bio_datasets/structure/pdbx/encoding.pyx b/src/bio_datasets/structure/pdbx/encoding.pyx new file mode 100644 index 0000000..7ad95d2 --- /dev/null +++ b/src/bio_datasets/structure/pdbx/encoding.pyx @@ -0,0 +1,1031 @@ +# This source code is part of the Biotite package and is distributed +# under the 3-Clause BSD License. Please see 'LICENSE.rst' for further +# information. + +""" +This module contains data encodings for BinaryCIF files. +""" + +__name__ = "biotite.structure.io.pdbx" +__author__ = "Patrick Kunzmann" +__all__ = ["ByteArrayEncoding", "FixedPointEncoding", + "IntervalQuantizationEncoding", "RunLengthEncoding", + "DeltaEncoding", "IntegerPackingEncoding", "StringArrayEncoding", + "TypeCode"] + +cimport cython +cimport numpy as np + +from dataclasses import dataclass +from abc import ABCMeta, abstractmethod +from numbers import Integral +from enum import IntEnum +import re +import numpy as np +from .component import _Component +from biotite.file import InvalidFileError + +ctypedef np.int8_t int8 +ctypedef np.int16_t int16 +ctypedef np.int32_t int32 +ctypedef np.uint8_t uint8 +ctypedef np.uint16_t uint16 +ctypedef np.uint32_t uint32 +ctypedef np.float32_t float32 +ctypedef np.float64_t float64 + +ctypedef fused Integer: + uint8 + uint16 + uint32 + int8 + int16 + int32 + +# Used to create cartesian product of type combinations +# in run-length encoding +ctypedef fused OutputInteger: + uint8 + uint16 + uint32 + int8 + int16 + int32 + +ctypedef fused Float: + float32 + float64 + + +CAMEL_CASE_PATTERN = re.compile(r"(?>> data = np.arange(3) + >>> print(data) + [0 1 2] + >>> print(ByteArrayEncoding().encode(data)) + b'\x00\x00\x00\x00\x01\x00\x00\x00\x02\x00\x00\x00' + """ + type: ... = None + + def __post_init__(self): + if self.type is not None: + self.type = TypeCode.from_dtype(self.type) + + def encode(self, data): + if self.type is None: + self.type = TypeCode.from_dtype(data.dtype) + return _safe_cast(data, self.type.to_dtype()).tobytes() + + def decode(self, data): + # Data is raw bytes in this case + return np.frombuffer(data, dtype=self.type.to_dtype()) + + +@dataclass +class FixedPointEncoding(Encoding): + """ + Lossy encoding that multiplies floating point values with a given + factor and subsequently rounds them to the nearest integer. + + Parameters + ---------- + factor : float + The factor by which the data is multiplied before rounding. + src_type : dtype or TypeCode, optional + The data type of the array to be encoded. + Either a NumPy dtype or a *BinaryCIF* type code is accepted. + The dtype must be a float type. + If omitted, the data type is taken from the data the + first time :meth:`encode()` is called. + + Attributes + ---------- + factor : float + src_type : TypeCode + + Examples + -------- + + >>> data = np.array([9.87, 6.543]) + >>> print(data) + [9.870 6.543] + >>> print(FixedPointEncoding(factor=100).encode(data)) + [987 654] + """ + factor: ... + src_type: ... = None + + def __post_init__(self): + if self.src_type is not None: + self.src_type = TypeCode.from_dtype(self.src_type) + if self.src_type not in (TypeCode.FLOAT32, TypeCode.FLOAT64): + raise ValueError( + "Only floating point types are supported" + ) + + def encode(self, data): + # If not given in constructor, it is determined from the data + if self.src_type is None: + self.src_type = TypeCode.from_dtype(data.dtype) + if self.src_type not in (TypeCode.FLOAT32, TypeCode.FLOAT64): + raise ValueError( + "Only floating point types are supported" + ) + + # Round to avoid wrong values due to floating point inaccuracies + return np.round(data * self.factor).astype(np.int32) + + def decode(self, data): + return (data / self.factor).astype( + dtype=self.src_type.to_dtype(), copy=False + ) + + +@dataclass +class IntervalQuantizationEncoding(Encoding): + """ + Lossy encoding that sorts floating point values into bins. + Each bin is represented by an integer + + Parameters + ---------- + min, max : float + The minimum and maximum value the bins comprise. + num_steps : int + The number of bins. + src_type : dtype or TypeCode, optional + The data type of the array to be encoded. + Either a NumPy dtype or a *BinaryCIF* type code is accepted. + The dtype must be a float type. + If omitted, the data type is taken from the data the + first time :meth:`encode()` is called. + + Attributes + ---------- + min, max : float + num_steps : int + src_type : TypeCode + + Examples + -------- + + >>> data = np.linspace(11, 12, 6) + >>> print(data) + [11.0 11.2 11.4 11.6 11.8 12.0] + >>> # Use 0.5 as step size + >>> encoding = IntervalQuantizationEncoding(min=10, max=20, num_steps=21) + >>> # The encoding is lossy, as different values are mapped to the same bin + >>> encoded = encoding.encode(data) + >>> print(encoded) + [2 3 3 4 4 4] + >>> decoded = encoding.decode(encoded) + >>> print(decoded) + [11.0 11.5 11.5 12.0 12.0 12.0] + """ + min: ... + max: ... + num_steps: ... + src_type: ... = None + + def __post_init__(self): + if self.src_type is not None: + self.src_type = TypeCode.from_dtype(self.src_type) + + def encode(self, data): + # If not given in constructor, it is determined from the data + if self.src_type is None: + self.src_type = TypeCode.from_dtype(data.dtype) + + steps = np.linspace( + self.min, self.max, self.num_steps, dtype=data.dtype + ) + indices = np.searchsorted(steps, data, side="left") + return indices.astype(np.int32, copy=False) + + def decode(self, data): + output = data * (self.max - self.min) / (self.num_steps - 1) + output = output.astype(self.src_type.to_dtype(), copy=False) + output += self.min + return output + + +@dataclass +class RunLengthEncoding(Encoding): + """ + Encoding that compresses runs of equal values into pairs of + (value, run length). + + Parameters + ---------- + src_size : int, optional + The size of the array to be encoded. + If omitted, the size is determined from the data the + first time :meth:`encode()` is called. + src_type : dtype or TypeCode, optional + The data type of the array to be encoded. + Either a NumPy dtype or a *BinaryCIF* type code is accepted. + The dtype must be a integer type. + If omitted, the data type is taken from the data the + first time :meth:`encode()` is called. + + Attributes + ---------- + src_size : int + src_type : TypeCode + + Examples + -------- + + >>> data = np.array([1, 1, 1, 5, 3, 3]) + >>> print(data) + [1 1 1 5 3 3] + >>> encoded = RunLengthEncoding().encode(data) + >>> print(encoded) + [1 3 5 1 3 2] + >>> # Emphasize the the pairs + >>> print(encoded.reshape(-1, 2)) + [[1 3] + [5 1] + [3 2]] + """ + src_size: ... = None + src_type: ... = None + + def __post_init__(self): + if self.src_type is not None: + self.src_type = TypeCode.from_dtype(self.src_type) + + def encode(self, data): + # If not given in constructor, it is determined from the data + if self.src_type is None: + self.src_type = TypeCode.from_dtype(data.dtype) + if self.src_size is None: + self.src_size = data.shape[0] + elif self.src_size != data.shape[0]: + raise IndexError( + "Given source size does not match actual data size" + ) + return self._encode(_safe_cast(data, self.src_type.to_dtype())) + + def decode(self, data): + return self._decode( + data, np.empty(0, dtype=self.src_type.to_dtype()) + ) + + def _encode(self, const Integer[:] data): + # Pessimistic allocation of output array + # -> Run length is 1 for every element + cdef int32[:] output = np.zeros(data.shape[0] * 2, dtype=np.int32) + cdef int i=0, j=0 + cdef int val = data[0] + cdef int run_length = 0 + cdef int curr_val + for i in range(data.shape[0]): + curr_val = data[i] + if curr_val == val: + run_length += 1 + else: + # New element -> Write element with run-length + output[j] = val + output[j+1] = run_length + j += 2 + val = curr_val + run_length = 1 + # Write last element + output[j] = val + output[j+1] = run_length + j += 2 + # Trim to correct size + return np.asarray(output)[:j] + + def _decode(self, const Integer[:] data, OutputInteger[:] output_type): + """ + `output_type` is merely a typed placeholder to allow for static + typing of output. + """ + if data.shape[0] % 2 != 0: + raise ValueError("Invalid run-length encoded data") + + cdef int length = 0 + cdef int i, j + cdef int value, repeat + + if self.src_size is None: + # Determine length of output array by summing run lengths + for i in range(1, data.shape[0], 2): + length += data[i] + else: + length = self.src_size + + cdef OutputInteger[:] output = np.zeros( + length, dtype=np.asarray(output_type).dtype + ) + # Fill output array + j = 0 + for i in range(0, data.shape[0], 2): + value = data[i] + repeat = data[i+1] + output[j : j+repeat] = value + j += repeat + return np.asarray(output) + + +@dataclass +class DeltaEncoding(Encoding): + """ + Encoding that encodes an array of integers into an array of + consecutive differences. + + Parameters + ---------- + src_type : dtype or TypeCode, optional + The data type of the array to be encoded. + Either a NumPy dtype or a *BinaryCIF* type code is accepted. + The dtype must be a integer type. + If omitted, the data type is taken from the data the + first time :meth:`encode()` is called. + origin : int, optional + The starting value from which the differences are calculated. + If omitted, the value is taken from the first array element the + first time :meth:`encode()` is called. + + Attributes + ---------- + src_type : TypeCode + origin : int + + Examples + -------- + + >>> data = np.array([1, 1, 2, 3, 5, 8]) + >>> encoding = DeltaEncoding() + >>> print(encoding.encode(data)) + [0 0 1 1 2 3] + >>> print(encoding.origin) + 1 + """ + src_type: ... = None + origin: ... = None + + def __post_init__(self): + if self.src_type is not None: + self.src_type = TypeCode.from_dtype(self.src_type) + + def encode(self, data): + # If not given in constructor, it is determined from the data + if self.src_type is None: + self.src_type = TypeCode.from_dtype(data.dtype) + if self.origin is None: + self.origin = data[0] + + data = data - self.origin + return np.diff(data, prepend=0).astype(np.int32, copy=False) + + def decode(self, data): + output = np.cumsum(data, dtype=self.src_type.to_dtype()) + output += self.origin + return output + + +@dataclass +class IntegerPackingEncoding(Encoding): + """ + Encoding that compresses an array of 32-bit integers into an array + of smaller sized integers. + + If a value does not fit into smaller integer type, + the integer is represented by a sum of consecutive elements + in the compressed array. + + Parameters + ---------- + byte_count : int + The number of bytes the packed integers should occupy. + Supported values are 1 and 2 for 8-bit and 16-bit integers, + respectively. + src_size : int, optional + The size of the array to be encoded. + If omitted, the size is determined from the data the + first time :meth:`encode()` is called. + is_unsigned : bool, optional + Whether the values should be packed into signed or unsigned + integers. + If omitted, first time :meth:`encode()` is called, determines whether + the values fit into unsigned integers. + + Attributes + ---------- + byte_count : int + src_size : int + is_unsigned : bool + + Examples + -------- + + >>> data = np.array([1, 2, -3, 128]) + >>> print(data) + [ 1 2 -3 128] + >>> print(IntegerPackingEncoding(byte_count=1).encode(data)) + [ 1 2 -3 127 1] + """ + byte_count: ... + src_size: ... = None + is_unsigned: ... = None + + def encode(self, data): + if self.src_size is None: + self.src_size = len(data) + elif self.src_size != len(data): + raise IndexError( + "Given source size does not match actual data size" + ) + if self.is_unsigned is None: + # Only positive values -> use unsigned integers + self.is_unsigned = data.min().item() >= 0 + + data = data.astype(np.int32, copy=False) + return self._encode( + data, np.empty(0, dtype=self._determine_packed_dtype()) + ) + + def decode(self, const Integer[:] data): + cdef int i, j + cdef int min_val, max_val + cdef int packed_val, unpacked_val + bounds = self._get_bounds(data) + min_val = bounds[0] + max_val = bounds[1] + # For signed integers, do not check lower bound (is always 0) + # -> Set lower bound to value that is never reached + if min_val == 0: + min_val = -1 + + cdef int32[:] output = np.zeros(self.src_size, dtype=np.int32) + j = 0 + unpacked_val = 0 + for i in range(data.shape[0]): + packed_val = data[i] + if packed_val == max_val or packed_val == min_val: + unpacked_val += packed_val + else: + unpacked_val += packed_val + output[j] = unpacked_val + unpacked_val = 0 + j += 1 + # Trim to correct size and return + return np.asarray(output) + + def _determine_packed_dtype(self): + if self.byte_count == 1: + if self.is_unsigned: + return np.uint8 + else: + return np.int8 + elif self.byte_count == 2: + if self.is_unsigned: + return np.uint16 + else: + return np.int16 + else: + raise ValueError("Unsupported byte count") + + @cython.cdivision(True) + def _encode(self, const Integer[:] data, OutputInteger[:] output_type): + """ + `output_type` is merely a typed placeholder to allow for static + typing of output. + """ + cdef int i=0, j=0 + + packed_type = np.asarray(output_type).dtype + cdef int min_val = np.iinfo(packed_type).min + cdef int max_val = np.iinfo(packed_type).max + + # Get length of output array + # by summing up required length of each element + cdef int number + cdef long length = 0 + for i in range(data.shape[0]): + number = data[i] + if number < 0: + if min_val == 0: + raise ValueError( + "Cannot pack negative numbers into unsigned type" + ) + # The required packed length for an element is the + # number of times min_val/max_val need to be repeated + length += number // min_val + 1 + elif number > 0: + length += number // max_val + 1 + else: + # number = 0 + length += 1 + + # Fill output + cdef OutputInteger[:] output = np.zeros(length, dtype=packed_type) + cdef int remainder + j = 0 + for i in range(data.shape[0]): + remainder = data[i] + if remainder < 0: + if min_val == 0: + raise ValueError( + "Cannot pack negative numbers into unsigned type" + ) + while remainder <= min_val: + remainder -= min_val + output[j] = min_val + j += 1 + elif remainder > 0: + while remainder >= max_val: + remainder -= max_val + output[j] = max_val + j += 1 + output[j] = remainder + j += 1 + return np.asarray(output) + + @staticmethod + def _get_bounds(const Integer[:] data): + if Integer is int8: + info = np.iinfo(np.int8) + elif Integer is int16: + info = np.iinfo(np.int16) + elif Integer is int32: + info = np.iinfo(np.int32) + elif Integer is uint8: + info = np.iinfo(np.uint8) + elif Integer is uint16: + info = np.iinfo(np.uint16) + elif Integer is uint32: + info = np.iinfo(np.uint32) + else: + raise ValueError("Unsupported integer type") + return info.min, info.max + + +@dataclass +class StringArrayEncoding(Encoding): + """ + Encoding that compresses an array of strings into an array of + indices that point to the unique strings in that array. + + The unique strings themselves are stored as part of the + :class:`StringArrayEncoding` as concatenated string. + The start index of each unique string in the concatenated string + is stored in an *offset* array. + + Parameters + ---------- + strings : ndarray, optional + The unique strings that are used for encoding. + If omitted, the unique strings are determined from the data the + first time :meth:`encode()` is called. + data_encoding : list of Encoding, optional + The encodings that are applied to the index array. + If omitted, the array is directly encoded into bytes without + further compression. + offset_encoding : list of Encoding, optional + The encodings that are applied to the offset array. + If omitted, the array is directly encoded into bytes without + further compression. + + Attributes + ---------- + strings : ndarray + data_encoding : list of Encoding + offset_encoding : list of Encoding + + Examples + -------- + + >>> data = np.array(["apple", "banana", "cherry", "apple", "banana", "apple"]) + >>> print(data) + ['apple' 'banana' 'cherry' 'apple' 'banana' 'apple'] + >>> # By default the indices would directly be encoded into bytes + >>> # However, the indices should be printed here -> data_encoding=[] + >>> encoding = StringArrayEncoding(data_encoding=[]) + >>> encoded = encoding.encode(data) + >>> print(encoding.strings) + ['apple' 'banana' 'cherry'] + >>> print(encoded) + [0 1 2 0 1 0] + """ + + strings: ... = None + data_encoding: ... = None + offset_encoding: ... = None + + def __init__(self, strings=None, data_encoding=None, offset_encoding=None): + self.strings = strings + if data_encoding is None: + data_encoding = [ByteArrayEncoding(TypeCode.INT32)] + self.data_encoding = data_encoding + if offset_encoding is None: + offset_encoding = [ByteArrayEncoding(TypeCode.INT32)] + self.offset_encoding = offset_encoding + + @staticmethod + def deserialize(content): + data_encoding = [ + deserialize_encoding(e) for e in content["dataEncoding"] + ] + offset_encoding = [ + deserialize_encoding(e) for e in content["offsetEncoding"] + ] + cdef str concatenated_strings = content["stringData"] + cdef np.ndarray offsets = decode_stepwise( + content["offsets"], offset_encoding + ) + + strings = np.array([ + concatenated_strings[offsets[i]:offsets[i+1]] + # The final offset is the exclusive stop index + for i in range(len(offsets)-1) + ], dtype="U") + + return StringArrayEncoding(strings, data_encoding, offset_encoding) + + def serialize(self): + if self.strings is None: + raise ValueError( + "'strings' must be explicitly given or needs to be " + "determined from first encoding pass, before it is serialized" + ) + + string_data = "".join(self.strings) + offsets = np.cumsum([0] + [len(s) for s in self.strings]) + + return { + "kind": "StringArray", + "dataEncoding": [e.serialize() for e in self.data_encoding], + "stringData": string_data, + "offsets": encode_stepwise(offsets, self.offset_encoding), + "offsetEncoding": [e.serialize() for e in self.offset_encoding], + } + + def encode(self, data): + if not np.issubdtype(data.dtype, np.str_): + raise TypeError("Data must be of string type") + + if self.strings is None: + # 'unique()' already sorts the strings, but this is not necessarily + # desired, as this makes efficient encoding of the indices more difficult + # -> Bring into the original order + _, unique_indices = np.unique(data, return_index=True) + self.strings = data[np.sort(unique_indices)] + check_present = False + else: + check_present = True + + string_order = np.argsort(self.strings).astype(np.int32) + sorted_strings = self.strings[string_order] + sorted_indices = np.searchsorted(sorted_strings, data) + indices = string_order[sorted_indices] + if check_present and not np.all(self.strings[indices] == data): + raise ValueError("Data contains strings not present in 'strings'") + return encode_stepwise(indices, self.data_encoding) + + def decode(self, data): + indices = decode_stepwise(data, self.data_encoding) + return self.strings[indices] + + def __eq__(self, other): + if not isinstance(other, type(self)): + return False + if not np.array_equal(self.strings, other.strings): + return False + if self.data_encoding != other.data_encoding: + return False + if self.offset_encoding != other.offset_encoding: + return False + return True + + +_encoding_classes = { + "ByteArray": ByteArrayEncoding, + "FixedPoint": FixedPointEncoding, + "IntervalQuantization": IntervalQuantizationEncoding, + "RunLength": RunLengthEncoding, + "Delta": DeltaEncoding, + "IntegerPacking": IntegerPackingEncoding, + "StringArray": StringArrayEncoding, +} +_encoding_classes_kinds = { + "ByteArrayEncoding": "ByteArray", + "FixedPointEncoding": "FixedPoint", + "IntervalQuantizationEncoding": "IntervalQuantization", + "RunLengthEncoding": "RunLength", + "DeltaEncoding": "Delta", + "IntegerPackingEncoding": "IntegerPacking", + "StringArrayEncoding": "StringArray", +} + + +def deserialize_encoding(content): + """ + Create a :class:`Encoding` by deserializing the given *BinaryCIF* content. + + Parameters + ---------- + content : dict + The encoding represenet as *BinaryCIF* dictionary. + + Returns + ------- + encoding : Encoding + The deserialized encoding. + """ + try: + encoding_class = _encoding_classes[content["kind"]] + except KeyError: + raise ValueError( + f"Unknown encoding kind '{content['kind']}'" + ) + return encoding_class.deserialize(content) + + +def create_uncompressed_encoding(array): + """ + Create a simple encoding for the given array that does not compress the data. + + Parameters + ---------- + array : ndarray + The array to to create the encoding for. + + Returns + ------- + encoding : list of Encoding + The encoding for the data. + """ + if np.issubdtype(array.dtype, np.str_): + return [StringArrayEncoding()] + else: + return [ByteArrayEncoding()] + + +def encode_stepwise(data, encoding): + """ + Apply a list of encodings stepwise to the given data. + + Parameters + ---------- + data : ndarray + The data to be encoded. + encoding : list of Encoding + The encodings to be applied. + + Returns + ------- + encoded_data : ndarray or bytes + The encoded data. + """ + for encoding in encoding: + data = encoding.encode(data) + return data + + +def decode_stepwise(data, encoding): + """ + Apply a list of encodings stepwise to the given data. + + Parameters + ---------- + data : ndarray or bytes + The data to be decoded. + encoding : list of Encoding + The encodings to be applied. + + Returns + ------- + decoded_data : ndarray + The decoded data. + """ + for enc in reversed(encoding): + data = enc.decode(data) + return data + + +def _camel_to_snake_case(attribute_name): + return CAMEL_CASE_PATTERN.sub("_", attribute_name).lower() + + +def _snake_to_camel_case(attribute_name): + attribute_name = "".join( + word.capitalize() for word in attribute_name.split("_") + ) + return attribute_name[0].lower() + attribute_name[1:] + + +def _safe_cast(array, dtype): + dtype = np.dtype(dtype) + if dtype == array.dtype: + return array + if np.issubdtype(dtype, np.integer): + if not np.issubdtype(array.dtype, np.integer): + raise ValueError("Cannot cast floating point to integer") + dtype_info = np.iinfo(dtype) + if np.any(array < dtype_info.min) or np.any(array > dtype_info.max): + raise ValueError("Integer values do not fit into the given dtype") + return array.astype(dtype) + + +def _get_n_decimals(value, tolerance): + MAX_DECIMALS = 10 + for n in range(MAX_DECIMALS): + if abs(value - round(value, n)) < tolerance: + return n + return MAX_DECIMALS \ No newline at end of file diff --git a/src/bio_datasets/structure/protein/protein.py b/src/bio_datasets/structure/protein/protein.py index d78b90d..f233466 100644 --- a/src/bio_datasets/structure/protein/protein.py +++ b/src/bio_datasets/structure/protein/protein.py @@ -249,20 +249,14 @@ class ProteinChain(ProteinMixin, BiomoleculeChain): def __init__( self, atoms: bs.AtomArray, - residue_dictionary: Optional[ResidueDictionary] = None, + residue_dictionary: ResidueDictionary, verbose: bool = False, backbone_only: bool = False, keep_hydrogens: bool = False, - keep_oxt: bool = False, replace_unexpected_with_unknown: bool = False, raise_error_on_unexpected: bool = False, ): - if residue_dictionary is None: - residue_dictionary = ProteinDictionary.from_preset( - "protein", keep_oxt=keep_oxt - ) - else: - assert keep_oxt == getattr(residue_dictionary, "keep_oxt", False) + assert residue_dictionary is not None super().__init__( atoms, residue_dictionary=residue_dictionary, diff --git a/src/bio_datasets/structure/residue.py b/src/bio_datasets/structure/residue.py index 00a2e11..be31f65 100644 --- a/src/bio_datasets/structure/residue.py +++ b/src/bio_datasets/structure/residue.py @@ -312,7 +312,7 @@ def from_ccd( chem_component_3to1 = get_component_3to1() chem_component_categories = get_component_categories(get_component_types()) frequencies = get_residue_frequencies() - res_names = np.unique(ccd_data["chem_comp_atom"]["comp_id"].as_array(str)) + res_names = list(np.unique(ccd_data["chem_comp_atom"]["comp_id"].as_array(str))) def keep_res(res_name): res_filter = frequencies.get(res_name, 0) >= minimum_pdb_entries @@ -326,6 +326,10 @@ def keep_res(res_name): res_filter = res_filter and ( keep_hydrogens or res_name not in ["H", "D", "D8U"] ) + res_filter = res_filter and ( + res_name in chem_component_3to1 + and res_name in chem_component_categories + ) return res_filter res_names = [res for res in res_names if keep_res(res)] diff --git a/tests/proteins/test_proteins.py b/tests/proteins/test_proteins.py index 609ad9c..ed9beab 100644 --- a/tests/proteins/test_proteins.py +++ b/tests/proteins/test_proteins.py @@ -4,7 +4,7 @@ from biotite.structure.residues import residue_iter from bio_datasets.structure.parsing import load_structure -from bio_datasets.structure.protein import ProteinChain +from bio_datasets.structure.protein import ProteinChain, ProteinDictionary from bio_datasets.structure.protein import constants as protein_constants expected_residue_atoms = { @@ -140,7 +140,9 @@ def test_fill_missing_atoms(pdb_atoms_top7): """ pdb_atom_array = pdb_atoms_top7[filter_amino_acids(pdb_atoms_top7)] # 1qys has missing atoms - protein = ProteinChain(pdb_atom_array) + protein = ProteinChain( + pdb_atom_array, residue_dictionary=ProteinDictionary.from_preset("protein") + ) # todo check for nans for raw_residue, filled_residue in zip( residue_iter(pdb_atom_array[filter_amino_acids(pdb_atom_array)]),