Skip to content

Commit

Permalink
add pickling (#258)
Browse files Browse the repository at this point in the history
* Delete pickler.py

* add `__reduce__` to ensure pickling works for all attributes

Without `__reduce__`, the input params of `__init__` might not be pickled, e.g. `id`.

* add `save_links` method

* Update test_nplinker_local.py

* update `save_data` method to ensure consistent return

* fix error caused by NumPy 2.0

AttributeError: `np.NINF` was removed in the NumPy 2.0 release. Use `-np.inf` instead.
  • Loading branch information
CunliangGeng authored Jun 17, 2024
1 parent 6a6f170 commit bedebde
Show file tree
Hide file tree
Showing 9 changed files with 83 additions and 123 deletions.
4 changes: 4 additions & 0 deletions src/nplinker/genomics/bgc.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ def __eq__(self, other) -> bool:
def __hash__(self) -> int:
return hash((self.id, self.product_prediction))

def __reduce__(self) -> tuple:
"""Reduce function for pickling."""
return (self.__class__, (self.id, *self.product_prediction), self.__dict__)

def add_parent(self, gcf: GCF) -> None:
"""Add a parent GCF to the BGC.
Expand Down
4 changes: 4 additions & 0 deletions src/nplinker/genomics/gcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def __hash__(self) -> int:
"""
return hash(self.id)

def __reduce__(self) -> tuple:
"""Reduce function for pickling."""
return (self.__class__, (self.id,), self.__dict__)

@property
def bgcs(self) -> set[BGC]:
"""Get the BGC objects."""
Expand Down
4 changes: 4 additions & 0 deletions src/nplinker/metabolomics/molecular_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def __eq__(self, other) -> bool:
def __hash__(self) -> int:
return hash(self.id)

def __reduce__(self) -> tuple:
"""Reduce function for pickling."""
return (self.__class__, (self.id,), self.__dict__)

@property
def spectra(self) -> set[Spectrum]:
"""Get Spectrum objects in the molecular family."""
Expand Down
8 changes: 8 additions & 0 deletions src/nplinker/metabolomics/spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ def __eq__(self, other) -> bool:
def __hash__(self) -> int:
return hash((self.id, self.precursor_mz))

def __reduce__(self) -> tuple:
"""Reduce function for pickling."""
return (
self.__class__,
(self.id, self.mz, self.intensity, self.precursor_mz, self.rt, self.metadata),
self.__dict__,
)

@cached_property
def peaks(self) -> np.ndarray:
"""Get the peaks, a 2D array with each row containing the values of (m/z, intensity)."""
Expand Down
20 changes: 20 additions & 0 deletions src/nplinker/nplinker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations
import logging
import pickle
from os import PathLike
from pprint import pformat
from typing import Sequence
Expand Down Expand Up @@ -295,3 +296,22 @@ def lookup_mf(self, id: str) -> MolecularFamily | None:
The MolecularFamily object with the given ID, or None if no such object exists.
"""
return self._mf_dict.get(id, None)

def save_data(
self,
file: str | PathLike,
links: LinkGraph | None = None,
) -> None:
"""Pickle data to a file.
The data to be pickled is a tuple containing the BGCs, GCFs, Spectra, MolecularFamilies,
StrainCollection and links, i.e. `(bgcs, gcfs, spectra, mfs, strains, links)`. If the links
are not provided, `None` will be used.
Args:
file: The path to the pickle file to save the data to.
links: The LinkGraph object to save.
"""
data = (self.bgcs, self.gcfs, self.spectra, self.mfs, self.strains, links)
with open(file, "wb") as f:
pickle.dump(data, f)
116 changes: 0 additions & 116 deletions src/nplinker/pickler.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/nplinker/scoring/metcalf_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def get_links(self, *objects, **parameters):
"MetcalfScoring.metcalf_mean and metcalf_std are not set. Run MetcalfScoring.setup first."
)
# use negative infinity as the score cutoff to ensure we get all links
scores_list = self._get_links(*objects, obj_type=obj_type, score_cutoff=np.NINF)
scores_list = self._get_links(*objects, obj_type=obj_type, score_cutoff=-np.inf)
scores_list = self._calc_standardised_score(scores_list)

links = LinkGraph()
Expand Down
36 changes: 36 additions & 0 deletions tests/integration/test_nplinker_local.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import os
import pickle
import pytest
from nplinker.genomics import GCF
from nplinker.metabolomics import MolecularFamily
from nplinker.metabolomics import Spectrum
from nplinker.nplinker import NPLinker
from . import DATA_DIR

Expand Down Expand Up @@ -70,3 +74,35 @@ def test_get_links(npl):
for _, _, scores in lg.links:
score = scores[scoring_method]
assert score.value >= 0


def test_save_data(npl):
scoring_method = "metcalf"
links = npl.get_links(npl.gcfs[:3], scoring_method)

pickle_file = os.path.join(npl.output_dir, "npl.pkl")
npl.save_data(pickle_file, links)

with open(pickle_file, "rb") as f:
bgcs, gcfs, spectra, mfs, strains, lg = pickle.load(f)

# tests from `test_load_data`
assert len(bgcs) == 390
assert len(gcfs) == 64
assert len(spectra) == 24652
assert len(mfs) == 29
assert len(strains) == 46

# tests from `test_get_links`
for obj1, obj2, scores in lg.links:
score = scores[scoring_method]
assert score.value >= 0

if isinstance(obj1, GCF):
assert obj1 in gcfs
elif isinstance(obj1, Spectrum):
assert obj1 in spectra
elif isinstance(obj1, MolecularFamily):
assert obj1 in mfs
else:
assert False
12 changes: 6 additions & 6 deletions tests/unit/scoring/test_metcalf_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_get_links_invalid_mixed_types(mc, spectra, mfs):
def test_get_links_gcf_standardised_false(mc, gcfs, spectra, mfs):
"""Test `get_links` method when input is GCF objects and `standardised` is False."""
# when cutoff is negative infinity, i.e. taking all scores
lg = mc.get_links(*gcfs, cutoff=np.NINF, standardised=False)
lg = mc.get_links(*gcfs, cutoff=-np.inf, standardised=False)
assert lg[gcfs[0]][spectra[0]][mc.name].value == 12
assert lg[gcfs[1]][spectra[0]][mc.name].value == -9
assert lg[gcfs[2]][spectra[0]][mc.name].value == 11
Expand All @@ -121,7 +121,7 @@ def test_get_links_gcf_standardised_false(mc, gcfs, spectra, mfs):

def test_get_links_gcf_standardised_true(mc, gcfs):
"""Test `get_links` method when input is GCF objects and `standardised` is True."""
lg = mc.get_links(*gcfs, cutoff=np.NINF, standardised=True)
lg = mc.get_links(*gcfs, cutoff=-np.inf, standardised=True)
assert len(lg.links) == 18

lg = mc.get_links(*gcfs, cutoff=0, standardised=True)
Expand All @@ -130,7 +130,7 @@ def test_get_links_gcf_standardised_true(mc, gcfs):

def test_get_links_spec_standardised_false(mc, gcfs, spectra):
"""Test `get_links` method when input is Spectrum objects and `standardised` is False."""
lg = mc.get_links(*spectra, cutoff=np.NINF, standardised=False)
lg = mc.get_links(*spectra, cutoff=-np.inf, standardised=False)
assert lg[spectra[0]][gcfs[0]][mc.name].value == 12
assert lg[spectra[0]][gcfs[1]][mc.name].value == -9
assert lg[spectra[0]][gcfs[2]][mc.name].value == 11
Expand All @@ -143,7 +143,7 @@ def test_get_links_spec_standardised_false(mc, gcfs, spectra):

def test_get_links_spec_standardised_true(mc, gcfs, spectra):
"""Test `get_links` method when input is Spectrum objects and `standardised` is True."""
lg = mc.get_links(*spectra, cutoff=np.NINF, standardised=True)
lg = mc.get_links(*spectra, cutoff=-np.inf, standardised=True)
assert len(lg.links) == 9

lg = mc.get_links(*spectra, cutoff=0, standardised=True)
Expand All @@ -152,7 +152,7 @@ def test_get_links_spec_standardised_true(mc, gcfs, spectra):

def test_get_links_mf_standardised_false(mc, gcfs, mfs):
"""Test `get_links` method when input is MolecularFamily objects and `standardised` is False."""
lg = mc.get_links(*mfs, cutoff=np.NINF, standardised=False)
lg = mc.get_links(*mfs, cutoff=-np.inf, standardised=False)
assert lg[mfs[0]][gcfs[0]][mc.name].value == 12
assert lg[mfs[0]][gcfs[1]][mc.name].value == -9
assert lg[mfs[0]][gcfs[2]][mc.name].value == 11
Expand All @@ -165,7 +165,7 @@ def test_get_links_mf_standardised_false(mc, gcfs, mfs):

def test_get_links_mf_standardised_true(mc, gcfs, mfs):
"""Test `get_links` method when input is MolecularFamily objects and `standardised` is True."""
lg = mc.get_links(*mfs, cutoff=np.NINF, standardised=True)
lg = mc.get_links(*mfs, cutoff=-np.inf, standardised=True)
assert len(lg.links) == 9

lg = mc.get_links(*mfs, cutoff=0, standardised=True)
Expand Down

0 comments on commit bedebde

Please sign in to comment.