diff --git a/pyneuroml/plot/PlotMorphology.py b/pyneuroml/plot/PlotMorphology.py index c9bb7c9ae..bb214946d 100644 --- a/pyneuroml/plot/PlotMorphology.py +++ b/pyneuroml/plot/PlotMorphology.py @@ -75,10 +75,17 @@ def process_args(): help="Plane to plot on for 2D plot", ) + parser.add_argument( + "-pointFraction", + type=str, + metavar="", + default=DEFAULTS["pointFraction"], + help="Fraction of network to plot as point cells", + ) parser.add_argument( "-plotType", type=str, - metavar="", + metavar="", default=DEFAULTS["plotType"], help="Level of detail to plot in", ) @@ -147,6 +154,7 @@ def plot_from_console(a: typing.Optional[typing.Any] = None, **kwargs: str): verbose=a.v, plot_type=a.plot_type, theme=a.theme, + plot_spec={"point_fraction": a.point_fraction}, ) else: plot_2D( @@ -158,6 +166,7 @@ def plot_from_console(a: typing.Optional[typing.Any] = None, **kwargs: str): a.save_to_file, a.square, a.plot_type, + plot_spec={"point_fraction": a.point_fraction}, ) @@ -172,6 +181,9 @@ def plot_2D( plot_type: str = "detailed", title: typing.Optional[str] = None, close_plot: bool = False, + plot_spec: typing.Optional[ + typing.Dict[str, typing.Union[str, typing.List[int], float]] + ] = None, ): """Plot cells in a 2D plane. @@ -205,6 +217,7 @@ def plot_2D( - "constant": show morphology, but use constant line widths - "schematic": only plot each unbranched segment group as a straight line, not following each segment + - "point": show all cells as points This is only applicable for neuroml.Cell cells (ones with some morphology) @@ -214,20 +227,36 @@ def plot_2D( :type title: str :param close_plot: call pyplot.close() to close plot after plotting :type close_plot: bool + :param plot_spec: dictionary that allows passing some specifications that + control how a plot is generated. This is mostly useful for large + network plots where one may want to have a mix of full morphology and + schematic, and point representations of cells. Possible keys are: + + - point_fraction: what fraction of each population to plot as point cells: + these cells will be randomly selected + - points_cells: list of cell ids to plot as point cells + - schematic_cells: list of cell ids to plot as schematics + - constant_cells: list of cell ids to plot as constant widths + + The last three lists override the point_fraction setting. If a cell id + is not included in the spec here, it will follow the plot_type provided + before. """ - if plot_type not in ["detailed", "constant", "schematic"]: + if plot_type not in ["detailed", "constant", "schematic", "point"]: raise ValueError( - "plot_type must be one of 'detailed', 'constant', or 'schematic'" + "plot_type must be one of 'detailed', 'constant', 'schematic', 'point'" ) if verbose: print("Plotting %s" % nml_file) - if type(nml_file) == str: + # do not recursive read the file, the extract_position_info function will + # do that for us, from a copy of the model + if type(nml_file) is str: nml_model = read_neuroml2_file( nml_file, - include_includes=True, + include_includes=False, check_validity_pre_include=False, verbose=False, optimized=True, @@ -250,7 +279,9 @@ def plot_2D( positions, pop_id_vs_color, pop_id_vs_radii, - ) = extract_position_info(nml_model, verbose) + ) = extract_position_info( + nml_model, verbose, nml_file if type(nml_file) is str else "" + ) if title is None: if len(nml_model.networks) > 0: @@ -268,12 +299,45 @@ def plot_2D( fig, ax = get_new_matplotlib_morph_plot(title, plane2d) axis_min_max = [float("inf"), -1 * float("inf")] - for pop_id in pop_id_vs_cell: - cell = pop_id_vs_cell[pop_id] - pos_pop = positions[pop_id] + # process plot_spec + point_cells = [] # type: typing.List[int] + schematic_cells = [] # type: typing.List[int] + constant_cells = [] # type: typing.List[int] + detailed_cells = [] # type: typing.List[int] + if plot_spec is not None: + try: + point_cells = plot_spec["point_cells"] + except KeyError: + pass + try: + schematic_cells = plot_spec["schematic_cells"] + except KeyError: + pass + try: + constant_cells = plot_spec["constant_cells"] + except KeyError: + pass + try: + detailed_cells = plot_spec["detailed_cells"] + except KeyError: + pass + + for pop_id, cell in pop_id_vs_cell.items(): + pos_pop = positions[pop_id] # type: typing.Dict[typing.Any, typing.List[float]] + + # reinit point_cells for each loop + point_cells_pop = [] + if len(point_cells) == 0 and plot_spec is not None: + cell_indices = list(pos_pop.keys()) + try: + point_cells_pop = random.sample( + cell_indices, + int(len(cell_indices) * float(plot_spec["point_fraction"])), + ) + except KeyError: + pass - for cell_index in pos_pop: - pos = pos_pop[cell_index] + for cell_index, pos in pos_pop.items(): radius = pop_id_vs_radii[pop_id] if pop_id in pop_id_vs_radii else 10 color = pop_id_vs_color[pop_id] if pop_id in pop_id_vs_color else None @@ -291,12 +355,36 @@ def plot_2D( nogui=True, ) else: - if plot_type == "schematic": + if ( + plot_type == "point" + or cell_index in point_cells_pop + or cell.id in point_cells + ): + # assume that soma is 0, plot point at where soma should be + soma_x_y_z = cell.get_actual_proximal(0) + pos1 = [ + pos[0] + soma_x_y_z.x, + pos[1] + soma_x_y_z.y, + pos[2] + soma_x_y_z.z, + ] + plot_2D_point_cells( + offset=pos1, + plane2d=plane2d, + color=color, + soma_radius=radius, + verbose=verbose, + ax=ax, + fig=fig, + autoscale=False, + scalebar=False, + nogui=True, + ) + elif plot_type == "schematic" or cell.id in schematic_cells: plot_2D_schematic( offset=pos, cell=cell, segment_groups=None, - labels=True, + labels=False, plane2d=plane2d, verbose=verbose, fig=fig, @@ -306,7 +394,12 @@ def plot_2D( autoscale=False, square=False, ) - else: + elif ( + plot_type == "detailed" + or cell.id in detailed_cells + or plot_type == "constant" + or cell.id in constant_cells + ): plot_2D_cell_morphology( offset=pos, cell=cell, diff --git a/pyneuroml/plot/PlotMorphologyVispy.py b/pyneuroml/plot/PlotMorphologyVispy.py index 82bf931c4..af5eca998 100644 --- a/pyneuroml/plot/PlotMorphologyVispy.py +++ b/pyneuroml/plot/PlotMorphologyVispy.py @@ -11,22 +11,17 @@ import logging -import typing -import numpy +import random import textwrap -from vispy import scene, app +import typing -from pyneuroml.utils.plot import ( - DEFAULTS, - get_cell_bound_box, - get_next_hex_color, -) +import numpy +from neuroml import Cell, NeuroMLDocument, Segment, SegmentGroup +from neuroml.neuro_lex_ids import neuro_lex_ids from pyneuroml.pynml import read_neuroml2_file from pyneuroml.utils import extract_position_info - -from neuroml import Cell, NeuroMLDocument, SegmentGroup, Segment -from neuroml.neuro_lex_ids import neuro_lex_ids - +from pyneuroml.utils.plot import DEFAULTS, get_cell_bound_box, get_next_hex_color +from vispy import app, scene logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -296,6 +291,9 @@ def plot_interactive_3D( title: typing.Optional[str] = None, theme: str = "light", nogui: bool = False, + plot_spec: typing.Optional[ + typing.Dict[str, typing.Union[str, typing.List[int], float]] + ] = None, ): """Plot interactive plots in 3D using Vispy @@ -316,6 +314,7 @@ def plot_interactive_3D( - "constant": show morphology, but use constant line widths - "schematic": only plot each unbranched segment group as a straight line, not following each segment + - "point": show all cells as points This is only applicable for neuroml.Cell cells (ones with some morphology) @@ -327,19 +326,33 @@ def plot_interactive_3D( :type theme: str :param nogui: toggle showing gui (for testing only) :type nogui: bool + :param plot_spec: dictionary that allows passing some specifications that + control how a plot is generated. This is mostly useful for large + network plots where one may want to have a mix of full morphology and + schematic, and point representations of cells. Possible keys are: + + - point_fraction: what fraction of each population to plot as point cells: + these cells will be randomly selected + - points_cells: list of cell ids to plot as point cells + - schematic_cells: list of cell ids to plot as schematics + - constant_cells: list of cell ids to plot as constant widths + + The last three lists override the point_fraction setting. If a cell id + is not included in the spec here, it will follow the plot_type provided + before. """ - if plot_type not in ["detailed", "constant", "schematic"]: + if plot_type not in ["detailed", "constant", "schematic", "point"]: raise ValueError( - "plot_type must be one of 'detailed', 'constant', or 'schematic'" + "plot_type must be one of 'detailed', 'constant', 'schematic', 'point'" ) if verbose: print(f"Plotting {nml_file}") - if type(nml_file) == str: + if type(nml_file) is str: nml_model = read_neuroml2_file( nml_file, - include_includes=True, + include_includes=False, check_validity_pre_include=False, verbose=False, optimized=True, @@ -360,7 +373,9 @@ def plot_interactive_3D( positions, pop_id_vs_color, pop_id_vs_radii, - ) = extract_position_info(nml_model, verbose) + ) = extract_position_info( + nml_model, verbose, nml_file if type(nml_file) is str else "" + ) # Collect all markers and only plot one markers object # this is more efficient than multiple markers, one for each point. @@ -429,12 +444,45 @@ def plot_interactive_3D( logger.debug(f"figure extents are: {view_min}, {view_max}") - for pop_id in pop_id_vs_cell: - cell = pop_id_vs_cell[pop_id] - pos_pop = positions[pop_id] + # process plot_spec + point_cells = [] # type: typing.List[int] + schematic_cells = [] # type: typing.List[int] + constant_cells = [] # type: typing.List[int] + detailed_cells = [] # type: typing.List[int] + if plot_spec is not None: + try: + point_cells = plot_spec["point_cells"] + except KeyError: + pass + try: + schematic_cells = plot_spec["schematic_cells"] + except KeyError: + pass + try: + constant_cells = plot_spec["constant_cells"] + except KeyError: + pass + try: + detailed_cells = plot_spec["detailed_cells"] + except KeyError: + pass + + for pop_id, cell in pop_id_vs_cell.items(): + pos_pop = positions[pop_id] # type: typing.Dict[typing.Any, typing.List[float]] - for cell_index in pos_pop: - pos = pos_pop[cell_index] + # reinit point_cells for each loop + point_cells_pop = [] + if len(point_cells) == 0 and plot_spec is not None: + cell_indices = list(pos_pop.keys()) + try: + point_cells_pop = random.sample( + cell_indices, + int(len(cell_indices) * float(plot_spec["point_fraction"])), + ) + except KeyError: + pass + + for cell_index, pos in pos_pop.items(): radius = pop_id_vs_radii[pop_id] if pop_id in pop_id_vs_radii else 10 color = pop_id_vs_color[pop_id] if pop_id in pop_id_vs_color else None @@ -448,7 +496,24 @@ def plot_interactive_3D( marker_sizes.extend([radius]) marker_colors.extend([color]) else: - if plot_type == "schematic": + if ( + plot_type == "point" + or cell_index in point_cells_pop + or cell.id in point_cells + ): + # assume that soma is 0, plot point at where soma should be + soma_x_y_z = cell.get_actual_proximal(0) + pos1 = [ + pos[0] + soma_x_y_z.x, + pos[1] + soma_x_y_z.y, + pos[2] + soma_x_y_z.z, + ] + marker_points.extend([pos1]) + # larger than the default soma width, which would be too + # small + marker_sizes.extend([25]) + marker_colors.extend([color]) + elif plot_type == "schematic" or cell.id in schematic_cells: plot_3D_schematic( offset=pos, cell=cell, @@ -459,7 +524,12 @@ def plot_interactive_3D( current_view=current_view, nogui=True, ) - else: + elif ( + plot_type == "detailed" + or cell.id in detailed_cells + or plot_type == "constant" + or cell.id in constant_cells + ): pts, sizes, colors = plot_3D_cell_morphology( offset=pos, cell=cell, diff --git a/pyneuroml/utils/__init__.py b/pyneuroml/utils/__init__.py index aff6861f9..f266da0c6 100644 --- a/pyneuroml/utils/__init__.py +++ b/pyneuroml/utils/__init__.py @@ -5,6 +5,7 @@ Copyright 2023 NeuroML Contributors """ + import copy import datetime import logging @@ -30,7 +31,7 @@ def extract_position_info( - nml_model: neuroml.NeuroMLDocument, verbose: bool = False + nml_model: neuroml.NeuroMLDocument, verbose: bool = False, nml_file_path: str = "" ) -> tuple: """Extract position information from a NeuroML model @@ -46,17 +47,29 @@ def extract_position_info( :type nml_model: NeuroMLDocument :param verbose: toggle function verbosity :type verbose: bool + :param nml_file_path: path of file corresponding to the model + :type nml_file_path: str :returns: [cell id vs cell dict, pop id vs cell dict, positions dict, pop id vs colour dict, pop id vs radii dict] :rtype: tuple of dicts """ + base_path = os.path.dirname(os.path.realpath(nml_file_path)) + # create a copy of the original model that we can manipulate as required nml_model_copy = copy.deepcopy(nml_model) - # add any included cells to the main document - for inc in nml_model_copy.includes: - inc = read_neuroml2_file(inc.href) - for acell in inc.cells: - nml_model_copy.add(acell) + # remove bits of the model we don't need + model_members = list(vars(nml_model_copy).keys()) + required_members = [ + "id", + "cells", + "cell2_ca_poolses", + "networks", + "populations", + "includes", + ] + for m in model_members: + if m not in required_members: + setattr(nml_model_copy, m, None) cell_id_vs_cell = {} positions = {} @@ -65,16 +78,47 @@ def extract_position_info( pop_id_vs_radii = {} cell_elements = [] - cell_elements.extend(nml_model_copy.cells) - cell_elements.extend(nml_model_copy.cell2_ca_poolses) - - for cell in cell_elements: - cell_id_vs_cell[cell.id] = cell + popElements = [] + # if the model contains a network, use it if len(nml_model_copy.networks) > 0: - popElements = nml_model_copy.networks[0].populations + # remove network members we don't need + network_members = list(vars(nml_model_copy.networks[0]).keys()) + for m in network_members: + if m != "populations": + setattr(nml_model_copy.networks[0], m, None) + + # get a list of what cell types are used in the various populations + required_cell_types = [ + pop.component for pop in nml_model_copy.networks[0].populations + ] + + # add only required cells that are included in populations to the + # document + for inc in nml_model_copy.includes: + incl_loc = os.path.abspath(os.path.join(base_path, inc.href)) + inc = read_neuroml2_file(incl_loc) + for acell in inc.cells: + if acell.id in required_cell_types: + acell.biophysical_properties = None + nml_model_copy.add(acell) + + cell_elements.extend(nml_model_copy.cells) + cell_elements.extend(nml_model_copy.cell2_ca_poolses) + # if the model does not include a network, plot all the cells in the + # model in new dummy populations else: - popElements = [] + # add any included cells to the main document + for inc in nml_model_copy.includes: + incl_loc = os.path.abspath(os.path.join(base_path, inc.href)) + inc = read_neuroml2_file(incl_loc) + for acell in inc.cells: + acell.biophysical_properties = None + nml_model_copy.add(acell) + + cell_elements.extend(nml_model_copy.cells) + cell_elements.extend(nml_model_copy.cell2_ca_poolses) + net = neuroml.Network(id="x") nml_model_copy.networks.append(net) cell_str = "" @@ -86,17 +130,20 @@ def extract_position_info( cell_str += cell.id + "__" net.id = cell_str[:-2] - popElements = nml_model_copy.networks[0].populations + popElements = nml_model_copy.networks[0].populations + + for cell in cell_elements: + cell_id_vs_cell[cell.id] = cell for pop in popElements: name = pop.id celltype = pop.component instances = pop.instances - if pop.component in cell_id_vs_cell.keys(): - pop_id_vs_cell[pop.id] = cell_id_vs_cell[pop.component] + if celltype in cell_id_vs_cell.keys(): + pop_id_vs_cell[name] = cell_id_vs_cell[celltype] else: - pop_id_vs_cell[pop.id] = None + pop_id_vs_cell[name] = None info = "Population: %s has %i positioned cells of type: %s" % ( name, diff --git a/pyneuroml/utils/plot.py b/pyneuroml/utils/plot.py index c37b1f605..896bf129a 100644 --- a/pyneuroml/utils/plot.py +++ b/pyneuroml/utils/plot.py @@ -34,6 +34,7 @@ "square": False, "plotType": "detailed", "theme": "light", + "pointFraction": 0, } # type: dict[str, typing.Any] diff --git a/setup.cfg b/setup.cfg index 798b5c56e..3ec1df270 100644 --- a/setup.cfg +++ b/setup.cfg @@ -126,4 +126,4 @@ doc = pydata-sphinx-theme [flake8] -ignore = E501, E502, F403, F405 +extend-ignore = E501, E502, F403, F405, W503, W504 diff --git a/tests/plot/test_morphology_plot.py b/tests/plot/test_morphology_plot.py index 2e24b0b5b..4120c5d61 100644 --- a/tests/plot/test_morphology_plot.py +++ b/tests/plot/test_morphology_plot.py @@ -115,14 +115,60 @@ def test_2d_morphology_plotter_data_overlay(self): self.assertIsFile(filename) pl.Path(filename).unlink() + def test_2d_plotter_network_with_spec(self): + """Test plot_2D function with a network of a few cells with specs.""" + nml_file = "tests/plot/L23-example/TestNetwork.net.nml" + ofile = pl.Path(nml_file).name + # percentage + for plane in ["xy", "yz", "xz"]: + filename = f"test_morphology_plot_2d_spec_{ofile.replace('.', '_', 100)}_{plane}.png" + # remove the file first + try: + pl.Path(filename).unlink() + except FileNotFoundError: + pass + + plot_2D( + nml_file, + nogui=True, + plane2d=plane, + save_to_file=filename, + plot_spec={"point_fraction": 0.5}, + ) + + self.assertIsFile(filename) + pl.Path(filename).unlink() + + # more detailed plot_spec + for plane in ["xy", "yz", "xz"]: + filename = f"test_morphology_plot_2d_spec_{ofile.replace('.', '_', 100)}_{plane}.png" + # remove the file first + try: + pl.Path(filename).unlink() + except FileNotFoundError: + pass + + plot_2D( + nml_file, + nogui=True, + plane2d=plane, + save_to_file=filename, + plot_spec={ + "point_cells": ["HL23VIP"], + "detailed_cells": ["HL23PYR"], + "schematic_cells": ["HL23PV"], + "constant_cells": ["HL23SST"], + }, + ) + self.assertIsFile(filename) + pl.Path(filename).unlink() + def test_2d_plotter_network(self): """Test plot_2D function with a network of a few cells.""" nml_file = "tests/plot/L23-example/TestNetwork.net.nml" ofile = pl.Path(nml_file).name for plane in ["xy", "yz", "xz"]: - filename = ( - f"tests/plot/test_morphology_plot_2d_{ofile.replace('.', '_', 100)}_{plane}.png" - ) + filename = f"tests/plot/test_morphology_plot_2d_{ofile.replace('.', '_', 100)}_{plane}.png" # remove the file first try: pl.Path(filename).unlink() @@ -132,7 +178,7 @@ def test_2d_plotter_network(self): plot_2D(nml_file, nogui=True, plane2d=plane, save_to_file=filename) self.assertIsFile(filename) - pl.Path(filename).unlink() + # pl.Path(filename).unlink() def test_2d_constant_plotter_network(self): """Test plot_2D_schematic function with a network of a few cells.""" @@ -196,7 +242,36 @@ def test_3d_schematic_plotter(self): def test_3d_morphology_plotter_vispy_network(self): """Test plot_3D_cell_morphology_vispy function.""" nml_file = "tests/plot/L23-example/TestNetwork.net.nml" - plot_interactive_3D(nml_file, min_width=1, nogui=True, theme="dark") + plot_interactive_3D(nml_file, min_width=1, nogui=False, theme="dark") + + @pytest.mark.localonly + def test_3d_morphology_plotter_vispy_network_with_spec(self): + """Test plot_3D_cell_morphology_vispy function.""" + nml_file = "tests/plot/L23-example/TestNetwork.net.nml" + plot_interactive_3D( + nml_file, + min_width=1, + nogui=True, + theme="dark", + plot_spec={"point_fraction": 0.5}, + ) + + @pytest.mark.localonly + def test_3d_morphology_plotter_vispy_network_with_spec2(self): + """Test plot_3D_cell_morphology_vispy function.""" + nml_file = "tests/plot/L23-example/TestNetwork.net.nml" + plot_interactive_3D( + nml_file, + min_width=1, + nogui=True, + theme="dark", + plot_spec={ + "point_cells": ["HL23VIP"], + "detailed_cells": ["HL23PYR"], + "schematic_cells": ["HL23PV"], + "constant_cells": ["HL23SST"], + }, + ) @pytest.mark.localonly def test_3d_plotter_vispy(self): @@ -221,7 +296,9 @@ def test_3d_plotter_plotly(self): nml_files = ["tests/plot/Cell_497232312.cell.nml", "tests/plot/test.cell.nml"] for nml_file in nml_files: ofile = pl.Path(nml_file).name - filename = f"tests/plot/test_morphology_plot_3d_{ofile.replace('.', '_', 100)}.png" + filename = ( + f"tests/plot/test_morphology_plot_3d_{ofile.replace('.', '_', 100)}.png" + ) # remove the file first try: pl.Path(filename).unlink() @@ -248,9 +325,7 @@ def test_2d_schematic_plotter(self): for plane in ["xy", "yz", "xz"]: # olm cell - filename = ( - f"tests/plot/test_schematic_plot_2d_{olm_ofile.replace('.', '_', 100)}_{plane}.png" - ) + filename = f"tests/plot/test_schematic_plot_2d_{olm_ofile.replace('.', '_', 100)}_{plane}.png" try: pl.Path(filename).unlink() except FileNotFoundError: @@ -265,9 +340,7 @@ def test_2d_schematic_plotter(self): ) # more complex cell - filename = ( - f"tests/plot/test_schematic_plot_2d_{ofile.replace('.', '_', 100)}_{plane}.png" - ) + filename = f"tests/plot/test_schematic_plot_2d_{ofile.replace('.', '_', 100)}_{plane}.png" # remove the file first try: pl.Path(filename).unlink()