Skip to content

Commit

Permalink
Merge pull request #178 from neurolib-dev/fix/exploration_custom_eval…
Browse files Browse the repository at this point in the history
…_multimodel

BoxSearch with custom evaluate function
  • Loading branch information
caglorithm authored Aug 5, 2021
2 parents f937746 + a5577c7 commit 696b099
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 6 deletions.
2 changes: 2 additions & 0 deletions neurolib/optimize/exploration/exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,8 @@ def getModelFromTraj(self, traj):
"""
model = self.model
runParams = self.getParametersFromTraj(traj)
if self.parameterSpace.star:
runParams = flatten_nested_dict(flat_dict_to_nested(runParams)["parameters"])

model.params.update(runParams)
return model
Expand Down
13 changes: 12 additions & 1 deletion neurolib/utils/stimulus.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,20 @@ def update_params(self, params_dict):
:param params_dict: New parameters for this input
:type params_dict: dict
"""

def _sanitize(value):
"""
Change string `None` to actual None - can happen with Exploration or
Evolution, since `pypet` does None -> "None".
"""
if value == "None":
return None
else:
return value

for param, value in params_dict.items():
if hasattr(self, param):
setattr(self, param, value)
setattr(self, param, _sanitize(value))

def _get_times(self, duration, dt):
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


def randomString(stringLength=10):
"""Generate a random string of fixed length """
"""Generate a random string of fixed length"""
letters = string.ascii_lowercase
return "".join(random.choice(letters) for i in range(stringLength))

Expand Down
8 changes: 4 additions & 4 deletions tests/test_stimulus.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def test_set_params(self):
n=2,
seed=42,
)
UPDATE = {"amplitude": 43.0, "seed": 12}
UPDATE = {"amplitude": 43.0, "seed": 12, "start": "None"}
sin.update_params(UPDATE)
params = sin.get_params()
params.pop("type")
Expand All @@ -249,10 +249,10 @@ def test_set_params(self):
"seed": 42,
"frequency": self.FREQUENCY,
"amplitude": self.AMPLITUDE,
"start": STIM_START,
"dc_bias": False,
"end": STIM_END,
**UPDATE,
"start": None,
},
)

Expand Down Expand Up @@ -318,7 +318,7 @@ def test_set_params(self):
n=2,
seed=42,
)
UPDATE = {"amplitude": 43.0, "seed": 12}
UPDATE = {"amplitude": 43.0, "seed": 12, "start": "None"}
sq.update_params(UPDATE)
params = sq.get_params()
params.pop("type")
Expand All @@ -329,10 +329,10 @@ def test_set_params(self):
"seed": 42,
"frequency": self.FREQUENCY,
"amplitude": self.AMPLITUDE,
"start": STIM_START,
"end": STIM_END,
"dc_bias": False,
**UPDATE,
"start": None,
},
)

Expand Down

0 comments on commit 696b099

Please sign in to comment.