Skip to content

Commit

Permalink
Merge pull request #33 from OpenFreeEnergy/make_mapper_tokenizable
Browse files Browse the repository at this point in the history
Make Mapper tokenizable
  • Loading branch information
richardjgowers authored Jan 25, 2024
2 parents cf44f64 + 5c58053 commit bd63ed4
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 86 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ name: kartograf
channels:
- conda-forge
dependencies:
- dill
- python
- pip
#Test
Expand Down
134 changes: 75 additions & 59 deletions src/kartograf/atom_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# For details, see https://github.com/OpenFreeEnergy/kartograf

import copy
import dill
import inspect
import numpy as np
from enum import Enum
Expand All @@ -15,7 +16,7 @@
from scipy.optimize import linear_sum_assignment
from scipy.sparse.csgraph import connected_components

from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import Callable, Iterable, Optional, Union

from gufe import SmallMoleculeComponent
from gufe import AtomMapping, AtomMapper, LigandAtomMapping
Expand Down Expand Up @@ -45,6 +46,9 @@ class mapping_algorithm(Enum):
minimal_spanning_tree = "MST"


_mapping_alg_type = Callable[[NDArray, float], dict[int, int]]


# Helper:
vector_eucledean_dist = calculate_edge_weight = lambda x, y: np.sqrt(
np.sum(np.square(y - x), axis=1)
Expand All @@ -56,7 +60,7 @@ class KartografAtomMapper(AtomMapper):
atom_max_distance: float
_map_exact_ring_matches_only: bool
atom_map_hydrogens: bool
mapping_algorithm: mapping_algorithm
mapping_algorithm: _mapping_alg_type

_filter_funcs: list[
Callable[[Chem.Mol, Chem.Mol, dict[int, int]], dict[int, int]]
Expand All @@ -71,8 +75,7 @@ def __init__(
map_exact_ring_matches_only: bool = True,
additional_mapping_filter_functions: Optional[Iterable[Callable[[
Chem.Mol, Chem.Mol, dict[int, int]], dict[int, int]]]] = None,
_mapping_algorithm: mapping_algorithm =
mapping_algorithm.linear_sum_assignment,
_mapping_algorithm: str = mapping_algorithm.linear_sum_assignment,
):
""" Geometry Based Atom Mapper
This mapper is a homebrew, that utilises rdkit in order
Expand Down Expand Up @@ -119,22 +122,69 @@ def __init__(
if additional_mapping_filter_functions is not None:
self._filter_funcs.extend(additional_mapping_filter_functions)

if _mapping_algorithm is not None and _mapping_algorithm == \
_mapping_algorithm.linear_sum_assignment:
if _mapping_algorithm == mapping_algorithm.linear_sum_assignment:
self._map_hydrogens_on_hydrogens_only = True
self.mapping_algorithm = self._linearSumAlgorithm_map
elif _mapping_algorithm is not None and _mapping_algorithm == \
_mapping_algorithm.minimal_spanning_tree:
elif _mapping_algorithm == mapping_algorithm.minimal_spanning_tree:
self.mapping_algorithm = self._minimalSpanningTree_map
else:
raise ValueError(
f"Mapping algorithm not implemented or unknown (options: MST "
f"or LSA). got key: {_mapping_algorithm}"
)

"""
Properties
"""
@classmethod
def _defaults(cls):
return {}

def _to_dict(self) -> dict:
built_in_filters = {
filter_atoms_h_only_h_mapped,
filter_whole_rings_only,
filter_ringsize_changes,
filter_ringbreak_changes,
}
additional_filters = [
dill.dumps(f) for f in self._filter_funcs
if f not in built_in_filters
]

# rather than serialise _filter_funcs, we serialise the arguments
# that lead to the correct _filter_funcs being added
#
# then reverse engineer the _mapping_algorithm argument
# this avoids serialising the function directly
map_arg = {
self._linearSumAlgorithm_map: 'LSA',
self._minimalSpanningTree_map: 'MST',
}[self.mapping_algorithm]

return {
'atom_max_distance': self.atom_max_distance,
'atom_map_hydrogens': self.atom_map_hydrogens,
'map_hydrogens_on_hydrogens_only': self._map_hydrogens_on_hydrogens_only,
'map_exact_ring_matches_only': self._map_exact_ring_matches_only,
'_mapping_algorithm': map_arg,
'filters': additional_filters,
}

@classmethod
def _from_dict(cls, d: dict):
# replace _mapping_algorithm key to enum
map_arg = d.pop('_mapping_algorithm', 'LSA')

map_alg = {
'LSA': mapping_algorithm.linear_sum_assignment,
'MSR': mapping_algorithm.minimal_spanning_tree,
}[map_arg]

d['_mapping_algorithm'] = map_alg

d['additional_mapping_filter_functions'] = [
dill.loads(f) for f in d.pop('filters', [])
]

return cls(**d)

@property
def map_hydrogens_on_hydrogens_only(self) -> bool:
Expand Down Expand Up @@ -166,40 +216,6 @@ def map_exact_ring_matches_only(self, s: bool):
elif f in self._filter_funcs:
self._filter_funcs.remove(f)

"""
Privat - Serialize
"""

@classmethod
def _from_dict(cls, d: dict):
"""Deserialize from dict representation"""
if any(k not in cls._defaults() for k in d):
keys = list(filter(lambda k: k in cls._defaults(), d.keys()))
raise ValueError(f"I don't know about all the keys here: {keys}")
return cls(**d)

def _to_dict(self) -> dict:
d = {}
for key in self._defaults():
if hasattr(self, key):
d[key] = getattr(self, key)
return d

@classmethod
def _defaults(cls):
"""This method should be overridden to provide the dict of defaults
appropriate for the `GufeTokenizable` subclass.
"""
sig = inspect.signature(cls.__init__)

defaults = {
param.name: param.default
for param in sig.parameters.values()
if param.default is not inspect.Parameter.empty
}

return defaults

"""
Private - Set Operations
"""
Expand All @@ -209,8 +225,8 @@ def _filter_mapping_for_max_overlapping_connected_atom_set(
cls,
moleculeA: Chem.Mol,
moleculeB: Chem.Mol,
atom_mapping: Dict[int, int],
) -> Dict[int, int]:
atom_mapping: dict[int, int],
) -> dict[int, int]:
""" Find connected core region from raw mapping
This algorithm finds the maximal overlapping connected set of
two molecules and a given mapping. In order to accomplish this
Expand Down Expand Up @@ -264,8 +280,8 @@ def _filter_mapping_for_max_overlapping_connected_atom_set(

@staticmethod
def _get_connected_atom_subsets(
mol: Chem.Mol, to_be_searched: List[int]
) -> List[Set[int]]:
mol: Chem.Mol, to_be_searched: list[int]
) -> list[set[int]]:
""" find connected sets in mappings
Get the connected sets of all to_be_searched atom indices in mol.
Connected means the atoms in a resulting connected set are connected
Expand Down Expand Up @@ -356,9 +372,9 @@ def _get_connected_atom_subsets(

@staticmethod
def _get_maximal_mapping_set_overlap(
sets_a: Iterable[Set], sets_b: Iterable[Set],
mapping: Dict[int, int]
) -> Tuple[Set, Set]:
sets_a: Iterable[set], sets_b: Iterable[set],
mapping: dict[int, int]
) -> tuple[set, set]:
"""get the largest set overlaps in the mapping of set_a and set_b.
Parameters
Expand Down Expand Up @@ -410,8 +426,8 @@ def _get_maximal_mapping_set_overlap(

@staticmethod
def _filter_mapping_for_set_overlap(
set_a: Set[int], set_b: Set[int], mapping: Dict[int, int]
) -> Dict[int, int]:
set_a: set[int], set_b: set[int], mapping: dict[int, int]
) -> dict[int, int]:
"""This filter reduces the mapping dict to only in the sets contained
atom IDs
Expand Down Expand Up @@ -479,8 +495,8 @@ def _get_full_distance_matrix(

@staticmethod
def _mask_atoms(
mol, mol_pos, map_hydrogens: bool = False, masked_atoms: List = [],
) -> Tuple[Dict, List]:
mol, mol_pos, map_hydrogens: bool, masked_atoms: list[int],
) -> tuple[dict, list]:
"""Mask atoms such they are not considered during the mapping.
Parameters
Expand Down Expand Up @@ -515,7 +531,7 @@ def _mask_atoms(

def _minimalSpanningTree_map(
self, distance_matrix: NDArray, max_dist: float
) -> Dict[int, int]:
) -> dict[int, int]:
"""MST Mapping
This function is a numpy graph based implementation to build up an
Atom Mapping purely on 3D criteria.
Expand Down Expand Up @@ -561,7 +577,7 @@ def _minimalSpanningTree_map(
@staticmethod
def _linearSumAlgorithm_map(
distance_matrix: NDArray, max_dist: float
) -> Dict[int, int]:
) -> dict[int, int]:
""" LSA mapping
This function is a LSA based implementation to build up an Atom
Mapping purely on 3D criteria.
Expand Down Expand Up @@ -589,7 +605,7 @@ def _linearSumAlgorithm_map(

def _additional_filter_rules(
self, molA: Chem.Mol, molB: Chem.Mol, mapping: dict[int, int]
) -> Dict[int, int]:
) -> dict[int, int]:
"""apply additional filter rules to the given mapping.
Parameters
Expand Down
58 changes: 31 additions & 27 deletions src/kartograf/tests/test_atom_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,39 +169,43 @@ def test_stereo_mapping(stereco_chem_molecules, stereo_chem_mapping):
check_mapping_vs_expected(geom_mapping, expected_mapping)


# Test Serialization
def test_to_from_dict():
mapper = KartografAtomMapper()
d1 = mapper._to_dict()
mapper2 = KartografAtomMapper._from_dict(d1)
d2 = mapper2._to_dict()

for key, val1 in d1.items():
val2 = d2[key]

if(val1 != val2):
raise ValueError("they need to be identical.")
class TestSerialisation:
def test_to_from_dict_cycle(self):
m = KartografAtomMapper()

m_dict = m.to_dict()

m2 = KartografAtomMapper.from_dict(m_dict)

assert m == m2

@pytest.mark.parametrize('mhoho,mermo', [
(True, True),
(True, False),
(False, True),
(False, False),
])
def test_check_filters(self, mhoho, mermo):
m = KartografAtomMapper(
map_hydrogens_on_hydrogens_only=mhoho,
map_exact_ring_matches_only=mermo,
)

mapper2.atom_max_distance = 10
d3 = mapper2._to_dict()
m2 = KartografAtomMapper.from_dict(m.to_dict())

for key, val1 in d1.items():
val2 = d3[key]
assert m._filter_funcs == m2._filter_funcs

if(val1 != val2 and key != "atom_max_distance"):
raise ValueError("they need to be identical.")
if(key == "atom_max_distance" and val1 == val2 ):
raise ValueError("they must not be identical.")
def test_custom_filters(self):
def nop_filter(a, b, c):
return c

m = KartografAtomMapper(
additional_mapping_filter_functions=[nop_filter],
)

def test_to_from_dict_wrong():
mapper = KartografAtomMapper()
d1 = mapper._to_dict()
d1.update({"FLEEEEE": "You FOOLS"})
m2 = KartografAtomMapper.from_dict(m.to_dict())

with pytest.raises(ValueError) as exc:
mapper2 = KartografAtomMapper._from_dict(d1)
assert "I don't know about all the keys here" in str(exc.value)
assert m._filter_funcs == m2._filter_funcs


def test_filter_property():
Expand Down

0 comments on commit bd63ed4

Please sign in to comment.