diff --git a/src/miv_simulator/config/__init__.py b/src/miv_simulator/config/__init__.py index 3a6b400..1a09bb3 100644 --- a/src/miv_simulator/config/__init__.py +++ b/src/miv_simulator/config/__init__.py @@ -3,7 +3,6 @@ BaseModel as _BaseModel, Field, conlist, - GetCoreSchemaHandler, ) from typing import Literal, Dict, Any, List, Tuple, Optional, Union, Callable from enum import IntEnum @@ -11,7 +10,6 @@ import numpy as np from typing_extensions import Annotated from pydantic.functional_validators import AfterValidator, BeforeValidator -from pydantic_core import CoreSchema, core_schema # Definitions @@ -27,12 +25,20 @@ class SWCTypesDef(IntEnum): hillock = 8 +SWCTypesLiteral = Literal[ + "soma", "axon", "basal", "apical", "trunk", "tuft", "ais", "hillock" +] + + class SynapseTypesDef(IntEnum): excitatory = 0 inhibitory = 1 modulatory = 2 +SynapseTypesLiteral = Literal["excitatory", "inhibitory", "modulatory"] + + class SynapseMechanismsDef(IntEnum): AMPA = 0 GABA_A = 1 @@ -54,6 +60,11 @@ class LayersDef(IntEnum): SLM = 9 # Lacunosum-moleculare +LayersLiteral = Literal[ + "default", "Hilus", "GCL", "IML", "MML", "OML", "SO", "SP", "SL", "SR", "SLM" +] + + class InputSelectivityTypesDef(IntEnum): random = 0 constant = 1 @@ -66,28 +77,23 @@ class PopulationsDef(IntEnum): OLM = 102 # GABAergic oriens-lacunosum/moleculare -class AllowStringsFrom: - """For convenience, allows users to specify enum values using their string name""" +PopulationsLiteral = Literal["STIM", "PYR", "PVBC", "OLM"] - def __init__(self, enum): - self.enum = enum - def cast(self, v): +def AllowStringsFrom(enum): + """For convenience, allows users to specify enum values using their string name""" + + def _cast(v) -> int: if isinstance(v, str): try: - return self.enum.__members__[v] + return enum.__members__[v] except KeyError: raise ValueError( - f"'{v}'. Must be one of {tuple(self.enum.__members__.keys())}" + f"'{v}'. Must be one of {tuple(enum.__members__.keys())}" ) return v - def __get_pydantic_core_schema__( - self, source_type: Any, handler: GetCoreSchemaHandler - ) -> CoreSchema: - return core_schema.no_info_before_validator_function( - self.cast, handler(source_type) - ) + return BeforeValidator(_cast) # Population @@ -183,15 +189,21 @@ class ParametricSurface(BaseModel): class CellType(BaseModel): template: str synapses: Dict[ - SWCTypesDefOrStr, + Literal["density"], Dict[ - Union[LayersDefOrStr, Literal["default"]], - Dict[Literal["mean", "variance"], float], + SWCTypesLiteral, + Dict[ + SynapseTypesLiteral, + Dict[ + LayersLiteral, + Dict[Literal["mean", "variance"], float], + ], + ], ], ] -CellTypes = Dict[PopulationsDefOrStr, CellType] +CellTypes = Dict[PopulationsLiteral, CellType] class AxonExtent(BaseModel): diff --git a/src/miv_simulator/interface/create_network.py b/src/miv_simulator/interface/create_network.py index d565a28..fef31b0 100644 --- a/src/miv_simulator/interface/create_network.py +++ b/src/miv_simulator/interface/create_network.py @@ -114,13 +114,7 @@ def synapse_forest(self, version: VersionType = None) -> "Component": def distribute_synapses(self, version: VersionType = None): return self.derive( "miv_simulator.interface.distribute_synapses", - [ - { - "blueprint": self.config.blueprint, - "coordinates": self.output_filepath, - } - ] - + normversion(version), + [] + normversion(version), uses=self, ) diff --git a/src/miv_simulator/interface/distribute_synapses.py b/src/miv_simulator/interface/distribute_synapses.py index 0f04a9f..52a1c19 100644 --- a/src/miv_simulator/interface/distribute_synapses.py +++ b/src/miv_simulator/interface/distribute_synapses.py @@ -4,20 +4,21 @@ from miv_simulator import config from miv_simulator import simulator from pydantic import BaseModel, Field -from typing import Optional - +from typing import Optional, Dict +from miv_simulator.utils import from_yaml class DistributeSynapses(Component): class Config(BaseModel): forest_filepath: str = Field("???") - cell_types: config.CellTypes = {} + cell_types: config.CellTypes = Field("???") population: str = Field("???") distribution: str = "uniform" mechanisms_path: str = "./mechanisms" template_path: str = "./templates" dt: float = 0.025 - tstop: float = (0.0,) + tstop: float = 0.0 celsius: Optional[float] = 35.0 + io_size: int = -1 write_size: int = 1 chunk_size: int = 1000 value_chunk_size: int = 1000 @@ -25,6 +26,13 @@ class Config(BaseModel): ranks_: int = 8 nodes_: int = 1 + def config_from_file(self, filename: str) -> Dict: + return from_yaml(filename) + + @property + def output_filepath(self) -> str: + return self.local_directory("synapses.h5") + def __call__(self): logging.basicConfig(level=logging.INFO) simulator.distribute_synapses( @@ -36,13 +44,13 @@ def __call__(self): template_path=self.config.template_path, dt=self.config.dt, tstop=self.config.tstop, - celcius=self.config.celsius, - output_filepath=None, # modify in-place + celsius=self.config.celsius, + output_filepath=self.output_filepath, + io_size=self.config.io_size, write_size=self.config.write_size, chunk_size=self.config.chunk_size, value_chunk_size=self.config.value_chunk_size, use_coreneuron=self.config.use_coreneuron, seed=self.seed, - io_size=self.config.io_size, dry_run=False, ) diff --git a/src/miv_simulator/simulator/distribute_synapse_locations.py b/src/miv_simulator/simulator/distribute_synapse_locations.py index 60e91fd..700efbc 100644 --- a/src/miv_simulator/simulator/distribute_synapse_locations.py +++ b/src/miv_simulator/simulator/distribute_synapse_locations.py @@ -92,9 +92,7 @@ def global_syn_summary(comm, syn_stats, gid_count, root): ) total_syn_stats_dict = pop_syn_stats["total"] for syn_type in total_syn_stats_dict: - global_syn_count = comm.gather( - total_syn_stats_dict[syn_type], root=root - ) + global_syn_count = comm.gather(total_syn_stats_dict[syn_type], root=root) if comm.rank == root: res.append( f"{population}: mean {syn_type} synapses per cell: {np.sum(global_syn_count) / global_count:.2f}" @@ -222,6 +220,7 @@ def distribute_synapse_locations( dt=env.dt, tstop=env.tstop, celsius=env.globals.get("celsius", None), + io_size=io_size, output_filepath=output_path, write_size=write_size, chunk_size=chunk_size, @@ -242,6 +241,7 @@ def distribute_synapses( tstop: float, celsius: Optional[float], output_filepath: Optional[str], + io_size: int, write_size: int, chunk_size: int, value_chunk_size: int, @@ -299,7 +299,9 @@ def distribute_synapses( (population_start, _) = pop_ranges[population] template_class = load_template( population_name=population, - cell_types=cell_types, + template_name=cell_types[population][ + "template" + ], template_path=template_path, ) @@ -308,13 +310,10 @@ def distribute_synapses( swc_set_dict = defaultdict(set) for sec_name, sec_dict in density_dict.items(): for syn_type, syn_dict in sec_dict.items(): - swc_set_dict[syn_type].add( - config.SWCTypesDef.__members__[sec_name] - ) + swc_set_dict[syn_type].add(sec_name) for layer_name in syn_dict: if layer_name != "default": - layer = config.LayersDef.__members__[layer_name] - layer_set_dict[syn_type].add(layer) + layer_set_dict[syn_type].add(layer_name) syn_stats_dict = { "section": defaultdict(lambda: {"excitatory": 0, "inhibitory": 0}), @@ -384,14 +383,10 @@ def distribute_synapses( cell_secidx_dict, ) else: - raise Exception( - f"Unknown distribution type: {distribution}" - ) + raise Exception(f"Unknown distribution type: {distribution}") synapse_dict[gid] = syn_dict - this_syn_stats = update_synapse_statistics( - syn_stats_dict, syn_dict - ) + this_syn_stats = update_synapse_statistics(syn_dict, syn_stats_dict) check_synapses( gid, morph_dict, @@ -412,11 +407,7 @@ def distribute_synapses( else: logger.info(f"Rank {rank} gid is None") gc.collect() - if ( - (not dry_run) - and (write_size > 0) - and (gid_count % write_size == 0) - ): + if (not dry_run) and (write_size > 0) and (gid_count % write_size == 0): append_cell_attributes( output_filepath, population, @@ -443,9 +434,7 @@ def distribute_synapses( value_chunk_size=value_chunk_size, ) - global_count, summary = global_syn_summary( - comm, syn_stats, gid_count, root=0 - ) + global_count, summary = global_syn_summary(comm, syn_stats, gid_count, root=0) if rank == 0: logger.info( f"Population: {population}, {comm.size} ranks took {time.time() - start_time:.2f} s " diff --git a/src/miv_simulator/utils/neuron.py b/src/miv_simulator/utils/neuron.py index 822fc8f..c519bdf 100644 --- a/src/miv_simulator/utils/neuron.py +++ b/src/miv_simulator/utils/neuron.py @@ -302,18 +302,12 @@ def load_cell_template( def load_template( population_name: str, - cell_types: config.CellTypes, + template_name: str, template_path: str, ): if population_name in _loaded_templates: return _loaded_templates[population_name] - if population_name not in cell_types: - raise KeyError( - f"load_cell_templates: unrecognized cell population: {population_name}" - ) - - template_name = cell_types[population_name]["template"] template_file = os.path.join(template_path, f"{template_name}.hoc") if not hasattr(h, template_name): h.load_file(template_file)