diff --git a/environment.yml b/environment.yml index ac48fff..f04f0e4 100644 --- a/environment.yml +++ b/environment.yml @@ -2,6 +2,7 @@ name: kartograf channels: - conda-forge dependencies: + - dill - python - pip #Test diff --git a/src/kartograf/atom_mapper.py b/src/kartograf/atom_mapper.py index 57cea1e..ebc7361 100644 --- a/src/kartograf/atom_mapper.py +++ b/src/kartograf/atom_mapper.py @@ -2,6 +2,7 @@ # For details, see https://github.com/OpenFreeEnergy/kartograf import copy +import dill import inspect import numpy as np from enum import Enum @@ -15,7 +16,7 @@ from scipy.optimize import linear_sum_assignment from scipy.sparse.csgraph import connected_components -from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Callable, Iterable, Optional, Union from gufe import SmallMoleculeComponent from gufe import AtomMapping, AtomMapper, LigandAtomMapping @@ -45,6 +46,9 @@ class mapping_algorithm(Enum): minimal_spanning_tree = "MST" +_mapping_alg_type = Callable[[NDArray, float], dict[int, int]] + + # Helper: vector_eucledean_dist = calculate_edge_weight = lambda x, y: np.sqrt( np.sum(np.square(y - x), axis=1) @@ -56,7 +60,7 @@ class KartografAtomMapper(AtomMapper): atom_max_distance: float _map_exact_ring_matches_only: bool atom_map_hydrogens: bool - mapping_algorithm: mapping_algorithm + mapping_algorithm: _mapping_alg_type _filter_funcs: list[ Callable[[Chem.Mol, Chem.Mol, dict[int, int]], dict[int, int]] @@ -71,8 +75,7 @@ def __init__( map_exact_ring_matches_only: bool = True, additional_mapping_filter_functions: Optional[Iterable[Callable[[ Chem.Mol, Chem.Mol, dict[int, int]], dict[int, int]]]] = None, - _mapping_algorithm: mapping_algorithm = - mapping_algorithm.linear_sum_assignment, + _mapping_algorithm: str = mapping_algorithm.linear_sum_assignment, ): """ Geometry Based Atom Mapper This mapper is a homebrew, that utilises rdkit in order @@ -119,12 +122,10 @@ def __init__( if additional_mapping_filter_functions is not None: self._filter_funcs.extend(additional_mapping_filter_functions) - if _mapping_algorithm is not None and _mapping_algorithm == \ - _mapping_algorithm.linear_sum_assignment: + if _mapping_algorithm == mapping_algorithm.linear_sum_assignment: self._map_hydrogens_on_hydrogens_only = True self.mapping_algorithm = self._linearSumAlgorithm_map - elif _mapping_algorithm is not None and _mapping_algorithm == \ - _mapping_algorithm.minimal_spanning_tree: + elif _mapping_algorithm == mapping_algorithm.minimal_spanning_tree: self.mapping_algorithm = self._minimalSpanningTree_map else: raise ValueError( @@ -132,9 +133,58 @@ def __init__( f"or LSA). got key: {_mapping_algorithm}" ) - """ - Properties - """ + @classmethod + def _defaults(cls): + return {} + + def _to_dict(self) -> dict: + built_in_filters = { + filter_atoms_h_only_h_mapped, + filter_whole_rings_only, + filter_ringsize_changes, + filter_ringbreak_changes, + } + additional_filters = [ + dill.dumps(f) for f in self._filter_funcs + if f not in built_in_filters + ] + + # rather than serialise _filter_funcs, we serialise the arguments + # that lead to the correct _filter_funcs being added + # + # then reverse engineer the _mapping_algorithm argument + # this avoids serialising the function directly + map_arg = { + self._linearSumAlgorithm_map: 'LSA', + self._minimalSpanningTree_map: 'MST', + }[self.mapping_algorithm] + + return { + 'atom_max_distance': self.atom_max_distance, + 'atom_map_hydrogens': self.atom_map_hydrogens, + 'map_hydrogens_on_hydrogens_only': self._map_hydrogens_on_hydrogens_only, + 'map_exact_ring_matches_only': self._map_exact_ring_matches_only, + '_mapping_algorithm': map_arg, + 'filters': additional_filters, + } + + @classmethod + def _from_dict(cls, d: dict): + # replace _mapping_algorithm key to enum + map_arg = d.pop('_mapping_algorithm', 'LSA') + + map_alg = { + 'LSA': mapping_algorithm.linear_sum_assignment, + 'MSR': mapping_algorithm.minimal_spanning_tree, + }[map_arg] + + d['_mapping_algorithm'] = map_alg + + d['additional_mapping_filter_functions'] = [ + dill.loads(f) for f in d.pop('filters', []) + ] + + return cls(**d) @property def map_hydrogens_on_hydrogens_only(self) -> bool: @@ -166,40 +216,6 @@ def map_exact_ring_matches_only(self, s: bool): elif f in self._filter_funcs: self._filter_funcs.remove(f) - """ - Privat - Serialize - """ - - @classmethod - def _from_dict(cls, d: dict): - """Deserialize from dict representation""" - if any(k not in cls._defaults() for k in d): - keys = list(filter(lambda k: k in cls._defaults(), d.keys())) - raise ValueError(f"I don't know about all the keys here: {keys}") - return cls(**d) - - def _to_dict(self) -> dict: - d = {} - for key in self._defaults(): - if hasattr(self, key): - d[key] = getattr(self, key) - return d - - @classmethod - def _defaults(cls): - """This method should be overridden to provide the dict of defaults - appropriate for the `GufeTokenizable` subclass. - """ - sig = inspect.signature(cls.__init__) - - defaults = { - param.name: param.default - for param in sig.parameters.values() - if param.default is not inspect.Parameter.empty - } - - return defaults - """ Private - Set Operations """ @@ -209,8 +225,8 @@ def _filter_mapping_for_max_overlapping_connected_atom_set( cls, moleculeA: Chem.Mol, moleculeB: Chem.Mol, - atom_mapping: Dict[int, int], - ) -> Dict[int, int]: + atom_mapping: dict[int, int], + ) -> dict[int, int]: """ Find connected core region from raw mapping This algorithm finds the maximal overlapping connected set of two molecules and a given mapping. In order to accomplish this @@ -264,8 +280,8 @@ def _filter_mapping_for_max_overlapping_connected_atom_set( @staticmethod def _get_connected_atom_subsets( - mol: Chem.Mol, to_be_searched: List[int] - ) -> List[Set[int]]: + mol: Chem.Mol, to_be_searched: list[int] + ) -> list[set[int]]: """ find connected sets in mappings Get the connected sets of all to_be_searched atom indices in mol. Connected means the atoms in a resulting connected set are connected @@ -356,9 +372,9 @@ def _get_connected_atom_subsets( @staticmethod def _get_maximal_mapping_set_overlap( - sets_a: Iterable[Set], sets_b: Iterable[Set], - mapping: Dict[int, int] - ) -> Tuple[Set, Set]: + sets_a: Iterable[set], sets_b: Iterable[set], + mapping: dict[int, int] + ) -> tuple[set, set]: """get the largest set overlaps in the mapping of set_a and set_b. Parameters @@ -410,8 +426,8 @@ def _get_maximal_mapping_set_overlap( @staticmethod def _filter_mapping_for_set_overlap( - set_a: Set[int], set_b: Set[int], mapping: Dict[int, int] - ) -> Dict[int, int]: + set_a: set[int], set_b: set[int], mapping: dict[int, int] + ) -> dict[int, int]: """This filter reduces the mapping dict to only in the sets contained atom IDs @@ -479,8 +495,8 @@ def _get_full_distance_matrix( @staticmethod def _mask_atoms( - mol, mol_pos, map_hydrogens: bool = False, masked_atoms: List = [], - ) -> Tuple[Dict, List]: + mol, mol_pos, map_hydrogens: bool, masked_atoms: list[int], + ) -> tuple[dict, list]: """Mask atoms such they are not considered during the mapping. Parameters @@ -515,7 +531,7 @@ def _mask_atoms( def _minimalSpanningTree_map( self, distance_matrix: NDArray, max_dist: float - ) -> Dict[int, int]: + ) -> dict[int, int]: """MST Mapping This function is a numpy graph based implementation to build up an Atom Mapping purely on 3D criteria. @@ -561,7 +577,7 @@ def _minimalSpanningTree_map( @staticmethod def _linearSumAlgorithm_map( distance_matrix: NDArray, max_dist: float - ) -> Dict[int, int]: + ) -> dict[int, int]: """ LSA mapping This function is a LSA based implementation to build up an Atom Mapping purely on 3D criteria. @@ -589,7 +605,7 @@ def _linearSumAlgorithm_map( def _additional_filter_rules( self, molA: Chem.Mol, molB: Chem.Mol, mapping: dict[int, int] - ) -> Dict[int, int]: + ) -> dict[int, int]: """apply additional filter rules to the given mapping. Parameters diff --git a/src/kartograf/tests/test_atom_mapper.py b/src/kartograf/tests/test_atom_mapper.py index 26e79d8..f58dbe4 100644 --- a/src/kartograf/tests/test_atom_mapper.py +++ b/src/kartograf/tests/test_atom_mapper.py @@ -169,39 +169,43 @@ def test_stereo_mapping(stereco_chem_molecules, stereo_chem_mapping): check_mapping_vs_expected(geom_mapping, expected_mapping) -# Test Serialization -def test_to_from_dict(): - mapper = KartografAtomMapper() - d1 = mapper._to_dict() - mapper2 = KartografAtomMapper._from_dict(d1) - d2 = mapper2._to_dict() - - for key, val1 in d1.items(): - val2 = d2[key] - - if(val1 != val2): - raise ValueError("they need to be identical.") +class TestSerialisation: + def test_to_from_dict_cycle(self): + m = KartografAtomMapper() + + m_dict = m.to_dict() + + m2 = KartografAtomMapper.from_dict(m_dict) + + assert m == m2 + + @pytest.mark.parametrize('mhoho,mermo', [ + (True, True), + (True, False), + (False, True), + (False, False), + ]) + def test_check_filters(self, mhoho, mermo): + m = KartografAtomMapper( + map_hydrogens_on_hydrogens_only=mhoho, + map_exact_ring_matches_only=mermo, + ) - mapper2.atom_max_distance = 10 - d3 = mapper2._to_dict() + m2 = KartografAtomMapper.from_dict(m.to_dict()) - for key, val1 in d1.items(): - val2 = d3[key] + assert m._filter_funcs == m2._filter_funcs - if(val1 != val2 and key != "atom_max_distance"): - raise ValueError("they need to be identical.") - if(key == "atom_max_distance" and val1 == val2 ): - raise ValueError("they must not be identical.") + def test_custom_filters(self): + def nop_filter(a, b, c): + return c + m = KartografAtomMapper( + additional_mapping_filter_functions=[nop_filter], + ) -def test_to_from_dict_wrong(): - mapper = KartografAtomMapper() - d1 = mapper._to_dict() - d1.update({"FLEEEEE": "You FOOLS"}) + m2 = KartografAtomMapper.from_dict(m.to_dict()) - with pytest.raises(ValueError) as exc: - mapper2 = KartografAtomMapper._from_dict(d1) - assert "I don't know about all the keys here" in str(exc.value) + assert m._filter_funcs == m2._filter_funcs def test_filter_property():