Skip to content

Commit

Permalink
Adding post_init_modifications
Browse files Browse the repository at this point in the history
  • Loading branch information
Hjorthmedh committed Nov 28, 2024
1 parent 49c01d0 commit 2bb5650
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions snudda/simulate/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ def __init__(self,
self.node_id = int(self.pc.id())
self.total_nodes = int(self.pc.nhost())

self.post_init_mods = dict()

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
Expand Down Expand Up @@ -240,6 +242,9 @@ def __init__(self,
self.current_injection_info |= self.sim_info["current_injection_info"]
self.write_log(f"Updating current_injection_info from config file")

if "post_init_modifications" in self.sim_info:
self.post_init_mods = self.sim_info["post_init_modifications"]

else:
self.sim_info = None

Expand Down Expand Up @@ -837,6 +842,10 @@ def setup_neurons(self):
and self.neurons[neuron_id].modulation is not None:
self.neurons[neuron_id].modulation.build_node_cache()

# This allows us to modify ion channel conductance on the fly before runnigg simulation
# Can be useful to e.g. increase KIR channel conductance
self.post_initialisation_modifications()

############################################################################

def connect_network(self):
Expand Down Expand Up @@ -2144,6 +2153,29 @@ def sanity_check_play_vectors(self, sim_end_time):
raise ValueError(f"Simulation duration {sim_end_time} is "
f"longer than time vector for bath application of {species_name}")

def post_initialisation_modifications(self):

if len(self.post_init_mods) == 0:
return

print(f"Applying post_initialisation_modifications")
# This allows us to modify ion channel conductance on the fly before runnigg simulation
# Can be useful to e.g. increase KIR channel conductance

for n in self.neurons.values():
try:
if n.type in self.post_init_mods:
for ion_channel, channel_mod_factor in self.post_init_mods[n.type].items():
for sec in n.icell.all:
for seg in sec:
channel = getattr(seg, ion_channel, None)
setattr(channel, "gbar", getattr(channel, "gbar") * channel_mod_factor)
except:
import traceback
self.write_log(traceback.format_exc())
import pdb
pdb.set_trace()

def run(self, t=None, hold_v=None):

""" Run simulation. """
Expand Down

0 comments on commit 2bb5650

Please sign in to comment.