From f8aeac61a28f4a42404d7cbd54aba3224d1bcb85 Mon Sep 17 00:00:00 2001 From: richard gowers Date: Tue, 23 Jan 2024 13:26:56 +0000 Subject: [PATCH 1/5] use builtins for type annotation --- src/kartograf/atom_mapper.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/kartograf/atom_mapper.py b/src/kartograf/atom_mapper.py index 2d6b195..958951e 100644 --- a/src/kartograf/atom_mapper.py +++ b/src/kartograf/atom_mapper.py @@ -15,7 +15,7 @@ from scipy.optimize import linear_sum_assignment from scipy.sparse.csgraph import connected_components -from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Callable, Iterable, Optional, Union from gufe import SmallMoleculeComponent from gufe import LigandAtomMapping @@ -210,8 +210,8 @@ def _filter_mapping_for_max_overlapping_connected_atom_set( cls, moleculeA: Chem.Mol, moleculeB: Chem.Mol, - atom_mapping: Dict[int, int], - ) -> Dict[int, int]: + atom_mapping: dict[int, int], + ) -> dict[int, int]: """ Find connected core region from raw mapping This algorithm finds the maximal overlapping connected set of two molecules and a given mapping. In order to accomplish this @@ -265,8 +265,8 @@ def _filter_mapping_for_max_overlapping_connected_atom_set( @staticmethod def _get_connected_atom_subsets( - mol: Chem.Mol, to_be_searched: List[int] - ) -> List[Set[int]]: + mol: Chem.Mol, to_be_searched: list[int] + ) -> list[set[int]]: """ find connected sets in mappings Get the connected sets of all to_be_searched atom indices in mol. Connected means the atoms in a resulting connected set are connected @@ -357,9 +357,9 @@ def _get_connected_atom_subsets( @staticmethod def _get_maximal_mapping_set_overlap( - sets_a: Iterable[Set], sets_b: Iterable[Set], - mapping: Dict[int, int] - ) -> Tuple[Set, Set]: + sets_a: Iterable[set], sets_b: Iterable[set], + mapping: dict[int, int] + ) -> tuple[set, set]: """get the largest set overlaps in the mapping of set_a and set_b. Parameters @@ -411,8 +411,8 @@ def _get_maximal_mapping_set_overlap( @staticmethod def _filter_mapping_for_set_overlap( - set_a: Set[int], set_b: Set[int], mapping: Dict[int, int] - ) -> Dict[int, int]: + set_a: set[int], set_b: set[int], mapping: dict[int, int] + ) -> dict[int, int]: """This filter reduces the mapping dict to only in the sets contained atom IDs @@ -480,8 +480,8 @@ def _get_full_distance_matrix( @staticmethod def _mask_atoms( - mol, mol_pos, map_hydrogens: bool = False, masked_atoms: List = [], - ) -> Tuple[Dict, List]: + mol, mol_pos, map_hydrogens: bool = False, masked_atoms: list = [], + ) -> tuple[dict, list]: """Mask atoms such they are not considered during the mapping. Parameters @@ -516,7 +516,7 @@ def _mask_atoms( def _minimalSpanningTree_map( self, distance_matrix: NDArray, max_dist: float - ) -> Dict[int, int]: + ) -> dict[int, int]: """MST Mapping This function is a numpy graph based implementation to build up an Atom Mapping purely on 3D criteria. @@ -562,7 +562,7 @@ def _minimalSpanningTree_map( @staticmethod def _linearSumAlgorithm_map( distance_matrix: NDArray, max_dist: float - ) -> Dict[int, int]: + ) -> dict[int, int]: """ LSA mapping This function is a LSA based implementation to build up an Atom Mapping purely on 3D criteria. @@ -590,7 +590,7 @@ def _linearSumAlgorithm_map( def _additional_filter_rules( self, molA: Chem.Mol, molB: Chem.Mol, mapping: dict[int, int] - ) -> Dict[int, int]: + ) -> dict[int, int]: """apply additional filter rules to the given mapping. Parameters From 3d820b89c2768d42e0a37c74260a73c78c93d4d2 Mon Sep 17 00:00:00 2001 From: richard gowers Date: Tue, 23 Jan 2024 13:28:42 +0000 Subject: [PATCH 2/5] clean up signature of _mask_atoms default arguments were never used, and the default list was potentially problematic --- src/kartograf/atom_mapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/kartograf/atom_mapper.py b/src/kartograf/atom_mapper.py index 958951e..650f413 100644 --- a/src/kartograf/atom_mapper.py +++ b/src/kartograf/atom_mapper.py @@ -480,7 +480,7 @@ def _get_full_distance_matrix( @staticmethod def _mask_atoms( - mol, mol_pos, map_hydrogens: bool = False, masked_atoms: list = [], + mol, mol_pos, map_hydrogens: bool, masked_atoms: list[int], ) -> tuple[dict, list]: """Mask atoms such they are not considered during the mapping. From a53c66d704319ba6f9307e33fa0e035fcc150eb7 Mon Sep 17 00:00:00 2001 From: richard gowers Date: Tue, 23 Jan 2024 13:36:39 +0000 Subject: [PATCH 3/5] fix up mapping algorithm type annotation --- src/kartograf/atom_mapper.py | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/src/kartograf/atom_mapper.py b/src/kartograf/atom_mapper.py index 650f413..10dc210 100644 --- a/src/kartograf/atom_mapper.py +++ b/src/kartograf/atom_mapper.py @@ -46,6 +46,9 @@ class mapping_algorithm(Enum): minimal_spanning_tree = "MST" +_mapping_alg_type = Callable[[NDArray, float], dict[int, int]] + + # Helper: vector_eucledean_dist = calculate_edge_weight = lambda x, y: np.sqrt( np.sum(np.square(y - x), axis=1) @@ -57,7 +60,7 @@ class KartografAtomMapper(AtomMapper): atom_max_distance: float _map_exact_ring_matches_only: bool atom_map_hydrogens: bool - mapping_algorithm: mapping_algorithm + mapping_algorithm: _mapping_alg_type _filter_funcs: list[ Callable[[Chem.Mol, Chem.Mol, dict[int, int]], dict[int, int]] @@ -72,7 +75,7 @@ def __init__( map_exact_ring_matches_only: bool = True, additional_mapping_filter_functions: Optional[Iterable[Callable[[ Chem.Mol, Chem.Mol, dict[int, int]], dict[int, int]]]] = None, - _mapping_algorithm: mapping_algorithm = + _mapping_algorithm: str = mapping_algorithm.linear_sum_assignment, ): """ Geometry Based Atom Mapper @@ -120,12 +123,10 @@ def __init__( if additional_mapping_filter_functions is not None: self._filter_funcs.extend(additional_mapping_filter_functions) - if _mapping_algorithm is not None and _mapping_algorithm == \ - _mapping_algorithm.linear_sum_assignment: + if _mapping_algorithm == mapping_algorithm.linear_sum_assignment: self._map_hydrogens_on_hydrogens_only = True self.mapping_algorithm = self._linearSumAlgorithm_map - elif _mapping_algorithm is not None and _mapping_algorithm == \ - _mapping_algorithm.minimal_spanning_tree: + elif _mapping_algorithm == mapping_algorithm.minimal_spanning_tree: self.mapping_algorithm = self._minimalSpanningTree_map else: raise ValueError( @@ -133,9 +134,22 @@ def __init__( f"or LSA). got key: {_mapping_algorithm}" ) - """ - Properties - """ + @classmethod + def _defaults(cls): + return {} + + def _to_dict(self) -> dict: + # currently only serialise some filter functions + + + return { + 'atom_max_distance': self.atom_max_distance, + 'atom_map_hydrogens': self.atom_map_hydrogens, + } + + @classmethod + def _from_dict(cls, d: dict): + return cls(**d) @property def map_hydrogens_on_hydrogens_only(self) -> bool: From e0c598050465b353f208fc8681667171b59034a7 Mon Sep 17 00:00:00 2001 From: richard gowers Date: Tue, 23 Jan 2024 16:11:10 +0000 Subject: [PATCH 4/5] serialisation for KartografAtomMapper --- src/kartograf/atom_mapper.py | 60 +++++++++++++++--------------------- 1 file changed, 24 insertions(+), 36 deletions(-) diff --git a/src/kartograf/atom_mapper.py b/src/kartograf/atom_mapper.py index 10dc210..72105af 100644 --- a/src/kartograf/atom_mapper.py +++ b/src/kartograf/atom_mapper.py @@ -139,12 +139,34 @@ def _defaults(cls): return {} def _to_dict(self) -> dict: - # currently only serialise some filter functions - + # currently only serialise built-in filter functions + # so check that we don't have any custom filters + allowed = { + filter_atoms_h_only_h_mapped, + filter_whole_rings_only, + filter_ringsize_changes, + filter_ringbreak_changes, + } + present = set(self._filter_funcs) + if remaining := present - allowed: + raise NotImplementedError("Can't (yet) serialise arbitrary functions, " + f"got: {remaining}") + # rather than serialise _filter_funcs, we serialise the arguments + # that lead to the correct _filter_funcs being added + # + # then reverse engineer the _mapping_algorithm argument + # this avoids serialising the function directly + map_arg = { + self._linearSumAlgorithm_map: mapping_algorithm.linear_sum_assignment, + self._minimalSpanningTree_map: mapping_algorithm.minimal_spanning_tree, + }[self.mapping_algorithm] return { 'atom_max_distance': self.atom_max_distance, 'atom_map_hydrogens': self.atom_map_hydrogens, + 'map_hydrogens_on_hydrogens_only': self._map_hydrogens_on_hydrogens_only, + 'map_exact_ring_matches_only': self._map_exact_ring_matches_only, + '_mapping_algorithm': map_arg, } @classmethod @@ -181,40 +203,6 @@ def map_exact_ring_matches_only(self, s: bool): elif f in self._filter_funcs: self._filter_funcs.remove(f) - """ - Privat - Serialize - """ - - @classmethod - def _from_dict(cls, d: dict): - """Deserialize from dict representation""" - if any(k not in cls._defaults() for k in d): - keys = list(filter(lambda k: k in cls._defaults(), d.keys())) - raise ValueError(f"I don't know about all the keys here: {keys}") - return cls(**d) - - def _to_dict(self) -> dict: - d = {} - for key in self._defaults(): - if hasattr(self, key): - d[key] = getattr(self, key) - return d - - @classmethod - def _defaults(cls): - """This method should be overridden to provide the dict of defaults - appropriate for the `GufeTokenizable` subclass. - """ - sig = inspect.signature(cls.__init__) - - defaults = { - param.name: param.default - for param in sig.parameters.values() - if param.default is not inspect.Parameter.empty - } - - return defaults - """ Private - Set Operations """ From 4a67c2c84a3ad0f200a82df4870422fb205a5b57 Mon Sep 17 00:00:00 2001 From: richard gowers Date: Wed, 24 Jan 2024 16:55:10 +0000 Subject: [PATCH 5/5] tests for serialising karto mapper --- environment.yml | 1 + src/kartograf/atom_mapper.py | 36 ++++++++++----- src/kartograf/tests/test_atom_mapper.py | 58 +++++++++++++------------ 3 files changed, 57 insertions(+), 38 deletions(-) diff --git a/environment.yml b/environment.yml index fd63395..5642b8d 100644 --- a/environment.yml +++ b/environment.yml @@ -2,6 +2,7 @@ name: kartograf channels: - conda-forge dependencies: + - dill - python - pip #Test diff --git a/src/kartograf/atom_mapper.py b/src/kartograf/atom_mapper.py index 72105af..1c35de5 100644 --- a/src/kartograf/atom_mapper.py +++ b/src/kartograf/atom_mapper.py @@ -2,6 +2,7 @@ # For details, see https://github.com/OpenFreeEnergy/kartograf import copy +import dill import inspect import numpy as np from enum import Enum @@ -75,8 +76,7 @@ def __init__( map_exact_ring_matches_only: bool = True, additional_mapping_filter_functions: Optional[Iterable[Callable[[ Chem.Mol, Chem.Mol, dict[int, int]], dict[int, int]]]] = None, - _mapping_algorithm: str = - mapping_algorithm.linear_sum_assignment, + _mapping_algorithm: str = mapping_algorithm.linear_sum_assignment, ): """ Geometry Based Atom Mapper This mapper is a homebrew, that utilises rdkit in order @@ -139,26 +139,25 @@ def _defaults(cls): return {} def _to_dict(self) -> dict: - # currently only serialise built-in filter functions - # so check that we don't have any custom filters - allowed = { + built_in_filters = { filter_atoms_h_only_h_mapped, filter_whole_rings_only, filter_ringsize_changes, filter_ringbreak_changes, } - present = set(self._filter_funcs) - if remaining := present - allowed: - raise NotImplementedError("Can't (yet) serialise arbitrary functions, " - f"got: {remaining}") + additional_filters = [ + dill.dumps(f) for f in self._filter_funcs + if f not in built_in_filters + ] + # rather than serialise _filter_funcs, we serialise the arguments # that lead to the correct _filter_funcs being added # # then reverse engineer the _mapping_algorithm argument # this avoids serialising the function directly map_arg = { - self._linearSumAlgorithm_map: mapping_algorithm.linear_sum_assignment, - self._minimalSpanningTree_map: mapping_algorithm.minimal_spanning_tree, + self._linearSumAlgorithm_map: 'LSA', + self._minimalSpanningTree_map: 'MST', }[self.mapping_algorithm] return { @@ -167,10 +166,25 @@ def _to_dict(self) -> dict: 'map_hydrogens_on_hydrogens_only': self._map_hydrogens_on_hydrogens_only, 'map_exact_ring_matches_only': self._map_exact_ring_matches_only, '_mapping_algorithm': map_arg, + 'filters': additional_filters, } @classmethod def _from_dict(cls, d: dict): + # replace _mapping_algorithm key to enum + map_arg = d.pop('_mapping_algorithm', 'LSA') + + map_alg = { + 'LSA': mapping_algorithm.linear_sum_assignment, + 'MSR': mapping_algorithm.minimal_spanning_tree, + }[map_arg] + + d['_mapping_algorithm'] = map_alg + + d['additional_mapping_filter_functions'] = [ + dill.loads(f) for f in d.pop('filters', []) + ] + return cls(**d) @property diff --git a/src/kartograf/tests/test_atom_mapper.py b/src/kartograf/tests/test_atom_mapper.py index 26e79d8..f58dbe4 100644 --- a/src/kartograf/tests/test_atom_mapper.py +++ b/src/kartograf/tests/test_atom_mapper.py @@ -169,39 +169,43 @@ def test_stereo_mapping(stereco_chem_molecules, stereo_chem_mapping): check_mapping_vs_expected(geom_mapping, expected_mapping) -# Test Serialization -def test_to_from_dict(): - mapper = KartografAtomMapper() - d1 = mapper._to_dict() - mapper2 = KartografAtomMapper._from_dict(d1) - d2 = mapper2._to_dict() - - for key, val1 in d1.items(): - val2 = d2[key] - - if(val1 != val2): - raise ValueError("they need to be identical.") +class TestSerialisation: + def test_to_from_dict_cycle(self): + m = KartografAtomMapper() + + m_dict = m.to_dict() + + m2 = KartografAtomMapper.from_dict(m_dict) + + assert m == m2 + + @pytest.mark.parametrize('mhoho,mermo', [ + (True, True), + (True, False), + (False, True), + (False, False), + ]) + def test_check_filters(self, mhoho, mermo): + m = KartografAtomMapper( + map_hydrogens_on_hydrogens_only=mhoho, + map_exact_ring_matches_only=mermo, + ) - mapper2.atom_max_distance = 10 - d3 = mapper2._to_dict() + m2 = KartografAtomMapper.from_dict(m.to_dict()) - for key, val1 in d1.items(): - val2 = d3[key] + assert m._filter_funcs == m2._filter_funcs - if(val1 != val2 and key != "atom_max_distance"): - raise ValueError("they need to be identical.") - if(key == "atom_max_distance" and val1 == val2 ): - raise ValueError("they must not be identical.") + def test_custom_filters(self): + def nop_filter(a, b, c): + return c + m = KartografAtomMapper( + additional_mapping_filter_functions=[nop_filter], + ) -def test_to_from_dict_wrong(): - mapper = KartografAtomMapper() - d1 = mapper._to_dict() - d1.update({"FLEEEEE": "You FOOLS"}) + m2 = KartografAtomMapper.from_dict(m.to_dict()) - with pytest.raises(ValueError) as exc: - mapper2 = KartografAtomMapper._from_dict(d1) - assert "I don't know about all the keys here" in str(exc.value) + assert m._filter_funcs == m2._filter_funcs def test_filter_property():