Skip to content

Commit

Permalink
Added APThreshold and Burstiness tests; default stim params; auto-res…
Browse files Browse the repository at this point in the history
…et of LEMSModel run params within a suite; PEP-8 cleanup
  • Loading branch information
rgerkin committed May 20, 2016
1 parent e8fdb2a commit f146c2a
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 40 deletions.
7 changes: 6 additions & 1 deletion neuronunit/capabilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import sciunit
from sciunit import Capability
from .spike_functions import spikes2amplitudes,spikes2widths
from .spike_functions import spikes2amplitudes,spikes2widths,spikes2thresholds
from .channel import *

class ProducesMembranePotential(Capability):
Expand Down Expand Up @@ -84,6 +84,11 @@ def get_AP_amplitudes(self):
amplitudes = spikes2amplitudes(action_potentials)
return amplitudes

def get_AP_thresholds(self):
action_potentials = self.get_APs()
thresholds = spikes2thresholds(action_potentials)
return thresholds


class ReceivesCurrent(Capability):
"""Indicates that somatic current can be injected into the model."""
Expand Down
22 changes: 21 additions & 1 deletion neuronunit/capabilities/spike_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,24 @@ def spikes2widths(spike_waveforms):
if n_spikes:
widths *= s.sampling_period # Convert from samples to time.
#print("Spike widths are %s" % str(widths))
return widths
return widths

def spikes2thresholds(spike_waveforms):
"""
IN:
spike_waveforms: Spike waveforms, e.g. from get_spike_waveforms().
neo.core.AnalogSignalArray
OUT:
1D numpy array of spike thresholds, specifically the membrane potential
at which 1/10 the maximum slope is reached.
"""
n_spikes = len(spike_waveforms)
thresholds = []
for i,s in enumerate(spike_waveforms):
s = np.array(s)
dvdt = np.diff(s)
trigger = dvdt.max()/10
x_loc = np.where(dvdt >= trigger)[0][0]
thresh = (s[x_loc]+s[x_loc+1])/2
thresholds.append(thresh)
return thresholds * spike_waveforms.units
2 changes: 2 additions & 0 deletions neuronunit/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def run(self, rerun=None, **run_params):
verbose=self.run_params['v'])
self.last_run_params = deepcopy(self.run_params)
self.rerun = False
self.run_params = {} # Reset run parameters so the next test has to pass
# its own run parameters and not use the same ones

def update_run_params(self):
from lxml import etree
Expand Down
152 changes: 124 additions & 28 deletions neuronunit/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
from neuronunit import neuroelectro
from .channel import *

AMPL = 0.0*pq.pA
DELAY = 100.0*pq.ms
DURATION = 1000.0*pq.ms

class VmTest(sciunit.Test):
"""Base class for tests involving the membrane potential of a model."""

Expand Down Expand Up @@ -44,7 +48,8 @@ def __init__(self,
# Observation values with units.
united_observation_keys = ['value','mean','std']

def validate_observation(self, observation, united_keys=['value','mean'], nonunited_keys=[]):
def validate_observation(self, observation,
united_keys=['value','mean'], nonunited_keys=[]):
try:
assert type(observation) is dict
assert any([key in observation for key in united_keys]) \
Expand All @@ -54,7 +59,8 @@ def validate_observation(self, observation, united_keys=['value','mean'], nonuni
assert type(observation[key]) is Quantity
for key in nonunited_keys:
if key in observation:
assert type(observation[key]) is not Quantity
assert type(observation[key]) is not Quantity \
or observation[key].units == pq.Dimensionless
except Exception as e:
key_str = 'and/or a '.join(['%s key' % key for key in united_keys])
msg = ("Observation must be a dictionary with a %s and each key "
Expand All @@ -76,13 +82,15 @@ def bind_score(self, score, model, observation, prediction):
score.related_data['vm'] = model.get_membrane_potential()
score.related_data['model_name'] = '%s_%s' % (model.name,self.name)

def plot_vm(self,ax=None,ylim=(-80,20)):
def plot_vm(self,ax=None,ylim=(None,None)):
"""A plot method the score can use for convenience."""
if ax is None:
ax = plt.gca()
vm = score.related_data['vm'].rescale('mV')
ax.plot(vm.times,vm)
ax.set_ylim(ylim)
y_min = float(vm.min()-5.0*pq.mV) if ylim[0] is None else ylim[0]
y_max = float(vm.max()+5.0*pq.mV) if ylim[1] is None else ylim[1]
ax.set_ylim(y_min,y_max)
ax.set_xlabel('Time (s)')
ax.set_ylabel('Vm (mV)')
score.plot_vm = MethodType(plot_vm, score) # Bind to the score.
Expand All @@ -92,9 +100,11 @@ def plot_vm(self,ax=None,ylim=(-80,20)):
def neuroelectro_summary_observation(cls, neuron):
reference_data = neuroelectro.NeuroElectroSummary(
neuron = neuron, # Neuron type lookup using the NeuroLex ID.
ephysprop = {'name': cls.ephysprop_name} # Ephys property name in the NeuroElectro ontology.
ephysprop = {'name': cls.ephysprop_name} # Ephys property name in
# NeuroElectro ontology.
)
reference_data.get_values() # Get and verify summary data from neuroelectro.org.
reference_data.get_values() # Get and verify summary data
# from neuroelectro.org.
observation = {'mean': reference_data.mean*cls.units,
'std': reference_data.std*cls.units,
'n': reference_data.n}
Expand All @@ -114,9 +124,8 @@ class InputResistanceTest(VmTest):

units = pq.ohm*1e6

params = {'injected_current':
{'amplitude':-10.0*pq.pA, 'delay':100*pq.ms,
'duration':500*pq.ms}}
params = {'injected_current':
{'amplitude':-10.0*pq.pA, 'delay':DELAY, 'duration':DURATION}}

ephysprop_name = 'Input Resistance'

Expand Down Expand Up @@ -161,7 +170,8 @@ class APWidthTest(VmTest):

def generate_prediction(self, model, verbose=False):
"""Implementation of sciunit.Test.generate_prediction."""
# Method implementation guaranteed by ProducesActionPotentials capability.
# Method implementation guaranteed by
# ProducesActionPotentials capability.
model.rerun = True
widths = model.get_AP_widths()
# Put prediction in a form that compute_score() can use.
Expand All @@ -175,7 +185,8 @@ def compute_score(self, observation, prediction, verbose=False):
if prediction['n'] == 0:
score = scores.InsufficientDataScore(None)
else:
score = super(APWidthTest,self).compute_score(observation,prediction)
score = super(APWidthTest,self).compute_score(observation,
prediction)
return score


Expand All @@ -187,8 +198,9 @@ class InjectedCurrentAPWidthTest(APWidthTest):

required_capabilities = (cap.ReceivesCurrent,)

params = {'injected_current':{'amplitude':100.0*pq.pA}}

params = {'injected_current':
{'amplitude':100.0*pq.pA, 'delay':DELAY, 'duration':DURATION}}

name = "Injected current AP width test"

description = ("A test of the widths of action potentials "
Expand All @@ -197,7 +209,8 @@ class InjectedCurrentAPWidthTest(APWidthTest):

def generate_prediction(self, model, verbose=False):
model.inject_current(self.params['injected_current'])
return super(InjectedCurrentAPWidthTest,self).generate_prediction(model, verbose=verbose)
return super(InjectedCurrentAPWidthTest,self).\
generate_prediction(model, verbose=verbose)


class APAmplitudeTest(VmTest):
Expand All @@ -207,7 +220,7 @@ class APAmplitudeTest(VmTest):

name = "AP amplitude test"

description = ("A test of the heights (peak amplitude) of "
description = ("A test of the amplitude (peak minus threshold) of "
"action potentials.")

score_type = scores.ZScore
Expand All @@ -218,9 +231,10 @@ class APAmplitudeTest(VmTest):

def generate_prediction(self, model, verbose=False):
"""Implementation of sciunit.Test.generate_prediction."""
# Method implementation guaranteed by ProducesActionPotentials capability.
# Method implementation guaranteed by
# ProducesActionPotentials capability.
model.rerun = True
heights = model.get_AP_amplitudes()
heights = model.get_AP_amplitudes() - model.get_AP_thresholds()
# Put prediction in a form that compute_score() can use.
prediction = {'mean':np.mean(heights) if len(heights) else None,
'std':np.std(heights) if len(heights) else None,
Expand All @@ -236,6 +250,20 @@ def compute_score(self, observation, prediction, verbose=False):
prediction, verbose=verbose)
return score

@classmethod
def neuroelectro_summary_observation(cls, neuron):
reference_data = neuroelectro.NeuroElectroSummary(
neuron = neuron, # Neuron type lookup using the NeuroLex ID.
ephysprop = {'name': cls.ephysprop_name} # Ephys property name in
# NeuroElectro ontology.
)
reference_data.get_values() # Get and verify summary data
# from neuroelectro.org.
observation = {'mean': reference_data.mean*cls.units,
'std': reference_data.std*cls.units,
'n': reference_data.n}
return observation


class InjectedCurrentAPAmplitudeTest(APAmplitudeTest):
"""
Expand All @@ -245,8 +273,9 @@ class InjectedCurrentAPAmplitudeTest(APAmplitudeTest):

required_capabilities = (cap.ReceivesCurrent,)

params = {'injected_current':{'amplitude':100.0*pq.pA}}

params = {'injected_current':
{'amplitude':100.0*pq.pA, 'delay':DELAY, 'duration':DURATION}}

name = "Injected current AP amplitude test"

description = ("A test of the heights (peak amplitudes) of "
Expand All @@ -255,7 +284,68 @@ class InjectedCurrentAPAmplitudeTest(APAmplitudeTest):

def generate_prediction(self, model, verbose=False):
model.inject_current(self.params['injected_current'])
return super(InjectedCurrentAPAmplitudeTest,self).generate_prediction(model, verbose=verbose)
return super(InjectedCurrentAPAmplitudeTest,self).\
generate_prediction(model, verbose=verbose)


class APThresholdTest(VmTest):
"""Tests the full widths of action potentials at their half-maximum."""

required_capabilities = (cap.ProducesActionPotentials,)

name = "AP threshold test"

description = ("A test of the membrane potential threshold at which "
"action potentials are produced.")

score_type = scores.ZScore

units = pq.mV

ephysprop_name = 'Spike Threshold'

def generate_prediction(self, model, verbose=False):
"""Implementation of sciunit.Test.generate_prediction."""
# Method implementation guaranteed by
# ProducesActionPotentials capability.
model.rerun = True
threshes = model.get_AP_thresholds()
# Put prediction in a form that compute_score() can use.
prediction = {'mean':np.mean(threshes) if len(threshes) else None,
'std':np.std(threshes) if len(threshes) else None,
'n':len(threshes)}
return prediction

def compute_score(self, observation, prediction, verbose=False):
"""Implementation of sciunit.Test.score_prediction."""
if prediction['n'] == 0:
score = scores.InsufficientDataScore(None)
else:
score = super(APThresholdTest,self).compute_score(observation,
prediction)
return score


class InjectedCurrentAPThresholdTest(APThresholdTest):
"""
Tests the thresholds of action potentials
under current injection.
"""

required_capabilities = (cap.ReceivesCurrent,)

params = {'injected_current':
{'amplitude':100.0*pq.pA, 'delay':DELAY, 'duration':DURATION}}

name = "Injected current AP threshold test"

description = ("A test of the membrane potential threshold at which "
"action potentials are produced under current injection.")

def generate_prediction(self, model, verbose=False):
model.inject_current(self.params['injected_current'])
return super(InjectedCurrentAPThresholdTest,self).\
generate_prediction(model, verbose=verbose)


class RheobaseTest(VmTest):
Expand All @@ -267,8 +357,9 @@ class RheobaseTest(VmTest):
required_capabilities = (cap.ReceivesCurrent,
cap.ProducesSpikes)

params = {'injected_current':{'amplitude':0.0*pq.pA}}

params = {'injected_current':
{'amplitude':100.0*pq.pA, 'delay':DELAY, 'duration':DURATION}}

name = "Rheobase test"

description = ("A test of the rheobase, i.e. the minimum injected current "
Expand All @@ -280,12 +371,13 @@ class RheobaseTest(VmTest):

def generate_prediction(self, model, verbose=False):
"""Implementation of sciunit.Test.generate_prediction."""
# Method implementation guaranteed by ProducesActionPotentials capability.
# Method implementation guaranteed by
# ProducesActionPotentials capability.
prediction = {'value': None}
model.rerun = True
units = self.observation['value'].units

lookup = self.threshold_FI(model, units)
lookup = self.threshold_FI(model, units, verbose=verbose)
sub = np.array([x for x in lookup if lookup[x]==0])*units
supra = np.array([x for x in lookup if lookup[x]>0])*units

Expand All @@ -309,14 +401,16 @@ def generate_prediction(self, model, verbose=False):

return prediction

def threshold_FI(self, model, units, guess=None):
def threshold_FI(self, model, units, guess=None, verbose=False):
lookup = {} # A lookup table global to the function below.

def f(ampl):
if float(ampl) not in lookup:
model.inject_current({'amplitude':ampl})
n_spikes = model.get_spike_count()
print("Injected %s current and got %d spikes" % (ampl,n_spikes))
if verbose:
print("Injected %s current and got %d spikes" % \
(ampl,n_spikes))
lookup[float(ampl)] = n_spikes

max_iters = 10
Expand Down Expand Up @@ -354,7 +448,9 @@ def compute_score(self, observation, prediction, verbose=False):
if prediction['value'] is None:
score = scores.InsufficientDataScore(None)
else:
score = super(RheobaseTest,self).compute_score(observation, prediction, verbose=verbose)
score = super(RheobaseTest,self).\
compute_score(observation, prediction, verbose=verbose)
#self.bind_score(score,None,observation,prediction)
return score


Expand Down Expand Up @@ -386,7 +482,7 @@ def generate_prediction(self, model, verbose=False):
"""Implementation of sciunit.Test.generate_prediction."""
model.rerun = True
model.inject_current({'amplitude':0.0*pq.pA})
median = model.get_median_vm() # Use median instead of mean for robustness.
median = model.get_median_vm() # Use median for robustness.
std = model.get_std_vm()
prediction = {'mean':median, 'std':std}
return prediction
Loading

0 comments on commit f146c2a

Please sign in to comment.