diff --git a/src/miv_simulator/config.py b/src/miv_simulator/config.py index 39764b5..6c1af3d 100644 --- a/src/miv_simulator/config.py +++ b/src/miv_simulator/config.py @@ -1,12 +1,9 @@ import copy from pydantic import ( BaseModel as _BaseModel, - Field, - conlist, AfterValidator, - BeforeValidator, ) -from typing import Literal, Dict, Any, List, Tuple, Optional, Union, Callable +from typing import Literal, Dict, List, Tuple, Optional, Union, Callable from enum import IntEnum from collections import defaultdict import numpy as np @@ -41,64 +38,12 @@ class SynapseTypesDef(IntEnum): SynapseTypesLiteral = Literal["excitatory", "inhibitory", "modulatory"] +# Synapses -class SynapseMechanismsDef(IntEnum): - AMPA = 0 - GABA_A = 1 - GABA_B = 2 - NMDA = 30 - - -SynapseMechanismsLiteral = Literal["AMPA", "GABA_A", "GABA_B", "NMDA"] - - -class LayersDef(IntEnum): - default = -1 - Hilus = 0 - GCL = 1 # Granule cell - IML = 2 # Inner molecular - MML = 3 # Middle molecular - OML = 4 # Outer molecular - SO = 5 # Oriens - SP = 6 # Pyramidale - SL = 7 # Lucidum - SR = 8 # Radiatum - SLM = 9 # Lacunosum-moleculare - - -class InputSelectivityTypesDef(IntEnum): - random = 0 - constant = 1 - - -def AllowStringsFrom(enum): - """For convenience, allows users to specify enum values using their string name""" - - def _cast(v) -> int: - if isinstance(v, str): - try: - return enum.__members__[v] - except KeyError: - raise ValueError( - f"'{v}'. Must be one of {tuple(enum.__members__.keys())}" - ) - return v - - return BeforeValidator(_cast) - +SynapseMechanismName = str # Population -SynapseTypesDefOrStr = Annotated[ - SynapseTypesDef, AllowStringsFrom(SynapseTypesDef) -] -SWCTypesDefOrStr = Annotated[SWCTypesDef, AllowStringsFrom(SWCTypesDef)] -SynapseMechanismsDefOrStr = Annotated[ - SynapseMechanismsDef, AllowStringsFrom(SynapseMechanismsDef) -] -LayersDefOrStr = Annotated[LayersDef, AllowStringsFrom(LayersDef)] - - PopulationName = str PostSynapticPopulationName = PopulationName PreSynapticPopulationName = PopulationName @@ -182,13 +127,29 @@ class Mechanism(BaseModel): class Synapse(BaseModel): - type: SynapseTypesDefOrStr - sections: conlist(SWCTypesDefOrStr) - layers: conlist(LayersDefOrStr) - proportions: conlist(float) - mechanisms: Dict[SynapseMechanismsLiteral, Mechanism] + type: SynapseTypesLiteral + sections: List[SWCTypesLiteral] + layers: List[LayerName] + proportions: list[float] + mechanisms: Dict[SynapseMechanismName, Mechanism] contacts: int = 1 + def to_config(self, layer_definitions: Dict[LayerName, int]): + return type( + "SynapseConfig", + (), + { + "type": SynapseTypesDef.__members__[self.type], + "sections": list( + map(SWCTypesDef.__members__.get, self.sections) + ), + "layers": list(map(layer_definitions.get, self.layers)), + "proportions": self.proportions, + "mechanisms": self.mechanisms, + "contacts": self.contacts, + }, + ) + def _origin_value_to_callable(value: Union[str, float]) -> Callable: if isinstance(value, float): @@ -283,9 +244,19 @@ def probabilities_sum_to_one(x): sentinel = object() +class Definitions(BaseModel): + swc_types: Dict[str, int] + synapse_types: Dict[str, int] + synapse_mechanisms: Dict[str, int] + layers: Dict[str, int] + populations: Dict[str, int] + input_selectivity_types: Dict[str, int] + + class Config: def __init__(self, data: Dict) -> None: self._data = copy.deepcopy(data) + self._definitions = None # compatibility self.get("Cell Types.STIM", {}).setdefault("synapses", {}) @@ -356,3 +327,20 @@ def cell_types(self) -> CellTypes: @property def clamp(self) -> Optional[Dict]: return self.get("Network Clamp", None) + + @property + def definitions(self) -> Definitions: + if self._definitions is None: + self._definitions = Definitions( + swc_types=self.get("Definitions.SWC Types", {}), + synapse_types=self.get("Definitions.Synapse Types", {}), + synapse_mechanisms=self.get( + "Definitions.Synapse Mechanisms", {} + ), + layers=self.get("Definitions.Layers", {}), + populations=self.get("Definitions.Populations", {}), + input_selectivity_types=self.get( + "Definitions.Input Selectivity Types", {} + ), + ) + return self._definitions diff --git a/src/miv_simulator/interface/connections.py b/src/miv_simulator/interface/connections.py index 051f1f4..013858d 100644 --- a/src/miv_simulator/interface/connections.py +++ b/src/miv_simulator/interface/connections.py @@ -18,6 +18,8 @@ class Config(BaseModel): forest_filepath: str = Field("???") axon_extents: config.AxonExtents = Field("???") synapses: config.Synapses = Field("???") + population_definitions: Dict[str, int] = Field("???") + layer_definitions: Dict[str, int] = Field("???") include_forest_populations: Optional[list] = None connectivity_namespace: str = "Connections" coordinates_namespace: str = "Coordinates" @@ -45,13 +47,22 @@ def __call__(self): filepath=self.config.filepath, forest_filepath=self.config.forest_filepath, include_forest_populations=self.config.include_forest_populations, - synapses=self.config.synapses, + synapses={ + post: { + pre: config.Synapse(**syn).to_config( + self.config.layer_definitions + ) + for pre, syn in v.items() + } + for post, v in self.config.synapses.items() + }, axon_extents=self.config.axon_extents, output_filepath=self.output_filepath, connectivity_namespace=self.config.connectivity_namespace, coordinates_namespace=self.config.coordinates_namespace, synapses_namespace=self.config.synapses_namespace, distances_namespace=self.config.distances_namespace, + populations_dict=self.config.population_definitions, io_size=self.config.io_size, chunk_size=self.config.chunk_size, value_chunk_size=self.config.value_chunk_size, diff --git a/src/miv_simulator/interface/h5_types.py b/src/miv_simulator/interface/h5_types.py index 846e4ef..c64fac2 100644 --- a/src/miv_simulator/interface/h5_types.py +++ b/src/miv_simulator/interface/h5_types.py @@ -12,7 +12,8 @@ class Config(BaseModel): cell_distributions: config.CellDistributions = Field("???") projections: config.SynapticProjections = Field("???") - mpi_args: str = "-n 1" + population_definitions: Dict[str, int] = Field("???") + ranks: int = 1 nodes: str = "1" def config_from_file(self, filename: str) -> Dict: @@ -31,11 +32,12 @@ def __call__(self) -> None: post: {pre: True for pre in v} for post, v in self.config.projections.items() }, + population_definitions=self.config.population_definitions, ) MPI.COMM_WORLD.barrier() def compute_context(self): context = super().compute_context() - del context["config"]["mpi_args"] + del context["config"]["ranks"] del context["config"]["nodes"] return context diff --git a/src/miv_simulator/interface/network.py b/src/miv_simulator/interface/network.py index 44b0ef6..a288b1d 100644 --- a/src/miv_simulator/interface/network.py +++ b/src/miv_simulator/interface/network.py @@ -25,6 +25,7 @@ def launch(self): { "projections": config.projections, "cell_distributions": config.cell_distributions, + "population_definitions": config.definitions.populations, }, ], ) @@ -67,6 +68,7 @@ def launch(self): ].output_filepath, "cell_types": config.cell_types, "population": population, + "layer_definitions": config.definitions.layers, "distribution": "poisson", "mechanisms_path": self.config.mechanisms_path, "template_path": self.config.template_path, @@ -86,6 +88,8 @@ def launch(self): population ].output_filepath, "axon_extents": config.axon_extents, + "population_definitions": config.definitions.populations, + "layer_definitions": config.definitions.layers, "io_size": 1, "cache_size": 20, "write_size": 100, diff --git a/src/miv_simulator/interface/neuroh5_graph.py b/src/miv_simulator/interface/neuroh5_graph.py index cb8b12a..66185cb 100644 --- a/src/miv_simulator/interface/neuroh5_graph.py +++ b/src/miv_simulator/interface/neuroh5_graph.py @@ -6,7 +6,7 @@ class NeuroH5Graph(Component): class Config: - mpi_args: str = "-n 1" + ranks: int = 1 nodes: str = "1" def __init__(self, *args, **kwargs): @@ -102,7 +102,7 @@ def files(self) -> Dict[str, str]: def compute_context(self): context = super().compute_context() - del context["config"]["mpi_args"] + del context["config"]["ranks"] del context["config"]["nodes"] context["predicate"]["uses"] = sorted([u.hash for u in self.uses]) return context diff --git a/src/miv_simulator/interface/synapse_forest.py b/src/miv_simulator/interface/synapse_forest.py index 5d20170..7fb0d4c 100644 --- a/src/miv_simulator/interface/synapse_forest.py +++ b/src/miv_simulator/interface/synapse_forest.py @@ -12,7 +12,8 @@ class Config(BaseModel): filepath: str = Field("???") population: config.PopulationName = Field("???") morphology: config.SWCFilePath = Field("???") - mpi_args: Optional[str] = "-n 1" + # ranks: int = 1 + mpi: Optional[str] = None nodes: str = "1" @property @@ -62,7 +63,7 @@ def dispatch_code( def compute_context(self): context = super().compute_context() - del context["config"]["mpi_args"] + del context["config"]["mpi"] del context["config"]["nodes"] del context["config"]["filepath"] context["config"]["morphology"] = file_hash( diff --git a/src/miv_simulator/interface/synapses.py b/src/miv_simulator/interface/synapses.py index 3bf940a..ad03573 100644 --- a/src/miv_simulator/interface/synapses.py +++ b/src/miv_simulator/interface/synapses.py @@ -17,6 +17,7 @@ class Config(BaseModel): forest_filepath: str = Field("???") cell_types: config.CellTypes = Field("???") population: str = Field("???") + layer_definitions: Dict[str, int] = Field("???") distribution: str = "uniform" mechanisms_path: str = "./mechanisms/compiled" template_path: str = "./templates" @@ -51,6 +52,9 @@ def __call__(self): simulator.distribute_synapses( forest_filepath=self.config.forest_filepath, cell_types=self.config.cell_types, + swc_defs=config.SWCTypesDef.__members__, + synapse_defs=config.SynapseTypesDef.__members__, + layer_defs=self.config.layer_definitions, populations=(self.config.population,), distribution=self.config.distribution, template_path=self.config.template_path, diff --git a/src/miv_simulator/simulator/distribute_synapses.py b/src/miv_simulator/simulator/distribute_synapses.py index b6c8679..22fad71 100644 --- a/src/miv_simulator/simulator/distribute_synapses.py +++ b/src/miv_simulator/simulator/distribute_synapses.py @@ -233,6 +233,7 @@ def distribute_synapse_locations( forest_filepath=forest_path, cell_types=env.celltypes, swc_defs=env.SWC_Types, + synapse_defs=env.Synapse_Types, layer_defs=env.layers, populations=populations, distribution=distribution, @@ -251,6 +252,7 @@ def distribute_synapses( forest_filepath: str, cell_types: config.CellTypes, swc_defs: Dict[str, int], + synapse_defs: Dict[str, int], layer_defs: Dict[str, int], populations: Tuple[str, ...], distribution: Literal["uniform", "poisson"], @@ -363,9 +365,9 @@ def distribute_synapses( seg_density_per_sec, ) = synapses.distribute_uniform_synapses( random_seed, - config.SynapseTypesDef.__members__, - config.SWCTypesDef.__members__, - config.LayersDef.__members__, + synapse_defs, + swc_defs, + layer_defs, density_dict, morph_dict, cell_sec_dict, @@ -378,9 +380,9 @@ def distribute_synapses( seg_density_per_sec, ) = synapses.distribute_poisson_synapses( random_seed, - config.SynapseTypesDef.__members__, - config.SWCTypesDef.__members__, - config.LayersDef.__members__, + synapse_defs, + swc_defs, + layer_defs, density_dict, morph_dict, cell_sec_dict, diff --git a/src/miv_simulator/simulator/generate_connections.py b/src/miv_simulator/simulator/generate_connections.py index beb1c41..59d8034 100644 --- a/src/miv_simulator/simulator/generate_connections.py +++ b/src/miv_simulator/simulator/generate_connections.py @@ -18,7 +18,7 @@ read_population_names, read_population_ranges, ) -from typing import Optional, Union, Tuple +from typing import Optional, Union, Tuple, Dict sys_excepthook = sys.excepthook diff --git a/src/miv_simulator/simulator/generate_synapse_forest.py b/src/miv_simulator/simulator/generate_synapse_forest.py index e81e122..0ff0d2a 100644 --- a/src/miv_simulator/simulator/generate_synapse_forest.py +++ b/src/miv_simulator/simulator/generate_synapse_forest.py @@ -12,23 +12,25 @@ def _bin_check(bin: str) -> None: raise FileNotFoundError(f"{bin} not found. Did you add it to the PATH?") -def _sh(cmd, spawn_process=True): - if not spawn_process: - return os.system(" ".join([shlex.quote(c) for c in cmd])) - - try: - subprocess.check_output( - cmd, - stderr=subprocess.STDOUT, - ) - except subprocess.CalledProcessError as e: - error_message = e.output.decode() - print(f"{os.getcwd()}$:") - print(" ".join(cmd)) - print("Error:", error_message) - raise subprocess.CalledProcessError( - e.returncode, e.cmd, output=error_message - ) +def _sh(cmd, spawn_process=False): + if spawn_process: + try: + subprocess.check_output( + cmd, + stderr=subprocess.STDOUT, + ) + except subprocess.CalledProcessError as e: + error_message = e.output.decode() + print(f"{os.getcwd()}$:") + print(" ".join(cmd)) + print("Error:", error_message) + raise subprocess.CalledProcessError( + e.returncode, e.cmd, output=error_message + ) + else: + cmdq = " ".join([shlex.quote(c) for c in cmd]) + if os.system(cmdq) != 0: + raise RuntimeError(f"Error running {cmdq}") def generate_synapse_forest( diff --git a/src/miv_simulator/utils/io.py b/src/miv_simulator/utils/io.py index a97791a..69a6e4c 100644 --- a/src/miv_simulator/utils/io.py +++ b/src/miv_simulator/utils/io.py @@ -269,32 +269,34 @@ def create_neural_h5( output_filepath: str, cell_distributions: config.CellDistributions, synapses: config.Synapses, + population_definitions: Dict[str, int], gap_junctions: Optional[Dict] = None, - populations: Optional[Dict[str, config.PopulationsDef]] = None, ) -> None: - if populations is None: - populations = config.PopulationsDef.__members__ - _populations = [] - for pop_name, pop_idx in populations.items(): + populations = [] + for pop_name, pop_idx in population_definitions.items(): layer_counts = cell_distributions[pop_name] pop_count = 0 for layer_name, layer_count in layer_counts.items(): pop_count += layer_count - _populations.append((pop_name, pop_idx, pop_count)) - _populations.sort(key=lambda x: x[1]) - min_pop_idx = _populations[0][1] + populations.append((pop_name, pop_idx, pop_count)) + populations.sort(key=lambda x: x[1]) + min_pop_idx = populations[0][1] projections = [] if gap_junctions: for (post, pre), connection_dict in gap_junctions.items(): - projections.append((populations[pre], populations[post])) + projections.append( + (population_definitions[pre], population_definitions[post]) + ) else: for post, connection_dict in synapses.items(): for pre, _ in connection_dict.items(): - projections.append((populations[pre], populations[post])) + projections.append( + (population_definitions[pre], population_definitions[post]) + ) # create an HDF5 enumerated type for the population label - mapping = {name: idx for name, idx in populations.items()} + mapping = {name: idx for name, idx in population_definitions.items()} dt_population_labels = h5py.special_dtype(enum=(np.uint16, mapping)) with h5py.File(output_filepath, "a") as h5: @@ -315,12 +317,12 @@ def create_neural_h5( g = h5_get_group(h5, grp_h5types) dset = h5_get_dataset( - g, grp_populations, maxshape=(len(_populations),), dtype=dt + g, grp_populations, maxshape=(len(populations),), dtype=dt ) - dset.resize((len(_populations),)) - a = np.zeros(len(_populations), dtype=dt) + dset.resize((len(populations),)) + a = np.zeros(len(populations), dtype=dt) start = 0 - for enum_id, (name, idx, count) in enumerate(_populations): + for enum_id, (name, idx, count) in enumerate(populations): a[enum_id]["Start"] = start a[enum_id]["Count"] = count a[enum_id]["Population"] = idx diff --git a/src/scripts/make_h5types.py b/src/scripts/make_h5types.py index d775d6c..1acca9d 100644 --- a/src/scripts/make_h5types.py +++ b/src/scripts/make_h5types.py @@ -16,8 +16,8 @@ def make_h5types( output_file, env.geometry["Cell Distribution"], env.connection_config, - env.gapjunctions if gap_junctions else None, env.Populations, + env.gapjunctions if gap_junctions else None, ) diff --git a/src/scripts/tools/reposition_trees.py b/src/scripts/tools/reposition_trees.py index 1dc1186..75333c9 100644 --- a/src/scripts/tools/reposition_trees.py +++ b/src/scripts/tools/reposition_trees.py @@ -228,28 +228,34 @@ def main( iter_count += 1 if ( - (not dry_run) - and (write_size > 0) - and (iter_count % write_size == 0) - ): + (not dry_run) + and (write_size > 0) + and (iter_count % write_size == 0) + ): if rank == 0: logger.info(f"Appending repositioned trees to {output_path}...") append_cell_trees( - output_path, population, new_trees_dict, io_size=io_size, - chunk_size=chunk_size, value_chunk_size=value_chunk_size, - comm=comm + output_path, + population, + new_trees_dict, + io_size=io_size, + chunk_size=chunk_size, + value_chunk_size=value_chunk_size, + comm=comm, ) new_trees_dict = {} - - if not dry_run: - if (rank == 0): + if rank == 0: logger.info(f"Appending repositioned trees to {output_path}...") append_cell_trees( - output_path, population, new_trees_dict, io_size=io_size, - chunk_size=chunk_size, value_chunk_size=value_chunk_size, - comm=comm + output_path, + population, + new_trees_dict, + io_size=io_size, + chunk_size=chunk_size, + value_chunk_size=value_chunk_size, + comm=comm, ) comm.barrier() @@ -259,6 +265,7 @@ def main( ) MPI.Finalize() + if __name__ == "__main__": main( args=sys.argv[