Skip to content

Commit

Permalink
Merge pull request #22 from neurolib-dev/feature/evolution_aln_brain
Browse files Browse the repository at this point in the history
Feature/evolution aln brain
  • Loading branch information
caglorithm authored Feb 4, 2020
2 parents cf3593e + 8aff335 commit e3e88b0
Show file tree
Hide file tree
Showing 10 changed files with 180 additions and 162 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ An extensive analysis of this model can be found in our paper and its associated

**Reference:** Cakan, C., Obermayer, K. (2020). Biophysically grounded mean-field models of neural populations under electrical stimulation ([ArXiv](https://arxiv.org/abs/1906.00676)).

The figure below shows a schematic of how a brain network is constructed:

<p align="center">
<img src="resources/pipeline.png" width="700">
</p>
Expand Down Expand Up @@ -206,8 +208,8 @@ That's all! Now you can check the results!
evolution.loadResults()
evolution.info(plot=True)
```
This will give you a summary of the last generation and plot a distribution of the individuals (and their parameters). As you can see in the parameter space cross sections below, all remaining individuals lie on a circle.
This will give you a summary of the last generation and plot a distribution of the individuals (and their parameters). Below is an animation of 10 generations of the evolutionary process. As you can see, after a couple of generations, all remaining individuals lie on a circle.

<p align="center">
<img src="resources/evolution_minimal.png">
<img src="resources/evolution_animated.gif">
</p>
2 changes: 1 addition & 1 deletion codecov.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
coverage:
precision: 2
round: nearest
range: "40...70"
range: "40...100"
ignore:
- "neurolib/models/aln/timeIntegration.py"
- "neurolib/models/bold/timeIntegration.py"
Expand Down
80 changes: 37 additions & 43 deletions examples/example-0-aln-minimal.ipynb

Large diffs are not rendered by default.

112 changes: 44 additions & 68 deletions examples/example-2-evolutionary-optimization-minimal.ipynb

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions examples/example-2.1-evolutionary-optimization-aln.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
"# change into root directory to the project\n",
"import os\n",
"if os.getcwd().split(\"/\")[-1] == \"examples\":\n",
" print(\"chdir\")\n",
" os.chdir('..')"
]
},
Expand Down Expand Up @@ -200,7 +199,7 @@
"# range to start with might be 20-100.\n",
"\n",
"evolution = Evolution(evalFunction = evaluateSimulation, parameterSpace = pars, model = aln, \n",
" weightList = [-1.0], POP_INIT_SIZE=6, POP_SIZE = 4, NGEN=3)\n",
" weightList = [-1.0], POP_INIT_SIZE=4, POP_SIZE = 4, NGEN=2)\n",
"# info: chose POP_INIT_SIZE=50, POP_SIZE = 20, NGEN=20 for real exploration, \n",
"# values are lower here for testing\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion neurolib/models/aln/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def run(self):
rates_exc, rates_inh, t, mufe, mufi, IA, seem, seim, siem, siim, seev, seiv, siev, siiv = return_tuple
self.t_BOLD = t_BOLD
self.BOLD = BOLD
Model.setOutput(self, "BOLD.t", t_BOLD)
Model.setOutput(self, "BOLD.t_BOLD", t_BOLD)
Model.setOutput(self, "BOLD.BOLD", BOLD)
else:
(
Expand Down
75 changes: 45 additions & 30 deletions neurolib/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,21 @@ class dotdict(dict):
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__

# now pickleable!!!
def __getstate__(self):
return dict(self)

def __setstate__(self, state):
self.update(state)


class Model:
"""The Model superclass manages inputs and outputs of all models.
"""

# working
defaultOutput = None
outputs = {}

# possiby deprecated
inputNames = []
inputs = []
nInputs = 0
outputNames = None
xrs = {}
outputs = dotdict({})

def __init__(self, name, description=None):
assert isinstance(name, str), f"Model name is not a string."
Expand All @@ -39,28 +39,34 @@ def setOutput(self, name, data):
"""
assert not isinstance(data, dict), "Output data cannot be a dictionary."
assert isinstance(name, str), "Output name must be a string."
# set output as an attribute
setattr(self, name, data)

# build results dictionary and write into self.outputs

keys = name.split(".")
level = self.outputs
for i, k in enumerate(keys):
# if it's the last iteration, it's data
if i == len(keys) - 1:
level[k] = data
# if it's a known key, then go deeper
elif k in level:
level = level[k]
# if it's a new key, create new nested dictionary and go deeper
else:
level[k] = dotdict({})
setattr(self, k, level[k])
level = level[k]
# if the output is a single name (not dot.separated)
if "." not in name:
# save into output dict
self.outputs[name] = data
# set output as an attribute
setattr(self, name, self.outputs[name])
else:
# build results dictionary and write into self.outputs
# dot.notation iteration
keys = name.split(".")
level = self.outputs # not copy, reference!
for i, k in enumerate(keys):
# if it's the last iteration, store data
if i == len(keys) - 1:
level[k] = data
# if key is in outputs, then go deeper
elif k in level:
level = level[k]
setattr(self, k, level)
# if it's a new key, create new nested dictionary, set attribute, then go deeper
else:
level[k] = dotdict({})
setattr(self, k, level[k])
level = level[k]

def getOutput(self, name):
"""Get an output.
"""Get an output of a given name (dot.semarated)
:param name: A key, grouped outputs in the form group.subgroup.variable
:type name: str
Expand Down Expand Up @@ -136,9 +142,18 @@ def xr(self, group=""):
# take all outputs of one group: disregard all dictionaries because they are subgroups
outputDict = self.getOutputs(group)
# make sure that there is a time array
assert "t" in outputDict, f"There is no time array (called t) in the output group."
t = outputDict["t"].copy()
del outputDict["t"]
timeDictKey = ""
if "t" in outputDict:
timeDictKey = "t"
else:
for k in outputDict:
if k.startswith("t"):
timeDictKey = k
logging.info(f"Assuming {k} to be the time axis.")
break
assert len(timeDictKey) > 0, f"No time array found (starting with t) in output group {group}."
t = outputDict[timeDictKey].copy()
del outputDict[timeDictKey]
outputs = []
outputNames = []
for key, value in outputDict.items():
Expand Down
28 changes: 20 additions & 8 deletions neurolib/optimize/evolution/evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(
self,
evalFunction,
parameterSpace,
weightList,
weightList=None,
model=None,
hdf_filename="evolution.hdf",
ncores=None,
Expand All @@ -49,6 +49,10 @@ def __init__(
:param NGEN: Numbers of generations to evaluate
:param CXPB: Crossover probability of each individual gene
"""
if weightList is None:
logging.info("weightList not set, assuming single fitness value to be maximized.")
weightList = [1.0]

trajectoryName = "results" + datetime.datetime.now().strftime("-%Y-%m-%d-%HH-%MM-%SS")
self.HDF_FILE = os.path.join(paths.HDF_DIR, hdf_filename)
trajectoryFileName = self.HDF_FILE
Expand All @@ -72,7 +76,11 @@ def __init__(

# Get the trajectory from the environment
traj = env.traj
trajectoryName = traj.v_name
# Sanity check if everything went ok
assert (
trajectoryName == traj.v_name
), f"Pypet trajectory has a different name than trajectoryName {trajectoryName}"
# trajectoryName = traj.v_name

self.model = model
self.evalFunction = evalFunction
Expand Down Expand Up @@ -238,10 +246,16 @@ def evalPopulationUsingPypet(self, traj, toolbox, pop, gIdx):

for idx, result in enumerate(evolutionResult):
runIndex, packedReturnFromEvalFunction = result
# this is the return from the evaluation function
fitnessesResult, outputs = packedReturnFromEvalFunction

# packedReturnFromEvalFunction is the return from the evaluation function
# it has length two, the first is the fitness, second is the model output
assert (
len(packedReturnFromEvalFunction) == 2
), "Evaluation function must return tuple with shape (fitness, output_data)"

fitnessesResult, returnedOutputs = packedReturnFromEvalFunction
pop[idx].outputs = returnedOutputs
# store outputs of simulations in population
pop[idx].outputs = outputs
pop[idx].fitness.values = fitnessesResult
# mean fitness value
pop[idx].fitness.score = np.nansum(pop[idx].fitness.wvalues) / (len(pop[idx].fitness.wvalues))
Expand Down Expand Up @@ -341,7 +355,7 @@ def runEvolution(self):
if self.verbose:
eu.printParamDist(self.pop, self.paramInterval, self.gIdx)
eu.printPopFitnessStats(
self.pop, self.paramInterval, self.gIdx, draw_scattermatrix=True, save_plots="evo"
self.pop, self.paramInterval, self.gIdx, draw_scattermatrix=True, save_plots=self.trajectoryName
)

# save all simulation data to pypet
Expand Down Expand Up @@ -390,8 +404,6 @@ def getScoresDuringEvolution(self, traj=None, drop_first=True, reverse=False):
if reverse:
generation_names = generation_names[::-1]
if drop_first:
# drop first (initial) generation 0
# generation_names = generation_names[1:]
generation_names.remove("gen_000000")

npop = len(traj.results.evolution[generation_names[0]].scores)
Expand Down
34 changes: 27 additions & 7 deletions neurolib/optimize/evolution/evolutionaryUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ def printParamDist(pop=None, paramInterval=None, gIdx=None):

print("Parameters dictribution (Generation {}):".format(gIdx))
for idx, k in enumerate(paramInterval._fields):
print("{}: \t mean: {:.4},\t std: {:.4}".format(k, np.mean([indiv[idx] for indiv in pop]), np.std([indiv[idx] for indiv in pop]),))
print(
"{}: \t mean: {:.4},\t std: {:.4}".format(
k, np.mean([indiv[idx] for indiv in pop]), np.std([indiv[idx] for indiv in pop]),
)
)


def printIndividuals(pop, paramInterval, stats=False):
Expand All @@ -35,11 +39,21 @@ def printIndividuals(pop, paramInterval, stats=False):
thesepars["fit"] = np.mean(ind.fitness.values)
pars.append(thesepars)
print(
"Individual", i, "pars", ", ".join([" ".join([k, "{0:.4}".format(ind[ki])]) for ki, k in enumerate(paramInterval._fields)]),
"Individual",
i,
"pars",
", ".join([" ".join([k, "{0:.4}".format(ind[ki])]) for ki, k in enumerate(paramInterval._fields)]),
)
print("\tFitness values: ", *np.round(ind.fitness.values, 4))
if stats:
print("\t > mean {0:.4}, std {0:.4}, min {0:.4} max {0:.4}".format(np.mean(ind.fitness.values), np.std(ind.fitness.values), np.min(ind.fitness.values), np.max(ind.fitness.values),))
print(
"\t > mean {0:.4}, std {0:.4}, min {0:.4} max {0:.4}".format(
np.mean(ind.fitness.values),
np.std(ind.fitness.values),
np.min(ind.fitness.values),
np.max(ind.fitness.values),
)
)


def printPopFitnessStats(
Expand All @@ -52,6 +66,10 @@ def printPopFitnessStats(
"""
Print some stats of a population fitness
"""
if save_plots:
if not os.path.exists(paths.FIGURES_DIR):
os.makedirs(paths.FIGURES_DIR)

# Gather all the fitnesses in one list and print the stats
# selectPop = [p for p in pop if not np.isnan(p.fitness.score)]
selectPop = [p for p in pop if not np.any(np.isnan(p.fitness.values))]
Expand All @@ -67,8 +85,10 @@ def printPopFitnessStats(
plt.xlabel("Score")
plt.ylabel("Count")
if save_plots is not None:
logging.info("Saving plot to {}".format(os.path.join(paths.FIGURES_DIR, "%s_hist_%i.jpg" % (save_plots, gIdx))))
plt.savefig(os.path.join(paths.FIGURES_DIR, "%s_hist_%i.jpg" % (save_plots, gIdx)))
logging.info(
"Saving plot to {}".format(os.path.join(paths.FIGURES_DIR, "%s_hist_%i.png" % (save_plots, gIdx)))
)
plt.savefig(os.path.join(paths.FIGURES_DIR, "%s_hist_%i.png" % (save_plots, gIdx)))
plt.show()

if draw_scattermatrix:
Expand All @@ -80,7 +100,7 @@ def printPopFitnessStats(
plt.figure()
sm = sns.pairplot(pcandidates, diag_kind="kde", kind="reg")
if save_plots is not None:
plt.savefig(os.path.join(paths.FIGURES_DIR, "{}_sns_params_{}.jpg".format(save_plots, gIdx)))
plt.savefig(os.path.join(paths.FIGURES_DIR, "{}_sns_params_{}.png".format(save_plots, gIdx)))
plt.show()

# Seaborn Plotting
Expand All @@ -92,7 +112,7 @@ def printPopFitnessStats(
grid = grid.map_diag(plt.hist, bins=10, color="darkred", edgecolor="k")
grid = grid.map_lower(sns.kdeplot, cmap="Reds")
if save_plots is not None:
plt.savefig(os.path.join(paths.FIGURES_DIR, "{}_sns_params_red_{}.jpg".format(save_plots, gIdx)))
plt.savefig(os.path.join(paths.FIGURES_DIR, "{}_sns_params_red_{}.png".format(save_plots, gIdx)))
plt.show()
except:
pass
Expand Down
Binary file added resources/evolution_animated.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit e3e88b0

Please sign in to comment.