Skip to content

Commit

Permalink
Merge pull request #121 from neurolib-dev/feature/star_dotdict_minus
Browse files Browse the repository at this point in the history
Support filtering from MultiModel parameters
  • Loading branch information
caglorithm authored Jan 20, 2021
2 parents 1120f20 + 35dae7f commit 4382081
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 10 deletions.
64 changes: 58 additions & 6 deletions neurolib/utils/collections.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
"""
Collections of custom data structures and types.
"""
from dpath.util import search, delete
import random
import string
from collections import MutableMapping

from dpath.util import delete, search

DEFAULT_STAR_SEPARATOR = "."

FORWARD_REPLACE = {"*": "STAR"}
BACKWARD_REPLACE = {"STAR": "*"}
FORWARD_REPLACE = {"*": "STAR", "|": "MINUS"}
BACKWARD_REPLACE = {v: k for k, v in FORWARD_REPLACE.items()}


class dotdict(dict):
Expand All @@ -26,9 +29,43 @@ def __setstate__(self, state):


class star_dotdict(dotdict):
"""Support star notation in dotdict. Nested dicts are now treated as glob"""
"""
Supports star notation in dotdict. Nested dicts are now treated as glob.
Also supports minus as a pipe ("|") for filtering out strings after |.
Example:
Wilson-Cowan node has in total four parameters named tau (time constants
for both excitatory and inhibitory populations, and Ornstein-Uhlenbeck
time constants for both populations), hence:
> model.params["*tau"]
# returns
{'WCnode_0.WCmassEXC_0.tau': 2.5,
'WCnode_0.WCmassEXC_0.noise_0.tau': 5.0,
'WCnode_0.WCmassINH_1.tau': 3.75,
'WCnode_0.WCmassINH_1.noise_0.tau': 5.0}
Now imagine you want to make exploration over population time constants,
but keep O-U as is, you can:
> model.params["*tau|noise"]
# returns
{'WCnode_0.WCmassEXC_0.tau': 2.5, 'WCnode_0.WCmassINH_1.tau': 3.75}
In other words, the string after "|" is filtered out from all the keys.
This works with setting, getting, and deleting an item and also
parameters defined with minus can be used in `Evolution` and
`Exploration` classes
"""

def __getitem__(self, attr):
# if using minus and star notation: split attribute -> search -> filter
if ("|" in attr) and ("*" in attr):
assert attr.count("|") == 1, f"Only one filter allowed: {attr}"
search_attr, filter_out = attr.split("|")
# recursive call __getitem__ without filtering substring
searched = self.__getitem__(search_attr)
# filter
return {k: v for k, v in searched.items() if filter_out not in k}
# if using star notation -> return dict of all keys that match
if "*" in attr:
return search(self, attr, separator=DEFAULT_STAR_SEPARATOR)
Expand All @@ -37,17 +74,32 @@ def __getitem__(self, attr):
return dict.get(self, attr)

def __setitem__(self, attr, val):
# if using minus and star notation: split attribute -> search -> filter
if ("|" in attr) and ("*" in attr):
assert attr.count("|") == 1, f"Only one filter allowed: {attr}"
search_attr, filter_out = attr.split("|")
attr = search_attr
else:
# if not, filter is long random string
filter_out = "".join(random.choice(string.ascii_lowercase) for _ in range(30))
# if using star notation -> search and set all keys matching
if "*" in attr:
for k, _ in search(self, attr, yielded=True, separator=DEFAULT_STAR_SEPARATOR):
setattr(self, k, val)
if filter_out not in k:
setattr(self, k, val)
# otherwise -> just __setitem__
else:
dict.__setitem__(self, attr, val)

def __delitem__(self, attr):
# if using minus and star notation: split attribute -> search -> filter
if ("|" in attr) and ("*" in attr):
assert attr.count("|") == 1, f"Only one filter allowed: {attr}"
key_to_del = self.__getitem__(attr).keys()
for key in key_to_del:
dict.__delitem__(self, key)
# if using star notation -> use dpath's delete
if "*" in attr:
elif "*" in attr:
delete(self, attr, separator=DEFAULT_STAR_SEPARATOR)
# otherwise -> just __delitem__
else:
Expand Down
30 changes: 26 additions & 4 deletions tests/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,14 @@
class TestCollections(unittest.TestCase):
NESTED_DICT = {"a": {"b": "c", "d": "e"}}
FLAT_DICT_DOT = {"a.b": "c", "a.d": "e"}
PARAM_DICT = {"mass0": {"a": 0.4, "b": 1.2, "c": "float"}, "mass1": {"a": 0.4, "b": 1.2, "c": "int"}}
PARAM_DICT = {
"mass0": {"a": 0.4, "b": 1.2, "c": "float", "noise": {"b": 12.0}},
"mass1": {"a": 0.4, "b": 1.2, "c": "int"},
}
PARAMS_ALL_A = {"mass0.a": 0.4, "mass1.a": 0.4}
PARAMS_ALL_B = {"mass0.b": 1.2, "mass0.noise.b": 12.0, "mass1.b": 1.2}
PARAMS_ALL_B_MINUS = {"mass0.b": 1.2, "mass1.b": 1.2}
PARAMS_ALL_B_MINUS_CHANGED = {"mass0.b": 2.7, "mass1.b": 2.7}
PARAMS_ALL_A_CHANGED = {"mass0.a": 0.7, "mass1.a": 0.7}

def test_flatten_nested_dict(self):
Expand All @@ -41,12 +47,28 @@ def test_star_dotdict(self):
self.assertDictEqual(params["*a"], self.PARAMS_ALL_A_CHANGED)
# delete params
del params["*a"]
self.assertFalse("a" in params)
self.assertFalse(params["*a"])

def test_star_dotdict_minus(self):
params = star_dotdict(flatten_nested_dict(self.PARAM_DICT), sep=".")
self.assertTrue(isinstance(params, star_dotdict))
# get params by star
self.assertDictEqual(params["*b"], self.PARAMS_ALL_B)
# get params by star and minus
self.assertDictEqual(params["*b|noise"], self.PARAMS_ALL_B_MINUS)
# change params by star and minus
params["*b|noise"] = 2.7
self.assertDictEqual(params["*b|noise"], self.PARAMS_ALL_B_MINUS_CHANGED)
# delete params by star and minus
del params["*b|noise"]
self.assertFalse(params["*b|noise"])
# check whether the `b` with noise stayed
self.assertEqual(len(params["*b"]), 1)

def test_sanitize_keys(self):
k = "mass1.tau*"
k = "mass1.tau*|noise"
k_san = _sanitize_keys(k, FORWARD_REPLACE)
self.assertEqual(k_san, k.replace("*", "STAR"))
self.assertEqual(k_san, k.replace("*", "STAR").replace("|", "MINUS"))
k_back = _sanitize_keys(k_san, BACKWARD_REPLACE)
self.assertEqual(k, k_back)

Expand Down

0 comments on commit 4382081

Please sign in to comment.