Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds script to import cctbx data to rs #264

Merged
merged 39 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
a2a2343
adds script to import cctbx data to rs
dermen Aug 10, 2024
36b5873
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 10, 2024
592007b
removes cctbx dependency
dermen Aug 11, 2024
8db20f9
merged precommit
dermen Aug 11, 2024
cd94855
merged precommit 2
dermen Aug 11, 2024
1233625
rejig to io
dermen Aug 20, 2024
fa209fe
adds mpi support
dermen Aug 20, 2024
f19bae4
addresses review
dermen Aug 25, 2024
0106d9d
removes comment
dermen Aug 25, 2024
45b8ddd
addresses review pt2
dermen Aug 26, 2024
c52719f
uses better names
dermen Aug 26, 2024
d1c4424
cleans up io __init__
dermen Aug 27, 2024
c2d84fd
adds support for more columns
dermen Aug 27, 2024
b2afcb5
adds verbose flag for read_dials_stills
dermen Aug 28, 2024
f315274
more debug statements
dermen Aug 28, 2024
45707ca
Merge remote-tracking branch 'upstream/main'
dermen Sep 3, 2024
5f850e5
unit tests for refl table reader
dermen Sep 3, 2024
b28bd46
adds back in the comma
dermen Sep 3, 2024
c71dd06
cleanup
dermen Sep 3, 2024
d30d18f
more cleanup
dermen Sep 4, 2024
a1e07bb
use tempfile, remove __main__ for production
Sep 10, 2024
37ce079
refactor, add mpi test with dummy comms
Sep 12, 2024
36c3376
Merge pull request #1 from kmdalton/stills
dermen Sep 17, 2024
53f6021
get ray_context
dermen Sep 17, 2024
493630c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 17, 2024
cb7570b
Update common.py
dermen Sep 17, 2024
479104d
Update common.py
dermen Sep 17, 2024
7f7dd26
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 17, 2024
7fd1208
allow nan inference for float64
Sep 18, 2024
a365d34
remove dtype inference from read_dials_stills
Sep 18, 2024
cc1117b
make cell/sg optional. improve docstring
Sep 18, 2024
73e1ed4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 18, 2024
8e32b56
fix docstring
Sep 19, 2024
af61760
make dtype inference optional
Sep 19, 2024
bce6f80
test dtype inference toggle
Sep 19, 2024
58b862d
test mtz_dtypes flag and mtz writing
Sep 19, 2024
1fc6f7a
no need for a list of files
Sep 19, 2024
6c248e4
separate test for mtz io
Sep 19, 2024
1d0b88d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions reciprocalspaceship/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from reciprocalspaceship.io.ccp4map import write_ccp4_map
from reciprocalspaceship.io.crystfel import read_crystfel
from reciprocalspaceship.io.csv import read_csv
from reciprocalspaceship.io.dials import print_refl_info, read_dials_stills
from reciprocalspaceship.io.mtz import (
from_gemmi,
read_cif,
Expand Down
49 changes: 49 additions & 0 deletions reciprocalspaceship/io/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import logging
import warnings
from contextlib import contextmanager
from importlib.util import find_spec


def set_ray_loglevel(level):
logger = logging.getLogger("ray")
logger.setLevel(level)
for handler in logger.handlers:
handler.setLevel(level)


def check_for_ray():
has_ray = True
if find_spec("ray") is None:
has_ray = False

message = (
"ray (https://www.ray.io/) is not available..."
"Falling back to serial stream file parser."
)
warnings.warn(message, ImportWarning)
return has_ray


def check_for_mpi():
try:
from mpi4py import MPI

return True
except Exception as err:
message = (
f"Failed `from mpi4py import MPI` with {err}. Falling back to serial mode."
)
warnings.warn(message, ImportWarning)
return False


@contextmanager
def ray_context(log_level="DEBUG", **ray_kwargs):
import ray

set_ray_loglevel(log_level)
ray.init(**ray_kwargs)
try:
yield ray
finally:
ray.shutdown()
24 changes: 2 additions & 22 deletions reciprocalspaceship/io/crystfel.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import mmap
import re
from contextlib import contextmanager
from importlib.util import find_spec
from typing import Union

import gemmi
import numpy as np

from reciprocalspaceship import DataSet, concat
from reciprocalspaceship.io.common import check_for_ray, ray_context
from reciprocalspaceship.utils import angle_between, eV2Angstroms

# See Rupp Table 5-2
Expand Down Expand Up @@ -60,17 +59,6 @@
}


@contextmanager
def ray_context(**ray_kwargs):
import ray

ray.init(**ray_kwargs)
try:
yield ray
finally:
ray.shutdown()


class StreamLoader(object):
"""
An object that loads stream files into rs.DataSet objects in parallel.
Expand Down Expand Up @@ -304,15 +292,7 @@ def read_crystfel(

# Check whether ray is available
if use_ray:
if find_spec("ray") is None:
use_ray = False
import warnings

message = (
"ray (https://www.ray.io/) is not available..."
"Falling back to serial stream file parser."
)
warnings.warn(message, ImportWarning)
use_ray = check_for_ray()

with open(self.filename, "r") as f:
memfile = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
Expand Down
286 changes: 286 additions & 0 deletions reciprocalspaceship/io/dials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
import logging
import sys

import msgpack
import numpy as np
import pandas

LOGGER = logging.getLogger("rs.io.dials")
if not LOGGER.handlers:
LOGGER.setLevel(logging.DEBUG)
console = logging.StreamHandler(stream=sys.stdout)
console.setLevel(logging.DEBUG)
LOGGER.addHandler(console)

import reciprocalspaceship as rs
from reciprocalspaceship.decorators import cellify, spacegroupify
from reciprocalspaceship.io.common import check_for_ray, set_ray_loglevel

MSGPACK_DTYPES = {
"double": np.float64,
"float": np.float32,
"int": np.int32,
"cctbx::miller::index<>": np.int32,
"vec3<double>": np.float64,
"std::size_t": np.intp,
}

DEFAULT_COLS = [
"miller_index",
"intensity.sum.value",
"intensity.sum.variance",
"xyzcal.px",
"s1",
"delpsical.rad",
"id",
]


def _set_logger(verbose):
level = logging.CRITICAL
if verbose:
level = logging.DEBUG

for log_name in ("rs.io.dials", "ray"):
logger = logging.getLogger(log_name)
logger.setLevel(level)
for handler in logger.handlers:
handler.setLevel(level)


def get_msgpack_data(data, name):
"""

Parameters
----------
data: msgpack data dict
name: msgpack data key

Returns
-------
numpy array of values
"""
dtype, (num, buff) = data[name]
if dtype in MSGPACK_DTYPES:
dtype = MSGPACK_DTYPES[dtype]
else:
dtype = None # should we warn here ?
vals = np.frombuffer(buff, dtype).reshape((num, -1))
data_dict = {}
for i, col_data in enumerate(vals.T):
data_dict[f"{name}.{i}"] = col_data

# remove the .0 suffix if data is a scalar type
if len(data_dict) == 1:
data_dict[name] = data_dict.pop(f"{name}.0")

return data_dict


def _concat(refl_data):
"""combine output of _get_refl_data"""
LOGGER.debug("Combining and formatting tables!")
if isinstance(refl_data, rs.DataSet):
ds = refl_data
else:
refl_data = [ds for ds in refl_data if ds is not None]
ds = rs.concat(refl_data)
expt_ids = set(ds.BATCH)
LOGGER.debug(f"Found {len(ds)} refls from {len(expt_ids)} expts.")
LOGGER.debug("Mapping batch column.")
expt_id_map = {name: i for i, name in enumerate(expt_ids)}
ds.BATCH = [expt_id_map[eid] for eid in ds.BATCH]
rename_map = {"miller_index.0": "H", "miller_index.1": "K", "miller_index.2": "L"}
for name in list(ds):
if "variance" in name:
new_name = name.replace("variance", "sigma")
rename_map[name] = new_name
ds[name] = np.sqrt(ds[name]).astype("Q")
LOGGER.debug(
f"Converted column {name} to MTZ-Type Q, took sqrt of the values, and renamed to {new_name}."
)
ds.rename(columns=rename_map, inplace=True)

LOGGER.debug("Inferring MTZ types...")
ds = ds.infer_mtz_dtypes().set_index(["H", "K", "L"], drop=True)
return ds


def _get_refl_data(fname, unitcell, spacegroup, extra_cols=None):
"""

Parameters
----------
fname: integrated refl file
unitcell: gemmi.UnitCell instance
spacegroup: gemmi.SpaceGroup instance
extra_cols: list of additional columns to read

Returns
-------
RS dataset (pandas Dataframe)

"""
LOGGER.debug(f"Loading {fname}")
pack = _get_refl_pack(fname)
refl_data = pack["data"]
expt_id_map = pack["identifiers"]

if "miller_index" not in refl_data:
raise IOError("refl table must have a miller_index column")

ds_data = {}
col_names = DEFAULT_COLS if extra_cols is None else DEFAULT_COLS + extra_cols
for col_name in col_names:
if col_name in refl_data:
col_data = get_msgpack_data(refl_data, col_name)
LOGGER.debug(f"... Read in data for {col_name}")
ds_data = {**col_data, **ds_data}

if "id" in ds_data:
ds_data["BATCH"] = np.array([expt_id_map[li] for li in ds_data.pop("id")])
ds = rs.DataSet(
ds_data,
cell=unitcell,
spacegroup=spacegroup,
)
ds["PARTIAL"] = True
return ds


def _read_dials_stills_serial(fnames, unitcell, spacegroup, extra_cols=None, **kwargs):
"""run read_dials_stills without trying to import ray"""
result = [
_get_refl_data(fname, unitcell, spacegroup, extra_cols) for fname in fnames
]
return result


def _read_dials_stills_ray(fnames, unitcell, spacegroup, numjobs=10, extra_cols=None):
"""

Parameters
----------
fnames: integration files
unitcell: gemmi.UnitCell instance
spacegroup: gemmi.SpaceGroup instance
numjobs: number of jobs
extra_cols: list of additional columns to read from refl tables

Returns
-------
RS dataset (pandas Dataframe)
"""
from reciprocalspaceship.io.common import ray_context

with ray_context(
log_level=LOGGER.level,
num_cpus=numjobs,
log_to_driver=LOGGER.level == logging.DEBUG,
) as ray:
# get the refl data
get_refl_data = ray.remote(_get_refl_data)
refl_data = ray.get(
[
get_refl_data.remote(fname, unitcell, spacegroup, extra_cols)
for fname in fnames
]
)
return refl_data


@cellify
@spacegroupify
def read_dials_stills(
fnames,
unitcell,
spacegroup,
numjobs=10,
parallel_backend=None,
extra_cols=None,
verbose=False,
comm=None,
):
"""
Parameters
----------
fnames: filenames
unitcell: unit cell tuple, Gemmi unit cell obj
spacegroup: space group symbol eg P4
numjobs: if backend==ray, specify the number of jobs (ignored if backend==mpi)
parallel_backend: ray, mpi, or None
extra_cols: list of additional column names to extract from the refltables. By default, this method will search for
miller_index, id, s1, xyzcal.px, intensity.sum.value, intensity.sum.variance, delpsical.rad
verbose: whether to print stdout
comm: optionally override the communicator used by backend='mpi'

Returns
-------
rs dataset (pandas Dataframe)
"""
_set_logger(verbose)

if parallel_backend not in ["ray", "mpi", None]:
raise NotImplementedError("parallel_backend should be ray, mpi, or none")

kwargs = {
"fnames": fnames,
"unitcell": unitcell,
"spacegroup": spacegroup,
"extra_cols": extra_cols,
}
reader = _read_dials_stills_serial
if parallel_backend == "ray":
kwargs["numjobs"] = numjobs
from reciprocalspaceship.io.common import check_for_ray

if check_for_ray():
reader = _read_dials_stills_ray
elif parallel_backend == "mpi":
from reciprocalspaceship.io.common import check_for_mpi

if check_for_mpi():
from reciprocalspaceship.io.dials_mpi import read_dials_stills_mpi as reader

kwargs["comm"] = comm
result = reader(**kwargs)
if result is not None:
result = _concat(result)
return result


def _get_refl_pack(filename):
pack = msgpack.load(open(filename, "rb"), strict_map_key=False)
try:
assert len(pack) == 3
_, _, pack = pack
except (TypeError, AssertionError):
raise IOError("File does not appear to be dials::af::reflection_table")
return pack


def print_refl_info(reflfile):
"""print contents of `fname`, a reflection table file saved with DIALS"""
pack = _get_refl_pack(reflfile)
if "identifiers" in pack:
idents = pack["identifiers"]
print(f"\nFound {len(idents)} experiment identifiers in {reflfile}:")
for i, ident in idents.items():
print(f"\t{i}: {ident}")
if "data" in pack:
data = pack["data"]
columns = []
col_space = 0
for name in data:
dtype, (_, buff) = data[name]
columns.append((name, dtype))
col_space = max(len(dtype), len(name), col_space)
names, dtypes = zip(*columns)
df = pandas.DataFrame({"names": names, "dtypes": dtypes})
print(
"\nReflection contents:\n"
+ df.to_string(index=False, col_space=col_space + 5, justify="center")
)

if "nrows" in pack:
print(f"\nNumber of reflections: {pack['nrows']} \n")
Loading