diff --git a/neurolib/utils/collections.py b/neurolib/utils/collections.py index 59d8fe95..9cff3a70 100644 --- a/neurolib/utils/collections.py +++ b/neurolib/utils/collections.py @@ -1,13 +1,16 @@ """ Collections of custom data structures and types. """ -from dpath.util import search, delete +import random +import string from collections import MutableMapping +from dpath.util import delete, search + DEFAULT_STAR_SEPARATOR = "." -FORWARD_REPLACE = {"*": "STAR"} -BACKWARD_REPLACE = {"STAR": "*"} +FORWARD_REPLACE = {"*": "STAR", "|": "MINUS"} +BACKWARD_REPLACE = {v: k for k, v in FORWARD_REPLACE.items()} class dotdict(dict): @@ -26,9 +29,43 @@ def __setstate__(self, state): class star_dotdict(dotdict): - """Support star notation in dotdict. Nested dicts are now treated as glob""" + """ + Supports star notation in dotdict. Nested dicts are now treated as glob. + Also supports minus as a pipe ("|") for filtering out strings after |. + Example: + Wilson-Cowan node has in total four parameters named tau (time constants + for both excitatory and inhibitory populations, and Ornstein-Uhlenbeck + time constants for both populations), hence: + + > model.params["*tau"] + # returns + {'WCnode_0.WCmassEXC_0.tau': 2.5, + 'WCnode_0.WCmassEXC_0.noise_0.tau': 5.0, + 'WCnode_0.WCmassINH_1.tau': 3.75, + 'WCnode_0.WCmassINH_1.noise_0.tau': 5.0} + + Now imagine you want to make exploration over population time constants, + but keep O-U as is, you can: + + > model.params["*tau|noise"] + # returns + {'WCnode_0.WCmassEXC_0.tau': 2.5, 'WCnode_0.WCmassINH_1.tau': 3.75} + + In other words, the string after "|" is filtered out from all the keys. + This works with setting, getting, and deleting an item and also + parameters defined with minus can be used in `Evolution` and + `Exploration` classes + """ def __getitem__(self, attr): + # if using minus and star notation: split attribute -> search -> filter + if ("|" in attr) and ("*" in attr): + assert attr.count("|") == 1, f"Only one filter allowed: {attr}" + search_attr, filter_out = attr.split("|") + # recursive call __getitem__ without filtering substring + searched = self.__getitem__(search_attr) + # filter + return {k: v for k, v in searched.items() if filter_out not in k} # if using star notation -> return dict of all keys that match if "*" in attr: return search(self, attr, separator=DEFAULT_STAR_SEPARATOR) @@ -37,17 +74,32 @@ def __getitem__(self, attr): return dict.get(self, attr) def __setitem__(self, attr, val): + # if using minus and star notation: split attribute -> search -> filter + if ("|" in attr) and ("*" in attr): + assert attr.count("|") == 1, f"Only one filter allowed: {attr}" + search_attr, filter_out = attr.split("|") + attr = search_attr + else: + # if not, filter is long random string + filter_out = "".join(random.choice(string.ascii_lowercase) for _ in range(30)) # if using star notation -> search and set all keys matching if "*" in attr: for k, _ in search(self, attr, yielded=True, separator=DEFAULT_STAR_SEPARATOR): - setattr(self, k, val) + if filter_out not in k: + setattr(self, k, val) # otherwise -> just __setitem__ else: dict.__setitem__(self, attr, val) def __delitem__(self, attr): + # if using minus and star notation: split attribute -> search -> filter + if ("|" in attr) and ("*" in attr): + assert attr.count("|") == 1, f"Only one filter allowed: {attr}" + key_to_del = self.__getitem__(attr).keys() + for key in key_to_del: + dict.__delitem__(self, key) # if using star notation -> use dpath's delete - if "*" in attr: + elif "*" in attr: delete(self, attr, separator=DEFAULT_STAR_SEPARATOR) # otherwise -> just __delitem__ else: diff --git a/tests/test_collections.py b/tests/test_collections.py index 4c63c49c..7df9d606 100644 --- a/tests/test_collections.py +++ b/tests/test_collections.py @@ -18,8 +18,14 @@ class TestCollections(unittest.TestCase): NESTED_DICT = {"a": {"b": "c", "d": "e"}} FLAT_DICT_DOT = {"a.b": "c", "a.d": "e"} - PARAM_DICT = {"mass0": {"a": 0.4, "b": 1.2, "c": "float"}, "mass1": {"a": 0.4, "b": 1.2, "c": "int"}} + PARAM_DICT = { + "mass0": {"a": 0.4, "b": 1.2, "c": "float", "noise": {"b": 12.0}}, + "mass1": {"a": 0.4, "b": 1.2, "c": "int"}, + } PARAMS_ALL_A = {"mass0.a": 0.4, "mass1.a": 0.4} + PARAMS_ALL_B = {"mass0.b": 1.2, "mass0.noise.b": 12.0, "mass1.b": 1.2} + PARAMS_ALL_B_MINUS = {"mass0.b": 1.2, "mass1.b": 1.2} + PARAMS_ALL_B_MINUS_CHANGED = {"mass0.b": 2.7, "mass1.b": 2.7} PARAMS_ALL_A_CHANGED = {"mass0.a": 0.7, "mass1.a": 0.7} def test_flatten_nested_dict(self): @@ -41,12 +47,28 @@ def test_star_dotdict(self): self.assertDictEqual(params["*a"], self.PARAMS_ALL_A_CHANGED) # delete params del params["*a"] - self.assertFalse("a" in params) + self.assertFalse(params["*a"]) + + def test_star_dotdict_minus(self): + params = star_dotdict(flatten_nested_dict(self.PARAM_DICT), sep=".") + self.assertTrue(isinstance(params, star_dotdict)) + # get params by star + self.assertDictEqual(params["*b"], self.PARAMS_ALL_B) + # get params by star and minus + self.assertDictEqual(params["*b|noise"], self.PARAMS_ALL_B_MINUS) + # change params by star and minus + params["*b|noise"] = 2.7 + self.assertDictEqual(params["*b|noise"], self.PARAMS_ALL_B_MINUS_CHANGED) + # delete params by star and minus + del params["*b|noise"] + self.assertFalse(params["*b|noise"]) + # check whether the `b` with noise stayed + self.assertEqual(len(params["*b"]), 1) def test_sanitize_keys(self): - k = "mass1.tau*" + k = "mass1.tau*|noise" k_san = _sanitize_keys(k, FORWARD_REPLACE) - self.assertEqual(k_san, k.replace("*", "STAR")) + self.assertEqual(k_san, k.replace("*", "STAR").replace("|", "MINUS")) k_back = _sanitize_keys(k_san, BACKWARD_REPLACE) self.assertEqual(k, k_back)