Skip to content

Commit

Permalink
Distribute synapse bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
frthjf committed Sep 13, 2023
1 parent faabc65 commit b5c3607
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 63 deletions.
50 changes: 31 additions & 19 deletions src/miv_simulator/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@
BaseModel as _BaseModel,
Field,
conlist,
GetCoreSchemaHandler,
)
from typing import Literal, Dict, Any, List, Tuple, Optional, Union, Callable
from enum import IntEnum
from collections import defaultdict
import numpy as np
from typing_extensions import Annotated
from pydantic.functional_validators import AfterValidator, BeforeValidator
from pydantic_core import CoreSchema, core_schema

# Definitions

Expand All @@ -27,12 +25,20 @@ class SWCTypesDef(IntEnum):
hillock = 8


SWCTypesLiteral = Literal[
"soma", "axon", "basal", "apical", "trunk", "tuft", "ais", "hillock"
]


class SynapseTypesDef(IntEnum):
excitatory = 0
inhibitory = 1
modulatory = 2


SynapseTypesLiteral = Literal["excitatory", "inhibitory", "modulatory"]


class SynapseMechanismsDef(IntEnum):
AMPA = 0
GABA_A = 1
Expand All @@ -54,6 +60,11 @@ class LayersDef(IntEnum):
SLM = 9 # Lacunosum-moleculare


LayersLiteral = Literal[
"default", "Hilus", "GCL", "IML", "MML", "OML", "SO", "SP", "SL", "SR", "SLM"
]


class InputSelectivityTypesDef(IntEnum):
random = 0
constant = 1
Expand All @@ -66,28 +77,23 @@ class PopulationsDef(IntEnum):
OLM = 102 # GABAergic oriens-lacunosum/moleculare


class AllowStringsFrom:
"""For convenience, allows users to specify enum values using their string name"""
PopulationsLiteral = Literal["STIM", "PYR", "PVBC", "OLM"]

def __init__(self, enum):
self.enum = enum

def cast(self, v):
def AllowStringsFrom(enum):
"""For convenience, allows users to specify enum values using their string name"""

def _cast(v) -> int:
if isinstance(v, str):
try:
return self.enum.__members__[v]
return enum.__members__[v]
except KeyError:
raise ValueError(
f"'{v}'. Must be one of {tuple(self.enum.__members__.keys())}"
f"'{v}'. Must be one of {tuple(enum.__members__.keys())}"
)
return v

def __get_pydantic_core_schema__(
self, source_type: Any, handler: GetCoreSchemaHandler
) -> CoreSchema:
return core_schema.no_info_before_validator_function(
self.cast, handler(source_type)
)
return BeforeValidator(_cast)


# Population
Expand Down Expand Up @@ -183,15 +189,21 @@ class ParametricSurface(BaseModel):
class CellType(BaseModel):
template: str
synapses: Dict[
SWCTypesDefOrStr,
Literal["density"],
Dict[
Union[LayersDefOrStr, Literal["default"]],
Dict[Literal["mean", "variance"], float],
SWCTypesLiteral,
Dict[
SynapseTypesLiteral,
Dict[
LayersLiteral,
Dict[Literal["mean", "variance"], float],
],
],
],
]


CellTypes = Dict[PopulationsDefOrStr, CellType]
CellTypes = Dict[PopulationsLiteral, CellType]


class AxonExtent(BaseModel):
Expand Down
8 changes: 1 addition & 7 deletions src/miv_simulator/interface/create_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,7 @@ def synapse_forest(self, version: VersionType = None) -> "Component":
def distribute_synapses(self, version: VersionType = None):
return self.derive(
"miv_simulator.interface.distribute_synapses",
[
{
"blueprint": self.config.blueprint,
"coordinates": self.output_filepath,
}
]
+ normversion(version),
[] + normversion(version),
uses=self,
)

Expand Down
22 changes: 15 additions & 7 deletions src/miv_simulator/interface/distribute_synapses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,35 @@
from miv_simulator import config
from miv_simulator import simulator
from pydantic import BaseModel, Field
from typing import Optional

from typing import Optional, Dict
from miv_simulator.utils import from_yaml

class DistributeSynapses(Component):
class Config(BaseModel):
forest_filepath: str = Field("???")
cell_types: config.CellTypes = {}
cell_types: config.CellTypes = Field("???")
population: str = Field("???")
distribution: str = "uniform"
mechanisms_path: str = "./mechanisms"
template_path: str = "./templates"
dt: float = 0.025
tstop: float = (0.0,)
tstop: float = 0.0
celsius: Optional[float] = 35.0
io_size: int = -1
write_size: int = 1
chunk_size: int = 1000
value_chunk_size: int = 1000
use_coreneuron: bool = False
ranks_: int = 8
nodes_: int = 1

def config_from_file(self, filename: str) -> Dict:
return from_yaml(filename)

@property
def output_filepath(self) -> str:
return self.local_directory("synapses.h5")

def __call__(self):
logging.basicConfig(level=logging.INFO)
simulator.distribute_synapses(
Expand All @@ -36,13 +44,13 @@ def __call__(self):
template_path=self.config.template_path,
dt=self.config.dt,
tstop=self.config.tstop,
celcius=self.config.celsius,
output_filepath=None, # modify in-place
celsius=self.config.celsius,
output_filepath=self.output_filepath,
io_size=self.config.io_size,
write_size=self.config.write_size,
chunk_size=self.config.chunk_size,
value_chunk_size=self.config.value_chunk_size,
use_coreneuron=self.config.use_coreneuron,
seed=self.seed,
io_size=self.config.io_size,
dry_run=False,
)
35 changes: 12 additions & 23 deletions src/miv_simulator/simulator/distribute_synapse_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,7 @@ def global_syn_summary(comm, syn_stats, gid_count, root):
)
total_syn_stats_dict = pop_syn_stats["total"]
for syn_type in total_syn_stats_dict:
global_syn_count = comm.gather(
total_syn_stats_dict[syn_type], root=root
)
global_syn_count = comm.gather(total_syn_stats_dict[syn_type], root=root)
if comm.rank == root:
res.append(
f"{population}: mean {syn_type} synapses per cell: {np.sum(global_syn_count) / global_count:.2f}"
Expand Down Expand Up @@ -222,6 +220,7 @@ def distribute_synapse_locations(
dt=env.dt,
tstop=env.tstop,
celsius=env.globals.get("celsius", None),
io_size=io_size,
output_filepath=output_path,
write_size=write_size,
chunk_size=chunk_size,
Expand All @@ -242,6 +241,7 @@ def distribute_synapses(
tstop: float,
celsius: Optional[float],
output_filepath: Optional[str],
io_size: int,
write_size: int,
chunk_size: int,
value_chunk_size: int,
Expand Down Expand Up @@ -299,7 +299,9 @@ def distribute_synapses(
(population_start, _) = pop_ranges[population]
template_class = load_template(
population_name=population,
cell_types=cell_types,
template_name=cell_types[population][
"template"
],
template_path=template_path,
)

Expand All @@ -308,13 +310,10 @@ def distribute_synapses(
swc_set_dict = defaultdict(set)
for sec_name, sec_dict in density_dict.items():
for syn_type, syn_dict in sec_dict.items():
swc_set_dict[syn_type].add(
config.SWCTypesDef.__members__[sec_name]
)
swc_set_dict[syn_type].add(sec_name)
for layer_name in syn_dict:
if layer_name != "default":
layer = config.LayersDef.__members__[layer_name]
layer_set_dict[syn_type].add(layer)
layer_set_dict[syn_type].add(layer_name)

syn_stats_dict = {
"section": defaultdict(lambda: {"excitatory": 0, "inhibitory": 0}),
Expand Down Expand Up @@ -384,14 +383,10 @@ def distribute_synapses(
cell_secidx_dict,
)
else:
raise Exception(
f"Unknown distribution type: {distribution}"
)
raise Exception(f"Unknown distribution type: {distribution}")

synapse_dict[gid] = syn_dict
this_syn_stats = update_synapse_statistics(
syn_stats_dict, syn_dict
)
this_syn_stats = update_synapse_statistics(syn_dict, syn_stats_dict)
check_synapses(
gid,
morph_dict,
Expand All @@ -412,11 +407,7 @@ def distribute_synapses(
else:
logger.info(f"Rank {rank} gid is None")
gc.collect()
if (
(not dry_run)
and (write_size > 0)
and (gid_count % write_size == 0)
):
if (not dry_run) and (write_size > 0) and (gid_count % write_size == 0):
append_cell_attributes(
output_filepath,
population,
Expand All @@ -443,9 +434,7 @@ def distribute_synapses(
value_chunk_size=value_chunk_size,
)

global_count, summary = global_syn_summary(
comm, syn_stats, gid_count, root=0
)
global_count, summary = global_syn_summary(comm, syn_stats, gid_count, root=0)
if rank == 0:
logger.info(
f"Population: {population}, {comm.size} ranks took {time.time() - start_time:.2f} s "
Expand Down
8 changes: 1 addition & 7 deletions src/miv_simulator/utils/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,18 +302,12 @@ def load_cell_template(

def load_template(
population_name: str,
cell_types: config.CellTypes,
template_name: str,
template_path: str,
):
if population_name in _loaded_templates:
return _loaded_templates[population_name]

if population_name not in cell_types:
raise KeyError(
f"load_cell_templates: unrecognized cell population: {population_name}"
)

template_name = cell_types[population_name]["template"]
template_file = os.path.join(template_path, f"{template_name}.hoc")
if not hasattr(h, template_name):
h.load_file(template_file)
Expand Down

0 comments on commit b5c3607

Please sign in to comment.