Skip to content

Commit

Permalink
Restructure simulator module to allow for direct library-like usage
Browse files Browse the repository at this point in the history
  • Loading branch information
frthjf committed Sep 22, 2023
1 parent ce9ab69 commit c82d7ce
Show file tree
Hide file tree
Showing 31 changed files with 238 additions and 247 deletions.
10 changes: 4 additions & 6 deletions docs/api/simulator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@ Simulator
.. autosummary::
:toctree: _toctree

distribute_synapse_locations
generate_distance_connections
generate_input_features
generate_input_spike_trains
generate_soma_coordinates
make_h5types
distribute_synapses
generate_connections
generate_network_architecture
generate_synapse_forest
measure_distances
2 changes: 2 additions & 0 deletions src/miv_simulator/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

# Definitions

SWCFilePath = str


class SWCTypesDef(IntEnum):
soma = 1
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@

from machinable import Component
from pydantic import BaseModel, Field
from miv_simulator import config
from miv_simulator import config, simulator
from typing import Optional, Dict
from miv_simulator.simulator import distance_connections
from miv_simulator.utils import from_yaml
from mpi4py import MPI


class DistanceConnections(Component):
class Connections(Component):
class Config(BaseModel):
filepath: str = Field("???")
forest_filepath: str = Field("???")
Expand Down Expand Up @@ -44,7 +43,7 @@ def output_filepath(self):

def __call__(self):
logging.basicConfig(level=logging.INFO)
distance_connections(
simulator.generate_connections(
filepath=self.config.filepath,
forest_filepath=self.config.forest_filepath,
include_forest_populations=self.config.include_forest_populations,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

from machinable import Component
from machinable.config import Field
from miv_simulator.simulator import measure_distances
from miv_simulator import config
from miv_simulator import config, simulator
from mpi4py import MPI


Expand All @@ -32,7 +31,7 @@ class Config(BaseModel):

def __call__(self):
logging.basicConfig(level=logging.INFO)
measure_distances(
simulator.measure_distances(
filepath=self.config.filepath,
geometry_filepath=self.config.geometry_filepath,
coordinate_namespace=self.config.coordinate_namespace,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from machinable import Component
from machinable.element import normversion
from machinable.types import VersionType
from pydantic import BaseModel, Field
from miv_simulator import config
from mpi4py import MPI
from miv_simulator.utils import io as io_utils, from_yaml
from typing import Dict


class CreateH5(Component):
class H5Types(Component):
class Config(BaseModel):
cell_distributions: config.CellDistributions = Field("???")
synapses: config.Synapses = Field("???")
Expand Down
2 changes: 1 addition & 1 deletion src/miv_simulator/interface/legacy/derive_spike_trains.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from machinable import Component
from pydantic import Field, BaseModel
from miv_simulator.simulator import (
from miv_simulator.input_spike_trains import (
generate_input_spike_trains,
import_input_spike_train,
)
Expand Down
2 changes: 1 addition & 1 deletion src/miv_simulator/interface/legacy/input_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic import Field, BaseModel
from machinable.element import normversion
from machinable.types import VersionType
from miv_simulator.simulator import generate_input_features
from miv_simulator.input_features import generate_input_features


class InputFeatures(Component):
Expand Down
15 changes: 6 additions & 9 deletions src/miv_simulator/interface/legacy/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,17 @@ def __call__(self):
self.distance_connections = {}
self.synapse_forest = {}
for dependency in self.uses:
if dependency.module == "miv_simulator.interface.create_network":
if (
dependency.module
== "miv_simulator.interface.network_architecture"
):
self.network = dependency
elif (
dependency.module
== "miv_simulator.interface.legacy.derive_spike_trains"
):
self.spike_trains = dependency
elif (
dependency.module
== "miv_simulator.interface.distance_connections"
):
elif dependency.module == "miv_simulator.interface.connections":
populations = read_population_names(
dependency.config.forest_filepath
)
Expand All @@ -74,10 +74,7 @@ def __call__(self):
f"defined in {self.synapse_forest[dependency.config.population]}"
)
self.synapse_forest[dependency.config.population] = dependency
elif (
dependency.module
== "miv_simulator.interface.distribute_synapses"
):
elif dependency.module == "miv_simulator.interface.synapses":
if dependency.config.population in self.synapses:
# check for duplicates
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,14 @@
from machinable import Component
from machinable.element import normversion
from machinable.types import VersionType
from miv_simulator import config
from miv_simulator.simulator.soma_coordinates import (
generate as generate_soma_coordinates,
)
from miv_simulator import config, simulator
from miv_simulator.utils import io as io_utils, from_yaml
from mpi4py import MPI
from pydantic import BaseModel, Field


class CreateNetwork(Component):
"""Creates neural H5 type definitions and soma coordinates within specified layer geometry."""
class NetworkArchitecture(Component):
"""Creates the network architecture by generating the soma coordinates within specified layer geometry."""

class Config(BaseModel):
filepath: str = Field("???")
Expand Down Expand Up @@ -45,7 +42,7 @@ def on_write_meta_data(self):

def __call__(self) -> None:
logging.basicConfig(level=logging.INFO)
generate_soma_coordinates(
simulator.generate_network_architecture(
output_filepath=self.config.filepath,
cell_distributions=self.config.cell_distributions,
layer_extents=self.config.layer_extents,
Expand All @@ -67,7 +64,7 @@ def __call__(self) -> None:

def measure_distances(self, version: VersionType = None):
return self.derive(
"miv_simulator.interface.measure_distances",
"miv_simulator.interface.distances",
[
{
"filepath": self.config.filepath,
Expand All @@ -87,7 +84,9 @@ def measure_distances(self, version: VersionType = None):
uses=self,
)

def synapse_forest(self, version: VersionType = None) -> "Component":
def generate_synapse_forest(
self, version: VersionType = None
) -> "Component":
return self.derive(
"miv_simulator.interface.synapse_forest",
[
Expand All @@ -101,14 +100,14 @@ def synapse_forest(self, version: VersionType = None) -> "Component":

def distribute_synapses(self, version: VersionType = None):
return self.derive(
"miv_simulator.interface.distribute_synapses",
"miv_simulator.interface.synapses",
[] + normversion(version),
uses=self,
)

def distance_connections(self, version: VersionType = None):
def generate_connections(self, version: VersionType = None):
return self.derive(
"miv_simulator.interface.distance_connections",
"miv_simulator.interface.connections",
[
{
"filepath": self.config.filepath,
Expand Down
81 changes: 9 additions & 72 deletions src/miv_simulator/interface/synapse_forest.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,13 @@
import os
import shutil
import subprocess

import h5py
from machinable import Component
from miv_simulator import config
from miv_simulator import config, simulator
from pydantic import BaseModel, Field


def _bin_check(bin: str) -> None:
if not shutil.which(bin):
raise FileNotFoundError(f"{bin} not found. Did you add it to the PATH?")


SWCFilePath = str


class GenerateSynapseForest(Component):
class Config(BaseModel):
filepath: str = Field("???")
population: config.PopulationName = Field("???")
morphology: SWCFilePath = Field("???")
morphology: config.SWCFilePath = Field("???")

@property
def tree_output_filepath(self) -> str:
Expand All @@ -31,60 +18,10 @@ def output_filepath(self) -> str:
return self.local_directory("forest.h5")

def __call__(self) -> None:
# create tree
if not os.path.isfile(self.tree_output_filepath):
_bin_check("neurotrees_import")
assert (
subprocess.run(
[
"neurotrees_import",
self.config.population,
self.tree_output_filepath,
self.config.morphology,
]
).returncode
== 0
)
assert (
subprocess.run(
[
"h5copy",
"-p",
"-s",
"/H5Types",
"-d",
"/H5Types",
"-i",
self.config.filepath,
"-o",
self.tree_output_filepath,
]
).returncode
== 0
)

if not os.path.isfile(self.output_filepath):
# determine population ranges
with h5py.File(self.config.filepath, "r") as f:
idx = list(
reversed(
f["H5Types"]["Population labels"].dtype.metadata["enum"]
)
).index(self.config.population)
offset = f["H5Types"]["Populations"][idx][0]

_bin_check("neurotrees_copy")
assert (
subprocess.run(
[
"neurotrees_copy",
"--fill",
"--output",
self.output_filepath,
self.tree_output_filepath,
self.config.population,
str(offset),
]
).returncode
== 0
)
simulator.generate_synapse_forest(
filepath=self.config.filepath,
tree_output_filepath=self.tree_output_filepath,
output_filepath=self.output_filepath,
population=self.config.population,
morphology=self.config.morphology,
)
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,18 @@
from mpi4py import MPI


class DistributeSynapses(Component):
class Synapses(Component):
class Config(BaseModel):
forest_filepath: str = Field("???")
cell_types: config.CellTypes = Field("???")
population: str = Field("???")
distribution: str = "uniform"
mechanisms_path: str = "./mechanisms"
template_path: str = "./templates"
dt: float = 0.025
tstop: float = 0.0
celsius: Optional[float] = 35.0
io_size: int = -1
write_size: int = 1
chunk_size: int = 1000
value_chunk_size: int = 1000
use_coreneuron: bool = False
ranks_: int = 8
nodes_: int = 1

Expand All @@ -44,15 +40,11 @@ def __call__(self):
distribution=self.config.distribution,
mechanisms_path=self.config.mechanisms_path,
template_path=self.config.template_path,
dt=self.config.dt,
tstop=self.config.tstop,
celsius=self.config.celsius,
output_filepath=self.output_filepath,
io_size=self.config.io_size,
write_size=self.config.write_size,
chunk_size=self.config.chunk_size,
value_chunk_size=self.config.value_chunk_size,
use_coreneuron=self.config.use_coreneuron,
seed=self.seed,
dry_run=False,
)
Expand Down
28 changes: 10 additions & 18 deletions src/miv_simulator/simulator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
from miv_simulator.simulator.distribute_synapse_locations import (
distribute_synapse_locations,
distribute_synapses,
)
from miv_simulator.simulator.generate_distance_connections import (
generate_distance_connections,
distance_connections,
)
from miv_simulator.simulator.generate_input_features import (
generate_input_features,
)
from miv_simulator.simulator.generate_input_spike_trains import (
generate_input_spike_trains,
import_input_spike_train,
__doc__ = """Contains the end-user public API of the MiV-Simulator"""

from miv_simulator.utils.io import create_neural_h5
from miv_simulator.simulator.generate_network_architecture import (
generate_network_architecture,
)
from miv_simulator.simulator.soma_coordinates import generate_soma_coordinates
from miv_simulator.simulator.measure_distances import measure_distances

# !deprecated, use io_utils directly
from miv_simulator.simulator.make_h5types import make_h5types
from miv_simulator.simulator.generate_synapse_forest import (
generate_synapse_forest,
)
from miv_simulator.simulator.distribute_synapses import distribute_synapses
from miv_simulator.simulator.generate_connections import generate_connections
Loading

0 comments on commit c82d7ce

Please sign in to comment.