From c8c87589cfead011421e07eb4c99ea2e6a16c0d2 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Fri, 29 Nov 2024 22:10:51 -0500 Subject: [PATCH 01/16] add prepare multi plot func --- neurodsp/plts/utils.py | 42 +++++++++++++++++++++++++++++++ neurodsp/tests/plts/test_utils.py | 27 ++++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/neurodsp/plts/utils.py b/neurodsp/plts/utils.py index 553791bb..0412e658 100644 --- a/neurodsp/plts/utils.py +++ b/neurodsp/plts/utils.py @@ -3,7 +3,9 @@ from copy import deepcopy from functools import wraps from os.path import join as pjoin +from itertools import repeat, cycle +import numpy as np import matplotlib.pyplot as plt from neurodsp.plts.settings import SUPTITLE_FONTSIZE @@ -155,3 +157,43 @@ def make_axes(n_rows, n_cols, figsize=None, row_size=4, col_size=3.6, **title_kwargs) return axes + + +def prepare_multi_plot_elements(xs, ys, labels=None, colors=None): + """Prepare inputs for plotting one or more elements in a loop. + + Parameters + ---------- + xs, ys : 1d or 2d array + Plot data. + labels : str or list + Label(s) for the plot input(s). + colors : str or iterable + Color(s) to plot input(s). + + Returns + ------- + xs, ys : iterable + Plot data. + labels : iterable + Label(s) for the plot input(s). + colors : iterable + Color(s) to plot input(s). + + Notes + ----- + This function takes inputs that can reflect one or more plot elements, and + prepares the inputs to be iterable for plotting in a loop. + """ + + xs = repeat(xs) if isinstance(xs, np.ndarray) and xs.ndim == 1 else xs + ys = [ys] if isinstance(ys, np.ndarray) and ys.ndim == 1 else ys + + if labels is not None: + labels = [labels] if not isinstance(labels, list) else labels + else: + labels = repeat(labels) + + colors = repeat(colors) if not isinstance(colors, list) else cycle(colors) + + return xs, ys, labels, colors diff --git a/neurodsp/tests/plts/test_utils.py b/neurodsp/tests/plts/test_utils.py index f727c42e..8a0b4935 100644 --- a/neurodsp/tests/plts/test_utils.py +++ b/neurodsp/tests/plts/test_utils.py @@ -1,7 +1,9 @@ """Tests for neurodsp.plts.utils.""" import os +import itertools +import numpy as np import matplotlib as mpl from neurodsp.tests.settings import TEST_PLOTS_PATH @@ -74,3 +76,28 @@ def test_make_axes(): axes = make_axes(2, 2) assert axes.shape == (2, 2) assert isinstance(axes[0, 0], mpl.axes._axes.Axes) + +def test_prepare_multi_plot_elements(): + + xs1 = np.array([1, 2, 3]) + ys1 = np.array([1, 2, 3]) + labels1 = None + colors1 = None + + # 1 input + xs1o, ys1o, labels1o, colors1o = prepare_multi_plot_elements(xs1, ys1, labels1, colors1) + assert isinstance(xs1o, itertools.repeat) + assert isinstance(ys1o, list) + assert isinstance(labels1o, itertools.repeat) + assert isinstance(colors1o, itertools.repeat) + + # multiple inputs + xs2 = [np.array([1, 2, 3]), np.array([4, 5, 6])] + ys2 = [np.array([1, 2, 3]), np.array([4, 5, 6])] + labels2 = ['A', 'B'] + colors2 = ['blue', 'red'] + xs2o, ys2o, labels2o, colors2o = prepare_multi_plot_elements(xs2, ys2, labels2, colors2) + assert isinstance(xs2o, list) + assert isinstance(ys2o, list) + assert isinstance(labels2o, list) + assert isinstance(colors2o, itertools.cycle) From b1190f58cdb1236e7ff37752ed136620571a3e10 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Fri, 29 Nov 2024 22:16:54 -0500 Subject: [PATCH 02/16] use prep multi plot for spectral --- neurodsp/plts/spectral.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/neurodsp/plts/spectral.py b/neurodsp/plts/spectral.py index c1ddfdad..cd5a5a39 100644 --- a/neurodsp/plts/spectral.py +++ b/neurodsp/plts/spectral.py @@ -6,7 +6,7 @@ import matplotlib.pyplot as plt from neurodsp.plts.style import style_plot -from neurodsp.plts.utils import check_ax, savefig +from neurodsp.plts.utils import check_ax, savefig, prepare_multi_plot_elements ################################################################################################### ################################################################################################### @@ -47,18 +47,8 @@ def plot_power_spectra(freqs, powers, labels=None, colors=None, ax=None, **kwarg ax = check_ax(ax, figsize=kwargs.pop('figsize', (6, 6))) - freqs = repeat(freqs) if isinstance(freqs, np.ndarray) and freqs.ndim == 1 else freqs - powers = [powers] if isinstance(powers, np.ndarray) and powers.ndim == 1 else powers - - if labels is not None: - labels = [labels] if not isinstance(labels, list) else labels - else: - labels = repeat(labels) - - colors = repeat(colors) if not isinstance(colors, list) else cycle(colors) - - for freq, power, color, label in zip(freqs, powers, colors, labels): - ax.loglog(freq, power, color=color, label=label) + for freq, power, label, color in zip(*prepare_multi_plot_elements(freqs, powers, labels, colors)): + ax.loglog(freq, power, label=label, color=color) ax.set_xlabel('Frequency (Hz)') ax.set_ylabel('Power ($V^2/Hz$)') From 417f4ddb72423c67c83cb3d3c6325d4d8994a28c Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Fri, 29 Nov 2024 22:38:13 -0500 Subject: [PATCH 03/16] use prep multi plot in ts plots --- neurodsp/plts/time_series.py | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/neurodsp/plts/time_series.py b/neurodsp/plts/time_series.py index 114fb2b5..8b07d26f 100644 --- a/neurodsp/plts/time_series.py +++ b/neurodsp/plts/time_series.py @@ -6,7 +6,7 @@ import matplotlib.pyplot as plt from neurodsp.plts.style import style_plot -from neurodsp.plts.utils import check_ax, savefig +from neurodsp.plts.utils import check_ax, savefig, prepare_multi_plot_elements from neurodsp.utils.data import create_samples from neurodsp.utils.checks import check_param_options @@ -49,18 +49,13 @@ def plot_time_series(times, sigs, labels=None, colors=None, ax=None, **kwargs): ax = check_ax(ax, kwargs.pop('figsize', (15, 3))) - sigs = [sigs] if (isinstance(sigs, np.ndarray) and sigs.ndim == 1) else sigs times, xlabel = _check_times(times, sigs) - if labels is not None: - labels = [labels] if not isinstance(labels, list) else labels - else: - labels = repeat(labels) + times, sigs, colors, labels = prepare_multi_plot_elements(times, sigs, colors, labels) # If not provided, default colors for up to two signals to be black & red - if not colors and len(sigs) <= 2: + if isinstance(colors, repeat) and next(colors) is None and len(sigs) <= 2: colors = ['k', 'r'] - colors = repeat(colors) if not isinstance(colors, list) else cycle(colors) for time, sig, color, label in zip(times, sigs, colors, labels): ax.plot(time, sig, color=color, label=label) @@ -174,22 +169,19 @@ def plot_multi_time_series(times, sigs, colors=None, ax=None, **plt_kwargs): Keyword arguments for customizing the plot. """ + times, xlabel = _check_times(times, sigs) colors = 'black' if not colors else colors - colors = repeat(colors) if isinstance(colors, str) else iter(colors) ax = check_ax(ax, figsize=plt_kwargs.pop('figsize', (15, 5))) - sigs = [sigs] if (isinstance(sigs, np.ndarray) and sigs.ndim == 1) else sigs - times, xlabel = _check_times(times, sigs) + times, sigs, _, colors = prepare_multi_plot_elements(times, sigs, None, colors) step = 0.8 * np.ptp(sigs[0]) for ind, (time, sig) in enumerate(zip(times, sigs)): ax.plot(time, sig+step*ind, color=next(colors), **plt_kwargs) - ax.set(yticks=[]) - ax.set_xlabel(xlabel) - ax.set_ylabel('Channels') + ax.set(xlabel=xlabel, ylabel='Channels', yticks=[]) def _check_times(times, sigs): @@ -200,6 +192,4 @@ def _check_times(times, sigs): times = create_samples(len(sigs[0])) xlabel = 'Samples' - times = repeat(times) if (isinstance(times, np.ndarray) and times.ndim == 1) else times - return times, xlabel From 8965c75d56f6bd202d4b1f9c5c5073661daf440e Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Fri, 29 Nov 2024 22:39:41 -0500 Subject: [PATCH 04/16] tweak prep multi plot name --- neurodsp/plts/spectral.py | 4 ++-- neurodsp/plts/time_series.py | 6 +++--- neurodsp/plts/utils.py | 2 +- neurodsp/tests/plts/test_utils.py | 6 +++--- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/neurodsp/plts/spectral.py b/neurodsp/plts/spectral.py index cd5a5a39..0f9860f6 100644 --- a/neurodsp/plts/spectral.py +++ b/neurodsp/plts/spectral.py @@ -6,7 +6,7 @@ import matplotlib.pyplot as plt from neurodsp.plts.style import style_plot -from neurodsp.plts.utils import check_ax, savefig, prepare_multi_plot_elements +from neurodsp.plts.utils import check_ax, savefig, prepare_multi_plot ################################################################################################### ################################################################################################### @@ -47,7 +47,7 @@ def plot_power_spectra(freqs, powers, labels=None, colors=None, ax=None, **kwarg ax = check_ax(ax, figsize=kwargs.pop('figsize', (6, 6))) - for freq, power, label, color in zip(*prepare_multi_plot_elements(freqs, powers, labels, colors)): + for freq, power, label, color in zip(*prepare_multi_plot(freqs, powers, labels, colors)): ax.loglog(freq, power, label=label, color=color) ax.set_xlabel('Frequency (Hz)') diff --git a/neurodsp/plts/time_series.py b/neurodsp/plts/time_series.py index 8b07d26f..75439e3b 100644 --- a/neurodsp/plts/time_series.py +++ b/neurodsp/plts/time_series.py @@ -6,7 +6,7 @@ import matplotlib.pyplot as plt from neurodsp.plts.style import style_plot -from neurodsp.plts.utils import check_ax, savefig, prepare_multi_plot_elements +from neurodsp.plts.utils import check_ax, savefig, prepare_multi_plot from neurodsp.utils.data import create_samples from neurodsp.utils.checks import check_param_options @@ -51,7 +51,7 @@ def plot_time_series(times, sigs, labels=None, colors=None, ax=None, **kwargs): times, xlabel = _check_times(times, sigs) - times, sigs, colors, labels = prepare_multi_plot_elements(times, sigs, colors, labels) + times, sigs, colors, labels = prepare_multi_plot(times, sigs, colors, labels) # If not provided, default colors for up to two signals to be black & red if isinstance(colors, repeat) and next(colors) is None and len(sigs) <= 2: @@ -174,7 +174,7 @@ def plot_multi_time_series(times, sigs, colors=None, ax=None, **plt_kwargs): ax = check_ax(ax, figsize=plt_kwargs.pop('figsize', (15, 5))) - times, sigs, _, colors = prepare_multi_plot_elements(times, sigs, None, colors) + times, sigs, _, colors = prepare_multi_plot(times, sigs, None, colors) step = 0.8 * np.ptp(sigs[0]) diff --git a/neurodsp/plts/utils.py b/neurodsp/plts/utils.py index 0412e658..c0a087ed 100644 --- a/neurodsp/plts/utils.py +++ b/neurodsp/plts/utils.py @@ -159,7 +159,7 @@ def make_axes(n_rows, n_cols, figsize=None, row_size=4, col_size=3.6, return axes -def prepare_multi_plot_elements(xs, ys, labels=None, colors=None): +def prepare_multi_plot(xs, ys, labels=None, colors=None): """Prepare inputs for plotting one or more elements in a loop. Parameters diff --git a/neurodsp/tests/plts/test_utils.py b/neurodsp/tests/plts/test_utils.py index 8a0b4935..346d593e 100644 --- a/neurodsp/tests/plts/test_utils.py +++ b/neurodsp/tests/plts/test_utils.py @@ -77,7 +77,7 @@ def test_make_axes(): assert axes.shape == (2, 2) assert isinstance(axes[0, 0], mpl.axes._axes.Axes) -def test_prepare_multi_plot_elements(): +def test_prepare_multi_plot(): xs1 = np.array([1, 2, 3]) ys1 = np.array([1, 2, 3]) @@ -85,7 +85,7 @@ def test_prepare_multi_plot_elements(): colors1 = None # 1 input - xs1o, ys1o, labels1o, colors1o = prepare_multi_plot_elements(xs1, ys1, labels1, colors1) + xs1o, ys1o, labels1o, colors1o = prepare_multi_plot(xs1, ys1, labels1, colors1) assert isinstance(xs1o, itertools.repeat) assert isinstance(ys1o, list) assert isinstance(labels1o, itertools.repeat) @@ -96,7 +96,7 @@ def test_prepare_multi_plot_elements(): ys2 = [np.array([1, 2, 3]), np.array([4, 5, 6])] labels2 = ['A', 'B'] colors2 = ['blue', 'red'] - xs2o, ys2o, labels2o, colors2o = prepare_multi_plot_elements(xs2, ys2, labels2, colors2) + xs2o, ys2o, labels2o, colors2o = prepare_multi_plot(xs2, ys2, labels2, colors2) assert isinstance(xs2o, list) assert isinstance(ys2o, list) assert isinstance(labels2o, list) From 63048b611930685f3eb51e3212550eff3f9d1b8f Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Fri, 29 Nov 2024 23:28:11 -0500 Subject: [PATCH 05/16] add plot_autocorr func --- neurodsp/plts/aperiodic.py | 35 +++++++++++++++++++++++++++ neurodsp/tests/plts/test_aperiodic.py | 25 +++++++++++++++++++ 2 files changed, 60 insertions(+) create mode 100644 neurodsp/plts/aperiodic.py create mode 100644 neurodsp/tests/plts/test_aperiodic.py diff --git a/neurodsp/plts/aperiodic.py b/neurodsp/plts/aperiodic.py new file mode 100644 index 00000000..5bc602fa --- /dev/null +++ b/neurodsp/plts/aperiodic.py @@ -0,0 +1,35 @@ +"""Plotting functions for neurodsp.aperiodic.""" + +from neurodsp.plts.style import style_plot +from neurodsp.plts.utils import check_ax, savefig, prepare_multi_plot + +#################################################################################################### +#################################################################################################### + +@savefig +@style_plot +def plot_autocorr(timepoints, autocorrs, labels=None, colors=None, ax=None, **kwargs): + """Plot autocorrelation results. + + Parameters + ---------- + timepoints : 1d array + Time points, in samples, at which autocorrelations are computed. + autocorrs : array + Autocorrelation values, across time lags. + labels : str or list of str, optional + Labels for each time series. + colors : str or list of str + Colors to use to plot lines. + ax : matplotlib.Axes, optional + Figure axes upon which to plot. + **kwargs + Keyword arguments for customizing the plot. + """ + + ax = check_ax(ax, figsize=kwargs.pop('figsize', (6, 5))) + + for time, ac, label, color in zip(*prepare_multi_plot(timepoints, autocorrs, labels, colors)): + ax.plot(time, ac, label=label, color=color) + + ax.set(xlabel='Lag (Samples)', ylabel='Autocorrelation') diff --git a/neurodsp/tests/plts/test_aperiodic.py b/neurodsp/tests/plts/test_aperiodic.py new file mode 100644 index 00000000..cfa675f7 --- /dev/null +++ b/neurodsp/tests/plts/test_aperiodic.py @@ -0,0 +1,25 @@ +"""Tests for neurodsp.plts.aperiodic.""" + +from neurodsp.aperiodic.autocorr import compute_autocorr + +from neurodsp.tests.settings import TEST_PLOTS_PATH, FS +from neurodsp.tests.tutils import plot_test + +from neurodsp.plts.aperiodic import * + +################################################################################################### +################################################################################################### + +def tests_plot_autocorr(tsig, tsig_comb): + + times1, acs1 = compute_autocorr(tsig, max_lag=150) + times2, acs2 = compute_autocorr(tsig_comb, max_lag=150) + + plot_autocorr(times1, acs1, + save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_autocorr-1.png') + + plot_autocorr([times1, times2], [acs1, acs2], + labels=['first', 'second'], colors=['k', 'r'], + save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_autocorr-2.png') From 905b1df95b77959a40efbc6ba7d4a83577283f36 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Fri, 29 Nov 2024 23:28:23 -0500 Subject: [PATCH 06/16] add plot_autocorr to API list --- doc/api.rst | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/doc/api.rst b/doc/api.rst index d34b0fe6..de1f4e31 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -465,6 +465,15 @@ Time Frequency plot_timefrequency +Aperiodic +~~~~~~~~~ + +.. currentmodule:: neurodsp.plts +.. autosummary:: + :toctree: generated/ + + plot_autocorr + Combined ~~~~~~~~ From 6fa2811b85f213d4640bbde2d860f87380baef47 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Fri, 29 Nov 2024 23:28:42 -0500 Subject: [PATCH 07/16] use new plot autocorr func in tutorial --- tutorials/aperiodic/plot_Autocorr.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/tutorials/aperiodic/plot_Autocorr.py b/tutorials/aperiodic/plot_Autocorr.py index 8adc6580..83b9da29 100644 --- a/tutorials/aperiodic/plot_Autocorr.py +++ b/tutorials/aperiodic/plot_Autocorr.py @@ -15,12 +15,12 @@ # sphinx_gallery_thumbnail_number = 1 import numpy as np -import matplotlib.pyplot as plt from neurodsp.sim import sim_powerlaw, sim_oscillation # Import the function for computing autocorrelation from neurodsp.aperiodic import compute_autocorr +from neurodsp.plts import plot_autocorr ################################################################################################### # Autocorrelation Measures @@ -92,9 +92,7 @@ ################################################################################################### # Plot autocorrelations -_, ax = plt.subplots(figsize=(6, 4)) -ax.plot(timepoints_osc1, autocorrs_osc1) -ax.set(xlabel='lag (samples)', ylabel='autocorrelation'); +plot_autocorr(timepoints_osc1, autocorrs_osc1) ################################################################################################### # @@ -109,11 +107,9 @@ ################################################################################################### # Plot autocorrelations for two different sinusoids -_, ax = plt.subplots(figsize=(6, 4)) -ax.plot(timepoints_osc1, autocorrs_osc1, alpha=0.75, label='10 Hz') -ax.plot(timepoints_osc2, autocorrs_osc2, alpha=0.75, label='20 Hz') -ax.set(xlabel='lag (samples)', ylabel='autocorrelation') -plt.legend(loc='upper right') +plot_autocorr([timepoints_osc1, timepoints_osc2], + [autocorrs_osc1, autocorrs_osc2], + labels=['10 Hz', '20 Hz']) ################################################################################################### # @@ -152,11 +148,9 @@ ################################################################################################### # Plot the autocorrelations of the aperiodic signals -_, ax = plt.subplots(figsize=(5, 4)) -ax.plot(timepoints_wn, autocorrs_wn, label='White Noise') -ax.plot(timepoints_pn, autocorrs_pn, label='Pink Noise') -ax.set(xlabel="lag (samples)", ylabel="autocorrelation") -plt.legend() +plot_autocorr([timepoints_wn, timepoints_pn], + [autocorrs_wn, autocorrs_pn], + labels=['White Noise', 'Pink Noise']) ################################################################################################### # @@ -165,4 +159,4 @@ # # By comparison, the pink noise signal has a pattern of decreasing autocorrelation # across increasing lags. This is characteristic of powerlaw data. -# \ No newline at end of file +# From 49765244e8706a6422b669cb03c964b93b0f6c90 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Fri, 29 Nov 2024 23:30:57 -0500 Subject: [PATCH 08/16] tweaks & lints --- neurodsp/plts/__init__.py | 1 + neurodsp/plts/spectral.py | 2 -- neurodsp/plts/time_series.py | 12 ++++++------ 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/neurodsp/plts/__init__.py b/neurodsp/plts/__init__.py index f0fb58bf..3c836be5 100644 --- a/neurodsp/plts/__init__.py +++ b/neurodsp/plts/__init__.py @@ -7,4 +7,5 @@ from .spectral import (plot_power_spectra, plot_spectral_hist, plot_scv, plot_scv_rs_lines, plot_scv_rs_matrix) from .timefrequency import plot_timefrequency +from .aperiodic import plot_autocorr from .combined import plot_timeseries_and_spectra diff --git a/neurodsp/plts/spectral.py b/neurodsp/plts/spectral.py index 0f9860f6..7a87406f 100644 --- a/neurodsp/plts/spectral.py +++ b/neurodsp/plts/spectral.py @@ -1,7 +1,5 @@ """Plotting functions for neurodsp.spectral.""" -from itertools import repeat, cycle - import numpy as np import matplotlib.pyplot as plt diff --git a/neurodsp/plts/time_series.py b/neurodsp/plts/time_series.py index 75439e3b..18a68c1d 100644 --- a/neurodsp/plts/time_series.py +++ b/neurodsp/plts/time_series.py @@ -1,6 +1,6 @@ """Plots for time series.""" -from itertools import repeat, cycle +from itertools import repeat import numpy as np import matplotlib.pyplot as plt @@ -50,7 +50,6 @@ def plot_time_series(times, sigs, labels=None, colors=None, ax=None, **kwargs): ax = check_ax(ax, kwargs.pop('figsize', (15, 3))) times, xlabel = _check_times(times, sigs) - times, sigs, colors, labels = prepare_multi_plot(times, sigs, colors, labels) # If not provided, default colors for up to two signals to be black & red @@ -169,11 +168,11 @@ def plot_multi_time_series(times, sigs, colors=None, ax=None, **plt_kwargs): Keyword arguments for customizing the plot. """ - times, xlabel = _check_times(times, sigs) - colors = 'black' if not colors else colors - ax = check_ax(ax, figsize=plt_kwargs.pop('figsize', (15, 5))) + colors = 'black' if not colors else colors + + times, xlabel = _check_times(times, sigs) times, sigs, _, colors = prepare_multi_plot(times, sigs, None, colors) step = 0.8 * np.ptp(sigs[0]) @@ -189,7 +188,8 @@ def _check_times(times, sigs): xlabel = 'Time (s)' if times is None: - times = create_samples(len(sigs[0])) + n_samples = len(sigs[0]) if isinstance(sigs, np.ndarray) and sigs.ndim == 2 else len(sigs) + times = create_samples(n_samples) xlabel = 'Samples' return times, xlabel From da60bcffe2bd96caea77dfd0469a488436bdf5bd Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 3 Dec 2024 10:45:25 -0500 Subject: [PATCH 09/16] add plot_spectra_3d --- neurodsp/plts/spectral.py | 65 ++++++++++++++++++++++++++++ neurodsp/tests/plts/test_spectral.py | 15 +++++++ 2 files changed, 80 insertions(+) diff --git a/neurodsp/plts/spectral.py b/neurodsp/plts/spectral.py index 7a87406f..7ff2295d 100644 --- a/neurodsp/plts/spectral.py +++ b/neurodsp/plts/spectral.py @@ -223,3 +223,68 @@ def plot_spectral_hist(freqs, power_bins, spectral_hist, spectrum_freqs=None, if spectrum is not None: plt_inds = np.logical_and(spectrum_freqs >= freqs[0], spectrum_freqs <= freqs[-1]) ax.plot(spectrum_freqs[plt_inds], np.log10(spectrum[plt_inds]), color='w', alpha=0.8) + + +@savefig +@style_plot +def plot_spectra_3D(freqs, powers, log_freqs=False, log_powers=True, + colors=None, orientation=(20, -50), **kwargs): + """Plot a series of power spectra in a 3D plot. + + Parameters + ---------- + freqs : 1d or 2d array or list of 1d array + Frequency vector. + powers : 2d array or list of 1d array + Power values. + log_freqs : bool, optional, default: False + Whether to plot the frequency values in log10 space. + log_powers : bool, optional, default: True + Whether to plot the power values in log10 space. + colors : str or list of str + Colors to use to plot lines. + orientation : tuple of int + Orientation to set the 3D plot. + **kwargs + Keyword arguments for customizing the plot. + + Examples + -------- + Plot power spectra in 3D: + + >>> from neurodsp.sim import sim_combined + >>> from neurodsp.spectral import compute_spectrum + >>> sig1 = sim_combined(n_seconds=10, fs=500, + ... components={'sim_powerlaw': {'exponent' : -1}, + ... 'sim_bursty_oscillation' : {'freq': 10}}) + >>> sig2 = sim_combined(n_seconds=10, fs=500, + ... components={'sim_powerlaw': {'exponent' : -1.5}, + ... 'sim_bursty_oscillation' : {'freq': 10}}) + >>> freqs1, powers1 = compute_spectrum(sig1, fs=500) + >>> freqs2, powers2 = compute_spectrum(sig2, fs=500) + >>> plot_spectra_3D([freqs1, freqs2], [powers1, powers2]) + """ + + fig = plt.figure() + ax = fig.add_subplot(projection='3d') + + n_spectra = len(powers) + + for ind, (freq, power, _, color) in \ + enumerate(zip(*prepare_multi_plot(freqs, powers, None, colors))): + ax.plot(xs=np.log10(freq) if log_freqs else freq, + ys=[ind] * len(freq), + zs=np.log10(power) if log_powers else power, + color=color, + **kwargs) + + ax.set( + xlabel='Frequency (Hz)', + ylabel='Channels', + zlabel='Power', + ylim=[0, n_spectra - 1], + ) + + yticks = list(range(n_spectra)) + ax.set_yticks(yticks, yticks) + ax.view_init(*orientation) diff --git a/neurodsp/tests/plts/test_spectral.py b/neurodsp/tests/plts/test_spectral.py index 049ce8bb..410b7f1f 100644 --- a/neurodsp/tests/plts/test_spectral.py +++ b/neurodsp/tests/plts/test_spectral.py @@ -68,3 +68,18 @@ def test_plot_spectral_hist(tsig_comb): spectrum=spectrum, spectrum_freqs=spectrum_freqs, save_fig=True, file_path=TEST_PLOTS_PATH, file_name='test_plot_spectral_hist.png') + +@plot_test +def test_plot_spectra_3D(tsig_comb, tsig_burst): + + freqs1, powers1 = compute_spectrum(tsig_comb, FS) + freqs2, powers2 = compute_spectrum(tsig_burst, FS) + + plot_spectra_3D([freqs1, freqs2], [powers1, powers2], + save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_spectral3D_1.png') + + plot_spectra_3D(freqs1, [powers1, powers2, powers1, powers2], + colors=['r', 'y', 'b', 'g'], + save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_spectral3D_2.png') From a6f32341eabc98505efc9f9ade68553f9f476b11 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Tue, 3 Dec 2024 10:46:03 -0500 Subject: [PATCH 10/16] add 3d plot to API list --- doc/api.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/api.rst b/doc/api.rst index de1f4e31..d4199fc1 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -430,6 +430,7 @@ Spectral :toctree: generated/ plot_power_spectra + plot_spectra_3D plot_scv plot_scv_rs_lines plot_scv_rs_matrix From 17ab4e8691a6111c82db7bd7deebb4363640ff39 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Thu, 5 Dec 2024 17:24:48 -0500 Subject: [PATCH 11/16] fix list input to _check_times --- neurodsp/plts/time_series.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/neurodsp/plts/time_series.py b/neurodsp/plts/time_series.py index 18a68c1d..1db61142 100644 --- a/neurodsp/plts/time_series.py +++ b/neurodsp/plts/time_series.py @@ -188,7 +188,10 @@ def _check_times(times, sigs): xlabel = 'Time (s)' if times is None: - n_samples = len(sigs[0]) if isinstance(sigs, np.ndarray) and sigs.ndim == 2 else len(sigs) + if isinstance(sigs, list) or (isinstance(sigs, np.ndarray) and sigs.ndim == 2): + n_samples = len(sigs[0]) + else: + n_samples = len(sigs) times = create_samples(n_samples) xlabel = 'Samples' From 4d3f956a84cc23f683024bb59ebb09d3f73a4832 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 2 Feb 2025 15:55:04 -0500 Subject: [PATCH 12/16] review updates / fixes --- neurodsp/plts/__init__.py | 2 +- neurodsp/plts/style.py | 4 +++- neurodsp/plts/utils.py | 7 +++++-- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/neurodsp/plts/__init__.py b/neurodsp/plts/__init__.py index 3c836be5..4ce61552 100644 --- a/neurodsp/plts/__init__.py +++ b/neurodsp/plts/__init__.py @@ -4,7 +4,7 @@ plot_multi_time_series) from .filt import plot_filter_properties, plot_frequency_response, plot_impulse_response from .rhythm import plot_swm_pattern, plot_lagged_coherence -from .spectral import (plot_power_spectra, plot_spectral_hist, +from .spectral import (plot_power_spectra, plot_spectral_hist, plot_spectra_3D, plot_scv, plot_scv_rs_lines, plot_scv_rs_matrix) from .timefrequency import plot_timefrequency from .aperiodic import plot_autocorr diff --git a/neurodsp/plts/style.py b/neurodsp/plts/style.py index e71506c8..ddd37bb2 100644 --- a/neurodsp/plts/style.py +++ b/neurodsp/plts/style.py @@ -113,10 +113,12 @@ def apply_custom_style(ax, **kwargs): if ax.get_title(): ax.title.set_size(kwargs.pop('title_fontsize', TITLE_FONTSIZE)) - # Settings for the axis labels + # Settings for the axis labels, including checking & setting for 3D axis label_size = kwargs.pop('label_size', LABEL_SIZE) ax.xaxis.label.set_size(label_size) ax.yaxis.label.set_size(label_size) + if hasattr(ax, 'zaxis'): + ax.zaxis.label.set_size(label_size) # Settings for the axis ticks ax.tick_params(axis='both', which='major', diff --git a/neurodsp/plts/utils.py b/neurodsp/plts/utils.py index c0a087ed..4b1a38b1 100644 --- a/neurodsp/plts/utils.py +++ b/neurodsp/plts/utils.py @@ -189,11 +189,14 @@ def prepare_multi_plot(xs, ys, labels=None, colors=None): xs = repeat(xs) if isinstance(xs, np.ndarray) and xs.ndim == 1 else xs ys = [ys] if isinstance(ys, np.ndarray) and ys.ndim == 1 else ys + # Collect definition of collection items considered iterables to check against + iterables = (list, tuple, np.ndarray) + if labels is not None: - labels = [labels] if not isinstance(labels, list) else labels + labels = [labels] if not isinstance(labels, iterables) else labels else: labels = repeat(labels) - colors = repeat(colors) if not isinstance(colors, list) else cycle(colors) + colors = repeat(colors) if not isinstance(colors, iterables) else cycle(colors) return xs, ys, labels, colors From f6882024541f6ae13fc921509ee127245a377b92 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 3 Feb 2025 23:09:44 -0500 Subject: [PATCH 13/16] add check_ax_3d --- neurodsp/plts/utils.py | 30 ++++++++++++++++++++++++++++++ neurodsp/tests/plts/test_utils.py | 18 ++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/neurodsp/plts/utils.py b/neurodsp/plts/utils.py index 4b1a38b1..aed83cf2 100644 --- a/neurodsp/plts/utils.py +++ b/neurodsp/plts/utils.py @@ -62,6 +62,36 @@ def check_ax(ax, figsize=None): return ax +def check_ax_3d(ax, figsize=None): + """Check whether a 3D figure axes object is defined, define if not. + + Parameters + ---------- + ax : matplotlib.Axes or None + Axes object to check if is defined. Must be 3D. + + Returns + ------- + ax : matplotlib.Axes + Figure axes object to use. + + Raises + ------ + ValueError + If the ax input is a defined axis, but is not 3D. + """ + + if ax and '3d' not in ax.name: + raise ValueError('Provided axis is not 3D.') + + if not ax: + + fig = plt.figure(figsize=figsize) + ax = fig.add_subplot(projection='3d') + + return ax + + def savefig(func): """Decorator function to save out figures.""" diff --git a/neurodsp/tests/plts/test_utils.py b/neurodsp/tests/plts/test_utils.py index 346d593e..527ff18d 100644 --- a/neurodsp/tests/plts/test_utils.py +++ b/neurodsp/tests/plts/test_utils.py @@ -1,5 +1,7 @@ """Tests for neurodsp.plts.utils.""" +from pytest import raises + import os import itertools @@ -42,6 +44,22 @@ def test_check_ax(): fig = plt.gcf() assert list(fig.get_size_inches()) == figsize +def test_check_ax_3d(): + + # Check running with None Input + ax = check_ax(None) + + # Check error if given a non 3D axis + with raises(ValueError): + _, ax = plt.subplots() + nax = check_ax_3d(ax) + + # Check running with pre-created axis + fig = plt.figure() + ax = fig.add_subplot(projection='3d') + nax = check_ax(ax) + assert nax == ax + def test_savefig(): @savefig From af23327e86d2e584af15adf497490aa35289d817 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 3 Feb 2025 23:10:14 -0500 Subject: [PATCH 14/16] update plot spectra 3d --- neurodsp/plts/spectral.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/neurodsp/plts/spectral.py b/neurodsp/plts/spectral.py index 7ff2295d..4e539b23 100644 --- a/neurodsp/plts/spectral.py +++ b/neurodsp/plts/spectral.py @@ -4,7 +4,7 @@ import matplotlib.pyplot as plt from neurodsp.plts.style import style_plot -from neurodsp.plts.utils import check_ax, savefig, prepare_multi_plot +from neurodsp.plts.utils import check_ax, check_ax_3d, savefig, prepare_multi_plot ################################################################################################### ################################################################################################### @@ -227,8 +227,8 @@ def plot_spectral_hist(freqs, power_bins, spectral_hist, spectrum_freqs=None, @savefig @style_plot -def plot_spectra_3D(freqs, powers, log_freqs=False, log_powers=True, - colors=None, orientation=(20, -50), **kwargs): +def plot_spectra_3D(freqs, powers, log_freqs=False, log_powers=True, colors=None, + orientation=(20, -50), zoom=1.0, ax=None, **kwargs): """Plot a series of power spectra in a 3D plot. Parameters @@ -241,10 +241,14 @@ def plot_spectra_3D(freqs, powers, log_freqs=False, log_powers=True, Whether to plot the frequency values in log10 space. log_powers : bool, optional, default: True Whether to plot the power values in log10 space. - colors : str or list of str + colors : str or list of str, optional Colors to use to plot lines. - orientation : tuple of int - Orientation to set the 3D plot. + orientation : tuple of int, optional, default: (20, -50) + Orientation to set the 3D plot. See `Axes3D.view_init` for more information. + zoom : float, optional, default: 1.0 + Zoom scaling for the figure axis. See `Axes3D.set_box_aspect` for more information. + ax : matplotlib.Axes, optional + Figure axes upon which to plot. Must be a 3D axis. **kwargs Keyword arguments for customizing the plot. @@ -265,8 +269,7 @@ def plot_spectra_3D(freqs, powers, log_freqs=False, log_powers=True, >>> plot_spectra_3D([freqs1, freqs2], [powers1, powers2]) """ - fig = plt.figure() - ax = fig.add_subplot(projection='3d') + ax = check_ax_3d(ax) n_spectra = len(powers) @@ -287,4 +290,6 @@ def plot_spectra_3D(freqs, powers, log_freqs=False, log_powers=True, yticks = list(range(n_spectra)) ax.set_yticks(yticks, yticks) + ax.view_init(*orientation) + ax.set_box_aspect(None, zoom=zoom) From fee975720e32b95ee96f3e6a2c867cc1432af323 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 3 Feb 2025 23:12:28 -0500 Subject: [PATCH 15/16] tweak 3d naming --- doc/api.rst | 2 +- neurodsp/plts/__init__.py | 2 +- neurodsp/plts/spectral.py | 2 +- neurodsp/tests/plts/test_spectral.py | 6 +++--- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index d4199fc1..e76e040b 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -430,7 +430,7 @@ Spectral :toctree: generated/ plot_power_spectra - plot_spectra_3D + plot_spectra_3d plot_scv plot_scv_rs_lines plot_scv_rs_matrix diff --git a/neurodsp/plts/__init__.py b/neurodsp/plts/__init__.py index 4ce61552..5ce0b10d 100644 --- a/neurodsp/plts/__init__.py +++ b/neurodsp/plts/__init__.py @@ -4,7 +4,7 @@ plot_multi_time_series) from .filt import plot_filter_properties, plot_frequency_response, plot_impulse_response from .rhythm import plot_swm_pattern, plot_lagged_coherence -from .spectral import (plot_power_spectra, plot_spectral_hist, plot_spectra_3D, +from .spectral import (plot_power_spectra, plot_spectral_hist, plot_spectra_3d, plot_scv, plot_scv_rs_lines, plot_scv_rs_matrix) from .timefrequency import plot_timefrequency from .aperiodic import plot_autocorr diff --git a/neurodsp/plts/spectral.py b/neurodsp/plts/spectral.py index 4e539b23..069c50fc 100644 --- a/neurodsp/plts/spectral.py +++ b/neurodsp/plts/spectral.py @@ -227,7 +227,7 @@ def plot_spectral_hist(freqs, power_bins, spectral_hist, spectrum_freqs=None, @savefig @style_plot -def plot_spectra_3D(freqs, powers, log_freqs=False, log_powers=True, colors=None, +def plot_spectra_3d(freqs, powers, log_freqs=False, log_powers=True, colors=None, orientation=(20, -50), zoom=1.0, ax=None, **kwargs): """Plot a series of power spectra in a 3D plot. diff --git a/neurodsp/tests/plts/test_spectral.py b/neurodsp/tests/plts/test_spectral.py index 410b7f1f..fab0107c 100644 --- a/neurodsp/tests/plts/test_spectral.py +++ b/neurodsp/tests/plts/test_spectral.py @@ -70,16 +70,16 @@ def test_plot_spectral_hist(tsig_comb): file_name='test_plot_spectral_hist.png') @plot_test -def test_plot_spectra_3D(tsig_comb, tsig_burst): +def test_plot_spectra_3d(tsig_comb, tsig_burst): freqs1, powers1 = compute_spectrum(tsig_comb, FS) freqs2, powers2 = compute_spectrum(tsig_burst, FS) - plot_spectra_3D([freqs1, freqs2], [powers1, powers2], + plot_spectra_3d([freqs1, freqs2], [powers1, powers2], save_fig=True, file_path=TEST_PLOTS_PATH, file_name='test_plot_spectral3D_1.png') - plot_spectra_3D(freqs1, [powers1, powers2, powers1, powers2], + plot_spectra_3d(freqs1, [powers1, powers2, powers1, powers2], colors=['r', 'y', 'b', 'g'], save_fig=True, file_path=TEST_PLOTS_PATH, file_name='test_plot_spectral3D_2.png') From 5470350f106420d6f8d846c7c7b8644fddfe51de Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Mon, 3 Feb 2025 23:13:44 -0500 Subject: [PATCH 16/16] fix doctest --- neurodsp/plts/spectral.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neurodsp/plts/spectral.py b/neurodsp/plts/spectral.py index 069c50fc..02cf77db 100644 --- a/neurodsp/plts/spectral.py +++ b/neurodsp/plts/spectral.py @@ -266,7 +266,7 @@ def plot_spectra_3d(freqs, powers, log_freqs=False, log_powers=True, colors=None ... 'sim_bursty_oscillation' : {'freq': 10}}) >>> freqs1, powers1 = compute_spectrum(sig1, fs=500) >>> freqs2, powers2 = compute_spectrum(sig2, fs=500) - >>> plot_spectra_3D([freqs1, freqs2], [powers1, powers2]) + >>> plot_spectra_3d([freqs1, freqs2], [powers1, powers2]) """ ax = check_ax_3d(ax)