Skip to content

Commit

Permalink
refactor: flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
gtdang committed Oct 23, 2024
1 parent fec7b57 commit 30fa1bb
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
1 change: 0 additions & 1 deletion hnn_core/tests/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,6 @@ def _get_line_hex_colors(fig):
colors=dict_mapping)



def test_network_plotter_init(setup_net):
"""Test init keywords of NetworkPlotter class."""
net = setup_net
Expand Down
11 changes: 7 additions & 4 deletions hnn_core/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,19 +559,22 @@ def plot_spikes_raster(cell_response, trial_idx=None, ax=None, show=True,
default_cell_types = cell_types

# Set default colors
default_colors = plt.rcParams['axes.prop_cycle'].by_key()['color'][:len(default_cell_types)]
cell_colors = {cell: color for cell, color in zip(default_cell_types, default_colors)}
default_colors = (plt.rcParams['axes.prop_cycle']
.by_key()['color'][:len(default_cell_types)])
cell_colors = {cell: color
for cell, color in zip(default_cell_types, default_colors)}

# validate colors argument
_validate_type(colors, (list, dict, None),'color', 'list of str, or dict')
_validate_type(colors, (list, dict, None), 'color', 'list of str, or dict')
if colors:
if isinstance(colors, list):
if len(colors) != len(default_cell_types):
raise ValueError(
f"Number of colors must be equal to number of "
f"cell types. {len(colors)} colors provided "
f"for {len(default_cell_types)} cell types.")
cell_colors = {cell: color for cell, color in zip(default_cell_types, colors)}
cell_colors = {cell: color
for cell, color in zip(default_cell_types, colors)}

if isinstance(colors, dict):
# Check valid cell types
Expand Down

0 comments on commit 30fa1bb

Please sign in to comment.