Skip to content

Commit

Permalink
Merge pull request #401 from Hjorthmedh/region_mesh_redux
Browse files Browse the repository at this point in the history
Updated input.py
  • Loading branch information
Hjorthmedh authored Jan 15, 2024
2 parents 68da066 + 752aa70 commit a7c251d
Show file tree
Hide file tree
Showing 6 changed files with 279 additions and 51 deletions.
249 changes: 206 additions & 43 deletions examples/parallel/KTH_PDC/input_tuning/Verify dSPN input tuning.ipynb

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion snudda/input/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def write_hdf5(self):
it_group.attrs["parameterFile"] = neuron_in["parameterFile"]

# We need to convert this to string to be able to save it
if "parameterList" in neuron_in:
if "parameterList" in neuron_in and neuron_in["parameterList"] is not None:
# We only need to save the synapse parameters in the file
syn_par_list = [x["synapse"] for x in neuron_in["parameterList"] if "synapse" in x]
if len(syn_par_list) > 0:
Expand Down Expand Up @@ -719,6 +719,9 @@ def make_neuron_input_parallel(self):

# Let input.json info override meta.json input parameters if given
for key, data in old_info.items():
if key == "parameterList" and data is None:
continue

inp_data_copy[key] = data

input_info[inp_name] = inp_data_copy
Expand Down
58 changes: 54 additions & 4 deletions snudda/input/input_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,12 +276,12 @@ def setup_input_verification(self, input_type="cortical", neuron_type="dSPN",
input_end = input_duration * np.arange(1, n_steps + 1)

input_target = neuron_type

if input_target not in self.input_type:
self.input_info[input_target] = collections.OrderedDict()
self.input_info = collections.OrderedDict()
self.input_info[input_target] = collections.OrderedDict()

self.input_info[input_target][input_type] = collections.OrderedDict()

self.input_info[input_target][input_type]["generator"] = "poisson"
self.input_info[input_target][input_type]["start"] = input_start
self.input_info[input_target][input_type]["end"] = input_end
self.input_info[input_target][input_type]["frequency"] = self.frequency_range
Expand Down Expand Up @@ -988,6 +988,44 @@ def plot_frequency_data_alt(self, frequency_data, show_plots=True, input_type_na
if not show_plots:
plt.close()

def plot_verify_frequency_distribution(self, input_type="cortical"):

network_info, input_config, input_data, neuron_id_lookup, neuron_name_list, \
spike_data, volt, time = self.load_data_helper()

# First find out what time ranges we need to look at, and what the input frequency is for those
neuron_type = list(input_config.keys())
assert len(neuron_type) == 1, f"Plot only supports one neuron type at a time, neuron_type = {neuron_type}"
neuron_type = neuron_type[0]

assert input_type in input_config[neuron_type], f"{input_type} not in input_config: {input_config}"
start_times = np.array(input_config[neuron_type][input_type]["start"])
end_times = np.array(input_config[neuron_type][input_type]["end"])
input_freq = np.array(input_config[neuron_type][input_type]["frequency"])
output_freq_list = []

for neuron_id in sorted(spike_data.keys()):
out_freq = []
for start_t, end_t in zip(start_times, end_times):
n_spikes = np.sum(np.logical_and(start_t <= spike_data[neuron_id], spike_data[neuron_id] < end_t))
out_freq.append(n_spikes / (end_t - start_t))

output_freq_list.append(out_freq)

output_freq = np.array(output_freq_list).T

plt.figure()
plt.plot(input_freq, output_freq, 'k')
plt.xlabel("Input frequency")
plt.ylabel("Output frequency")
plt.ion()
plt.show()

import pdb
pdb.set_trace()



def get_neuron_info(self, neuron_path):

neuron_path = snudda_parse_path(neuron_path, self.snudda_data)
Expand Down Expand Up @@ -1176,6 +1214,9 @@ def create_network_config(self,
else:
neuron_def = self.gather_all_neurons(neuron_types=neuron_types, all_combinations=all_combinations)

# Just generate a set of points
volume_def[vol_name]["n_putative_points"] = max(len(neuron_def.keys())*5, 10000)

fake_axon_density = ["r", "1", 10e-6]

for n in neuron_def.keys():
Expand Down Expand Up @@ -1501,12 +1542,16 @@ def write_tuning_info(self):
help="Optional, if only we want to simulate one neuron subtype, eg. FS_1")
parser.add_argument("--meta_input", action="store_true", default=False)
parser.add_argument("--seed_list", type=str, default=None)
parser.add_argument("--no_downsampling", action="store_true")

args = parser.parse_args()

# TODO: Let the user choose input type, duration for each "run", frequency range, number of input range

if args.seed_list is not None:
seed_list = ast.literal_eval(args.seed_list)
else:
seed_list = None

input_scaling = InputTuning(args.networkPath, input_seed_list=seed_list)

Expand Down Expand Up @@ -1559,7 +1604,12 @@ def write_tuning_info(self):
print("Tip, to run in parallel on your local machine use: "
"mpiexec -n 4 python3 tuning/input_tuning.py simulate <yournetworkhere>")

input_scaling.simulate(mech_dir=args.mechDir)
if args.no_downsampling:
sample_dt = None
else:
sample_dt = 0.01

input_scaling.simulate(mech_dir=args.mechDir, sample_dt=sample_dt)

elif args.action == "analyse":
# input_scaling.plot_generated_input()
Expand Down
2 changes: 1 addition & 1 deletion snudda/place/region_mesh_redux.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def __init__(self):

if __name__ == "__main__":

mesh_path="../data/mesh/Striatum-d-right.obj"
mesh_path = "../data/mesh/Striatum-d-right.obj"

# nep = NeuronPlacer(mesh_path=mesh_path, d_min=10e-6, n_putative_points=10000000)
nep = NeuronPlacer(mesh_path=mesh_path, d_min=10e-6, n_putative_points=None, putative_density=100e3)
Expand Down
8 changes: 7 additions & 1 deletion snudda/plotting/plot_traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ def neuron_name(self, neuron_type):
def plot_traces(self, trace_id=None, offset=150e-3, colours=None, skip_time=None, time_range=None,
line_width=1, fig_size=None,
mark_current=None, mark_current_y=None,
title=None, fig_name=None, mark_depolarisation_block=True):
title=None, fig_name=None,
mark_depolarisation_block=True,
mark_spikes=True):

"""
Plot the traces of neuron trace_id
Expand Down Expand Up @@ -212,6 +214,10 @@ def plot_traces(self, trace_id=None, offset=150e-3, colours=None, skip_time=None
plt.plot([depol_start_t-skip_time, depol_end_t-skip_time],
[ofs, ofs], color="red", linewidth=line_width)

if mark_spikes:
spike_times = self.output_load.get_spikes(neuron_id=r).flatten()
plt.plot(spike_times, np.full(spike_times.shape, fill_value=offset), 'r*')

if offset:
ofs += offset

Expand Down
8 changes: 7 additions & 1 deletion snudda/simulate/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,9 @@ def setup_neurons(self):

self.pc.cell(ID, nc, 1) # The 1 means broadcast spikes to other machines

# print(f"Spike threshold for neuron {ID} is {self.pc.threshold(ID)}")
# self.pc.threshold(ID, new_threshold) # if we want to update threshold, obs in mV

# Record all spikes
t_spikes = h.Vector()
id_spikes = h.Vector()
Expand Down Expand Up @@ -1185,7 +1188,10 @@ def add_external_input(self, input_file=None):
neuron_input = self.input_data["input"][str(neuron_id)][input_type]
sections = self.neurons[neuron_id].map_id_to_compartment(neuron_input.attrs["sectionID"])
mod_file = SnuddaLoad.to_str(neuron_input.attrs["modFile"])
param_list = json.loads(neuron_input.attrs["parameterList"], object_pairs_hook=OrderedDict)
if "parameterList" in neuron_input.attrs:
param_list = json.loads(neuron_input.attrs["parameterList"], object_pairs_hook=OrderedDict)
else:
param_list = None

# TODO: Sanity check mod_file string
eval_str = f"self.sim.neuron.h.{mod_file}"
Expand Down

0 comments on commit a7c251d

Please sign in to comment.