Skip to content

Commit

Permalink
Rollback hardcoded types
Browse files Browse the repository at this point in the history
  • Loading branch information
frthjf committed Apr 6, 2024
1 parent 1fd8aaa commit f2402b5
Show file tree
Hide file tree
Showing 13 changed files with 146 additions and 123 deletions.
114 changes: 51 additions & 63 deletions src/miv_simulator/config.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import copy
from pydantic import (
BaseModel as _BaseModel,
Field,
conlist,
AfterValidator,
BeforeValidator,
)
from typing import Literal, Dict, Any, List, Tuple, Optional, Union, Callable
from typing import Literal, Dict, List, Tuple, Optional, Union, Callable
from enum import IntEnum
from collections import defaultdict
import numpy as np
Expand Down Expand Up @@ -41,64 +38,12 @@ class SynapseTypesDef(IntEnum):

SynapseTypesLiteral = Literal["excitatory", "inhibitory", "modulatory"]

# Synapses

class SynapseMechanismsDef(IntEnum):
AMPA = 0
GABA_A = 1
GABA_B = 2
NMDA = 30


SynapseMechanismsLiteral = Literal["AMPA", "GABA_A", "GABA_B", "NMDA"]


class LayersDef(IntEnum):
default = -1
Hilus = 0
GCL = 1 # Granule cell
IML = 2 # Inner molecular
MML = 3 # Middle molecular
OML = 4 # Outer molecular
SO = 5 # Oriens
SP = 6 # Pyramidale
SL = 7 # Lucidum
SR = 8 # Radiatum
SLM = 9 # Lacunosum-moleculare


class InputSelectivityTypesDef(IntEnum):
random = 0
constant = 1


def AllowStringsFrom(enum):
"""For convenience, allows users to specify enum values using their string name"""

def _cast(v) -> int:
if isinstance(v, str):
try:
return enum.__members__[v]
except KeyError:
raise ValueError(
f"'{v}'. Must be one of {tuple(enum.__members__.keys())}"
)
return v

return BeforeValidator(_cast)

SynapseMechanismName = str

# Population

SynapseTypesDefOrStr = Annotated[
SynapseTypesDef, AllowStringsFrom(SynapseTypesDef)
]
SWCTypesDefOrStr = Annotated[SWCTypesDef, AllowStringsFrom(SWCTypesDef)]
SynapseMechanismsDefOrStr = Annotated[
SynapseMechanismsDef, AllowStringsFrom(SynapseMechanismsDef)
]
LayersDefOrStr = Annotated[LayersDef, AllowStringsFrom(LayersDef)]


PopulationName = str
PostSynapticPopulationName = PopulationName
PreSynapticPopulationName = PopulationName
Expand Down Expand Up @@ -182,13 +127,29 @@ class Mechanism(BaseModel):


class Synapse(BaseModel):
type: SynapseTypesDefOrStr
sections: conlist(SWCTypesDefOrStr)
layers: conlist(LayersDefOrStr)
proportions: conlist(float)
mechanisms: Dict[SynapseMechanismsLiteral, Mechanism]
type: SynapseTypesLiteral
sections: List[SWCTypesLiteral]
layers: List[LayerName]
proportions: list[float]
mechanisms: Dict[SynapseMechanismName, Mechanism]
contacts: int = 1

def to_config(self, layer_definitions: Dict[LayerName, int]):
return type(
"SynapseConfig",
(),
{
"type": SynapseTypesDef.__members__[self.type],
"sections": list(
map(SWCTypesDef.__members__.get, self.sections)
),
"layers": list(map(layer_definitions.get, self.layers)),
"proportions": self.proportions,
"mechanisms": self.mechanisms,
"contacts": self.contacts,
},
)


def _origin_value_to_callable(value: Union[str, float]) -> Callable:
if isinstance(value, float):
Expand Down Expand Up @@ -283,9 +244,19 @@ def probabilities_sum_to_one(x):
sentinel = object()


class Definitions(BaseModel):
swc_types: Dict[str, int]
synapse_types: Dict[str, int]
synapse_mechanisms: Dict[str, int]
layers: Dict[str, int]
populations: Dict[str, int]
input_selectivity_types: Dict[str, int]


class Config:
def __init__(self, data: Dict) -> None:
self._data = copy.deepcopy(data)
self._definitions = None

# compatibility
self.get("Cell Types.STIM", {}).setdefault("synapses", {})
Expand Down Expand Up @@ -356,3 +327,20 @@ def cell_types(self) -> CellTypes:
@property
def clamp(self) -> Optional[Dict]:
return self.get("Network Clamp", None)

@property
def definitions(self) -> Definitions:
if self._definitions is None:
self._definitions = Definitions(
swc_types=self.get("Definitions.SWC Types", {}),
synapse_types=self.get("Definitions.Synapse Types", {}),
synapse_mechanisms=self.get(
"Definitions.Synapse Mechanisms", {}
),
layers=self.get("Definitions.Layers", {}),
populations=self.get("Definitions.Populations", {}),
input_selectivity_types=self.get(
"Definitions.Input Selectivity Types", {}
),
)
return self._definitions
13 changes: 12 additions & 1 deletion src/miv_simulator/interface/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class Config(BaseModel):
forest_filepath: str = Field("???")
axon_extents: config.AxonExtents = Field("???")
synapses: config.Synapses = Field("???")
population_definitions: Dict[str, int] = Field("???")
layer_definitions: Dict[str, int] = Field("???")
include_forest_populations: Optional[list] = None
connectivity_namespace: str = "Connections"
coordinates_namespace: str = "Coordinates"
Expand Down Expand Up @@ -45,13 +47,22 @@ def __call__(self):
filepath=self.config.filepath,
forest_filepath=self.config.forest_filepath,
include_forest_populations=self.config.include_forest_populations,
synapses=self.config.synapses,
synapses={
post: {
pre: config.Synapse(**syn).to_config(
self.config.layer_definitions
)
for pre, syn in v.items()
}
for post, v in self.config.synapses.items()
},
axon_extents=self.config.axon_extents,
output_filepath=self.output_filepath,
connectivity_namespace=self.config.connectivity_namespace,
coordinates_namespace=self.config.coordinates_namespace,
synapses_namespace=self.config.synapses_namespace,
distances_namespace=self.config.distances_namespace,
populations_dict=self.config.population_definitions,
io_size=self.config.io_size,
chunk_size=self.config.chunk_size,
value_chunk_size=self.config.value_chunk_size,
Expand Down
6 changes: 4 additions & 2 deletions src/miv_simulator/interface/h5_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ class Config(BaseModel):

cell_distributions: config.CellDistributions = Field("???")
projections: config.SynapticProjections = Field("???")
mpi_args: str = "-n 1"
population_definitions: Dict[str, int] = Field("???")
ranks: int = 1
nodes: str = "1"

def config_from_file(self, filename: str) -> Dict:
Expand All @@ -31,11 +32,12 @@ def __call__(self) -> None:
post: {pre: True for pre in v}
for post, v in self.config.projections.items()
},
population_definitions=self.config.population_definitions,
)
MPI.COMM_WORLD.barrier()

def compute_context(self):
context = super().compute_context()
del context["config"]["mpi_args"]
del context["config"]["ranks"]
del context["config"]["nodes"]
return context
4 changes: 4 additions & 0 deletions src/miv_simulator/interface/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def launch(self):
{
"projections": config.projections,
"cell_distributions": config.cell_distributions,
"population_definitions": config.definitions.populations,
},
],
)
Expand Down Expand Up @@ -67,6 +68,7 @@ def launch(self):
].output_filepath,
"cell_types": config.cell_types,
"population": population,
"layer_definitions": config.definitions.layers,
"distribution": "poisson",
"mechanisms_path": self.config.mechanisms_path,
"template_path": self.config.template_path,
Expand All @@ -86,6 +88,8 @@ def launch(self):
population
].output_filepath,
"axon_extents": config.axon_extents,
"population_definitions": config.definitions.populations,
"layer_definitions": config.definitions.layers,
"io_size": 1,
"cache_size": 20,
"write_size": 100,
Expand Down
4 changes: 2 additions & 2 deletions src/miv_simulator/interface/neuroh5_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

class NeuroH5Graph(Component):
class Config:
mpi_args: str = "-n 1"
ranks: int = 1
nodes: str = "1"

def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -102,7 +102,7 @@ def files(self) -> Dict[str, str]:

def compute_context(self):
context = super().compute_context()
del context["config"]["mpi_args"]
del context["config"]["ranks"]
del context["config"]["nodes"]
context["predicate"]["uses"] = sorted([u.hash for u in self.uses])
return context
5 changes: 3 additions & 2 deletions src/miv_simulator/interface/synapse_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ class Config(BaseModel):
filepath: str = Field("???")
population: config.PopulationName = Field("???")
morphology: config.SWCFilePath = Field("???")
mpi_args: Optional[str] = "-n 1"
# ranks: int = 1
mpi: Optional[str] = None
nodes: str = "1"

@property
Expand Down Expand Up @@ -62,7 +63,7 @@ def dispatch_code(

def compute_context(self):
context = super().compute_context()
del context["config"]["mpi_args"]
del context["config"]["mpi"]
del context["config"]["nodes"]
del context["config"]["filepath"]
context["config"]["morphology"] = file_hash(
Expand Down
4 changes: 4 additions & 0 deletions src/miv_simulator/interface/synapses.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class Config(BaseModel):
forest_filepath: str = Field("???")
cell_types: config.CellTypes = Field("???")
population: str = Field("???")
layer_definitions: Dict[str, int] = Field("???")
distribution: str = "uniform"
mechanisms_path: str = "./mechanisms/compiled"
template_path: str = "./templates"
Expand Down Expand Up @@ -51,6 +52,9 @@ def __call__(self):
simulator.distribute_synapses(
forest_filepath=self.config.forest_filepath,
cell_types=self.config.cell_types,
swc_defs=config.SWCTypesDef.__members__,
synapse_defs=config.SynapseTypesDef.__members__,
layer_defs=self.config.layer_definitions,
populations=(self.config.population,),
distribution=self.config.distribution,
template_path=self.config.template_path,
Expand Down
14 changes: 8 additions & 6 deletions src/miv_simulator/simulator/distribute_synapses.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def distribute_synapse_locations(
forest_filepath=forest_path,
cell_types=env.celltypes,
swc_defs=env.SWC_Types,
synapse_defs=env.Synapse_Types,
layer_defs=env.layers,
populations=populations,
distribution=distribution,
Expand All @@ -251,6 +252,7 @@ def distribute_synapses(
forest_filepath: str,
cell_types: config.CellTypes,
swc_defs: Dict[str, int],
synapse_defs: Dict[str, int],
layer_defs: Dict[str, int],
populations: Tuple[str, ...],
distribution: Literal["uniform", "poisson"],
Expand Down Expand Up @@ -363,9 +365,9 @@ def distribute_synapses(
seg_density_per_sec,
) = synapses.distribute_uniform_synapses(
random_seed,
config.SynapseTypesDef.__members__,
config.SWCTypesDef.__members__,
config.LayersDef.__members__,
synapse_defs,
swc_defs,
layer_defs,
density_dict,
morph_dict,
cell_sec_dict,
Expand All @@ -378,9 +380,9 @@ def distribute_synapses(
seg_density_per_sec,
) = synapses.distribute_poisson_synapses(
random_seed,
config.SynapseTypesDef.__members__,
config.SWCTypesDef.__members__,
config.LayersDef.__members__,
synapse_defs,
swc_defs,
layer_defs,
density_dict,
morph_dict,
cell_sec_dict,
Expand Down
2 changes: 1 addition & 1 deletion src/miv_simulator/simulator/generate_connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
read_population_names,
read_population_ranges,
)
from typing import Optional, Union, Tuple
from typing import Optional, Union, Tuple, Dict

sys_excepthook = sys.excepthook

Expand Down
36 changes: 19 additions & 17 deletions src/miv_simulator/simulator/generate_synapse_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,25 @@ def _bin_check(bin: str) -> None:
raise FileNotFoundError(f"{bin} not found. Did you add it to the PATH?")


def _sh(cmd, spawn_process=True):
if not spawn_process:
return os.system(" ".join([shlex.quote(c) for c in cmd]))

try:
subprocess.check_output(
cmd,
stderr=subprocess.STDOUT,
)
except subprocess.CalledProcessError as e:
error_message = e.output.decode()
print(f"{os.getcwd()}$:")
print(" ".join(cmd))
print("Error:", error_message)
raise subprocess.CalledProcessError(
e.returncode, e.cmd, output=error_message
)
def _sh(cmd, spawn_process=False):
if spawn_process:
try:
subprocess.check_output(
cmd,
stderr=subprocess.STDOUT,
)
except subprocess.CalledProcessError as e:
error_message = e.output.decode()
print(f"{os.getcwd()}$:")
print(" ".join(cmd))
print("Error:", error_message)
raise subprocess.CalledProcessError(
e.returncode, e.cmd, output=error_message
)
else:
cmdq = " ".join([shlex.quote(c) for c in cmd])
if os.system(cmdq) != 0:
raise RuntimeError(f"Error running {cmdq}")


def generate_synapse_forest(
Expand Down
Loading

0 comments on commit f2402b5

Please sign in to comment.