From 6d712ec0073dc75d10c833fca4c1c5413a673e38 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Mon, 9 Dec 2024 17:56:39 +0100 Subject: [PATCH 1/7] Add option to load a dataset in the raw data visualizer upon startup --- src/gui/rawdata_visualizer.py | 77 ++++++++++++++------------- src/iblphotometry_tests/test_plots.py | 8 +++ 2 files changed, 49 insertions(+), 36 deletions(-) diff --git a/src/gui/rawdata_visualizer.py b/src/gui/rawdata_visualizer.py index ba764fd..36aa0d4 100644 --- a/src/gui/rawdata_visualizer.py +++ b/src/gui/rawdata_visualizer.py @@ -40,7 +40,7 @@ def init_ui(self): # Layout for file loading and selection file_layout = QHBoxLayout() self.load_button = QPushButton('Load File', self) - self.load_button.clicked.connect(self.load_file) + self.load_button.clicked.connect(self.open_dialog) file_layout.addWidget(self.load_button) self.column_selector = QComboBox(self) @@ -95,47 +95,50 @@ def init_ui(self): self.setWindowTitle('DataFrame Plotter') self.setGeometry(300, 100, 800, 600) - def load_file(self): + def load_file(self, file_path): + try: + if ( + file_path.endswith('.csv') + or file_path.endswith('.pqt') + or file_path.endswith('.parquet') + ): + self.dfs = from_raw_neurophotometrics_file(file_path) + else: + raise ValueError('Unsupported file format') + + if 'GCaMP' in self.dfs.keys(): + self.df = self.dfs['GCaMP'] + self.times = self.dfs['GCaMP'].index.values + self.plot_time_index = np.arange(0, len(self.times)) + self.filtered_df = None + else: + raise ValueError('No GCaMP found') + + if 'Isosbestic' in self.dfs.keys(): + self.dfiso = self.dfs['Isosbestic'] + + # Display the dataframe in the table + # self.display_dataframe() + # Update the column selector + self.update_column_selector() + + # Load into Pynapple dataframe + self.dfs = from_raw_neurophotometrics_file(file_path) + + # Set filter combo box + self.filter_selector.setCurrentIndex(0) # Reset to "Select Filter" + + except Exception as e: + print(f'Error loading file: {e}') + + def open_dialog(self): # Open a file dialog to choose the CSV or PQT file file_path, _ = QFileDialog.getOpenFileName( self, 'Open File', '', 'CSV and PQT Files (*.csv *.pqt);;All Files (*)' ) if file_path: # Load the file into a DataFrame based on its extension - try: - if ( - file_path.endswith('.csv') - or file_path.endswith('.pqt') - or file_path.endswith('.parquet') - ): - self.dfs = from_raw_neurophotometrics_file(file_path) - else: - raise ValueError('Unsupported file format') - - if 'GCaMP' in self.dfs.keys(): - self.df = self.dfs['GCaMP'] - self.times = self.dfs['GCaMP'].index.values - self.plot_time_index = np.arange(0, len(self.times)) - self.filtered_df = None - else: - raise ValueError('No GCaMP found') - - if 'Isosbestic' in self.dfs.keys(): - self.dfiso = self.dfs['Isosbestic'] - - # Display the dataframe in the table - # self.display_dataframe() - # Update the column selector - self.update_column_selector() - - # Load into Pynapple dataframe - self.dfs = from_raw_neurophotometrics_file(file_path) - - # Set filter combo box - self.filter_selector.setCurrentIndex(0) # Reset to "Select Filter" - - except Exception as e: - print(f'Error loading file: {e}') + self.load_file(file_path) # TODO this does not work with pynapple as format, convert back to pandas DF # def display_dataframe(self): @@ -286,5 +289,7 @@ def filter_mad(self, df): if __name__ == '__main__': app = QApplication(sys.argv) window = DataFrameVisualizerApp() + if len(sys.argv) >= 2: + window.load_file(sys.argv[1]) window.show() sys.exit(app.exec_()) diff --git a/src/iblphotometry_tests/test_plots.py b/src/iblphotometry_tests/test_plots.py index db7ca41..b3489e7 100644 --- a/src/iblphotometry_tests/test_plots.py +++ b/src/iblphotometry_tests/test_plots.py @@ -1,3 +1,4 @@ +import unittest import pandas as pd from pathlib import Path @@ -186,3 +187,10 @@ def test_plot_event_tick(self): df_nph, t_events, fs = self.get_synthetic_data() plots.plot_event_tick(t_events) plt.close('all') + + +if __name__ == '__main__': + suite = unittest.TestSuite() + suite.addTest(TestPlotters("test_class_plotsignalresponse")) + runner = unittest.TextTestRunner() + runner.run(suite) From 251af9ac1f48f2c9b938e6d61e97296f5882e9e6 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 10 Dec 2024 19:25:48 +0100 Subject: [PATCH 2/7] WIP: adding fiber behavior GUI --- src/gui/rawdata_visualizer.py | 103 +++++++++++++++++++++++++- src/iblphotometry/plots.py | 29 +++++--- src/iblphotometry_tests/test_plots.py | 29 +++++++- 3 files changed, 143 insertions(+), 18 deletions(-) diff --git a/src/gui/rawdata_visualizer.py b/src/gui/rawdata_visualizer.py index 36aa0d4..ae8bc20 100644 --- a/src/gui/rawdata_visualizer.py +++ b/src/gui/rawdata_visualizer.py @@ -18,6 +18,7 @@ import iblphotometry.preprocessing as ffpr import numpy as np +import pandas as pd class DataFrameVisualizerApp(QWidget): @@ -30,6 +31,8 @@ def __init__(self): self.plot_time_index = None + self.behavior_gui = None + self.filtered_df = None # Filtered DataFrame used for plotting only self.init_ui() @@ -56,6 +59,10 @@ def init_ui(self): self.filter_selector.currentIndexChanged.connect(self.apply_filter) file_layout.addWidget(self.filter_selector) + self.behavior_button = QPushButton('Open behavior GUI', self) + self.behavior_button.clicked.connect(self.open_behavior_gui) + file_layout.addWidget(self.behavior_button) + # # Table widget to display DataFrame # self.table = QTableWidget(self) # self.table.setSelectionMode(QTableWidget.SingleSelection) @@ -238,9 +245,9 @@ def on_column_header_clicked(self, logical_index): # Update the plots based on the selected column self.update_plots() - def apply_filter(self): + def apply_filter(self, filter_idx, filter_option=None): # Get the selected filter option from the filter dropdown - filter_option = self.filter_selector.currentText() + filter_option = filter_option or self.filter_selector.currentText() if filter_option == 'Select Filter': self.filtered_df = None @@ -285,6 +292,98 @@ def filter_mad(self, df): # filtered_df[col] = filtered_df[col].apply(lambda x: x if x <= 100 else None) # return filtered_df + def open_behavior_gui(self): + signal = self.plotobj.processed_signal + + if self.behavior_gui is None: + self.behavior_gui = BehaviorVisualizerGUI() + assert self.behavior_gui is not None + + if signal is None: + print("Apply a filter before opening the Behavior GUI") + else: + print("Opening Behavior GUI") + self.behavior_gui.set_data(signal, self.times) + self.behavior_gui.show() + + +class BehaviorVisualizerGUI(QWidget): + def __init__(self, ): + super().__init__() + self.trials = None + self.init_ui() + + def set_data(self, processed_signal, times): + assert processed_signal is not None + assert times is not None + self.processed_signal = processed_signal + self.times = times + + def init_ui(self): + # Create layout + main_layout = QVBoxLayout() + + # Layout for file loading and selection + file_layout = QHBoxLayout() + self.load_button = QPushButton('Load File', self) + self.load_button.clicked.connect(self.open_dialog) + file_layout.addWidget(self.load_button) + + main_layout.addLayout(file_layout) + + # Set up plots layout + self.plot_layout = QGridLayout() + self.plotobj = plots.PlotSignalResponse() + self.figure, self.axes = self.plotobj.set_fig_layout() + self.canvas = FigureCanvas(self.figure) + self.plot_layout.addWidget(self.canvas, 0, 0, 1, 3) + + # Create a NavigationToolbar + self.toolbar = NavigationToolbar(self.canvas, self) + + main_layout.addLayout(self.plot_layout) + self.setLayout(main_layout) + + self.setWindowTitle('Behavior Visualizer') + self.setGeometry(300, 100, 800, 600) + + def load_trials(self, trials): + assert trials is not None + self.trials = trials + self.update_plots() + + def load_file(self, file_path): + # load a trial file + try: + if ( + file_path.endswith('.pqt') + or file_path.endswith('.parquet') + ): + self.load_trials(pd.read_parquet(file_path)) + else: + raise ValueError('Unsupported file format') + except Exception as e: + print(f'Error loading file: {e}') + + def open_dialog(self): + file_path, _ = QFileDialog.getOpenFileName( + self, 'Open File', '', 'CSV and PQT Files (*.csv *.pqt);;All Files (*)' + ) + if file_path: + self.load_file(file_path) + + def update_plots(self): + self.figure.clear() + + self.plotobj.set_data( + self.trials, self.processed_signal, self.times, + ) + # NOTE: we need to update the layout as it depends on the data + self.figure, self.axes = self.plotobj.set_fig_layout(figure=self.figure) + self.plotobj.plot_trialsort_psth(self.axes) + + self.canvas.draw() + if __name__ == '__main__': app = QApplication(sys.argv) diff --git a/src/iblphotometry/plots.py b/src/iblphotometry/plots.py index 2f4bb67..5c5b27a 100644 --- a/src/iblphotometry/plots.py +++ b/src/iblphotometry/plots.py @@ -41,9 +41,6 @@ def set_axis_style(ax, fontsize=10, **kwargs): class PlotSignal: - # def __init__(self, *args, **kwargs): - # self.set_data(*args, **kwargs) - def set_data( self, raw_signal, times, raw_isosbestic=None, processed_signal=None, fs=None ): @@ -143,7 +140,10 @@ def raw_processed_figure2(self, axd): class PlotSignalResponse: - def __init__( + def __init__(self): + self.psth_dict = {} + + def set_data( self, trials, processed_signal, times, fs=None, event_window=np.array([-1, 2]) ): self.trials = trials @@ -187,10 +187,20 @@ def update_psth_dict(self, event): except KeyError: warnings.warn(f'Event {event} not found in trials table.') - def plot_trialsort_psth(self): - fig, axs = plt.subplots(2, len(self.psth_dict.keys()) - 1) + def set_fig_layout(self, figure=None): + n = max(1, len(self.psth_dict.keys()) - 1) + if figure is None: + figure, axs = plt.subplots(2, n, squeeze=False) + else: + axs = figure.subplots(2, n, squeeze=False) + figure.tight_layout() + return figure, axs + def plot_trialsort_psth(self, axs): signal_keys = [k for k in self.psth_dict.keys() if k != 'times'] + if axs.shape[1] < len(signal_keys): + raise ValueError("Error, skipping PSTH plotting") + for iaxs, event in enumerate(signal_keys): axs_plt = [axs[0, iaxs], axs[1, iaxs]] plot_psth(self.psth_dict[event], self.psth_dict['times'], axs=axs_plt) @@ -206,17 +216,12 @@ def plot_trialsort_psth(self): if iaxs > 0: axs[0, iaxs].axis('off') axs[1, iaxs].set_yticks([]) - fig.tight_layout() - return fig, axs - def plot_processed_trialtick(self, event_key='stimOn_times'): - fig, ax = plt.subplots(1, 1) - plt.figure(figsize=(10, 6)) + def plot_processed_trialtick(self, ax, event_key='stimOn_times'): events = self.trials[event_key] ax.set_ylim([-0.2, 0.1]) plot_event_tick(events, ax=ax, color='#FFC0CB', ls='-') plot_processed_signal(self.processed_signal, self.times, ax=ax) - return fig, ax """ diff --git a/src/iblphotometry_tests/test_plots.py b/src/iblphotometry_tests/test_plots.py index b3489e7..8b460ea 100644 --- a/src/iblphotometry_tests/test_plots.py +++ b/src/iblphotometry_tests/test_plots.py @@ -1,3 +1,4 @@ +import sys import unittest import pandas as pd from pathlib import Path @@ -7,6 +8,7 @@ from iblphotometry.behavior import psth, psth_times import iblphotometry.plots as plots +from gui.rawdata_visualizer import DataFrameVisualizerApp, BehaviorVisualizerGUI from iblphotometry.synthetic import synthetic101 import iblphotometry.preprocessing as ffpr from iblphotometry_tests.base_tests import PhotometryDataTestCase @@ -95,9 +97,13 @@ def test_class_plotsignalresponse(self): # eid = '77a6741c-81cc-475f-9454-a9b997be02a4' # trials = one.load_object(eid, 'trials') trials = pd.read_parquet(self.paths['trials_table_kcenia_pqt']) - plotobj = plots.PlotSignalResponse(trials, processed_signal, times) - plotobj.plot_trialsort_psth() - plotobj.plot_processed_trialtick() + plotobj = plots.PlotSignalResponse() + plotobj.set_data(trials, processed_signal, times) + fig, axs = plotobj.set_fig_layout() + plotobj.plot_trialsort_psth(fig, axs) + fig, ax = plt.subplots(1, 1) + plotobj.plot_processed_trialtick(ax) + plt.show() plt.close('all') """ @@ -188,9 +194,24 @@ def test_plot_event_tick(self): plots.plot_event_tick(t_events) plt.close('all') + def test_gui(self): + df_nph, _, fs = self.get_test_data() + processed_signal = df_nph['signal_processed'].values + times = df_nph['times'].values + trials = pd.read_parquet(self.paths['trials_table_kcenia_pqt']) + + from PyQt5.QtWidgets import QApplication + app = QApplication(sys.argv) + window = BehaviorVisualizerGUI() + window.set_data(processed_signal, times) + window.load_trials(trials) + window.show() + # Uncomment to debug + # app.exec_() + if __name__ == '__main__': suite = unittest.TestSuite() - suite.addTest(TestPlotters("test_class_plotsignalresponse")) + suite.addTest(TestPlotters("test_gui")) runner = unittest.TextTestRunner() runner.run(suite) From efef55325d0381336395001946bc542c57a9267d Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 10 Dec 2024 19:42:38 +0100 Subject: [PATCH 3/7] Ruff --- src/iblphotometry_tests/test_plots.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/iblphotometry_tests/test_plots.py b/src/iblphotometry_tests/test_plots.py index 8b460ea..2989837 100644 --- a/src/iblphotometry_tests/test_plots.py +++ b/src/iblphotometry_tests/test_plots.py @@ -8,7 +8,7 @@ from iblphotometry.behavior import psth, psth_times import iblphotometry.plots as plots -from gui.rawdata_visualizer import DataFrameVisualizerApp, BehaviorVisualizerGUI +from gui.rawdata_visualizer import BehaviorVisualizerGUI from iblphotometry.synthetic import synthetic101 import iblphotometry.preprocessing as ffpr from iblphotometry_tests.base_tests import PhotometryDataTestCase @@ -201,7 +201,7 @@ def test_gui(self): trials = pd.read_parquet(self.paths['trials_table_kcenia_pqt']) from PyQt5.QtWidgets import QApplication - app = QApplication(sys.argv) + app = QApplication(sys.argv) # noqa window = BehaviorVisualizerGUI() window.set_data(processed_signal, times) window.load_trials(trials) From 3ca5b9621c5a3b288601dcb93fc416351a56be85 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 10 Dec 2024 19:47:23 +0100 Subject: [PATCH 4/7] Comment out GUI test for CI --- src/iblphotometry_tests/test_plots.py | 42 +++++++++++++-------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/src/iblphotometry_tests/test_plots.py b/src/iblphotometry_tests/test_plots.py index 2989837..03cd278 100644 --- a/src/iblphotometry_tests/test_plots.py +++ b/src/iblphotometry_tests/test_plots.py @@ -194,24 +194,24 @@ def test_plot_event_tick(self): plots.plot_event_tick(t_events) plt.close('all') - def test_gui(self): - df_nph, _, fs = self.get_test_data() - processed_signal = df_nph['signal_processed'].values - times = df_nph['times'].values - trials = pd.read_parquet(self.paths['trials_table_kcenia_pqt']) - - from PyQt5.QtWidgets import QApplication - app = QApplication(sys.argv) # noqa - window = BehaviorVisualizerGUI() - window.set_data(processed_signal, times) - window.load_trials(trials) - window.show() - # Uncomment to debug - # app.exec_() - - -if __name__ == '__main__': - suite = unittest.TestSuite() - suite.addTest(TestPlotters("test_gui")) - runner = unittest.TextTestRunner() - runner.run(suite) + # def test_gui(self): + # df_nph, _, fs = self.get_test_data() + # processed_signal = df_nph['signal_processed'].values + # times = df_nph['times'].values + # trials = pd.read_parquet(self.paths['trials_table_kcenia_pqt']) + + # from PyQt5.QtWidgets import QApplication + # app = QApplication(sys.argv) + # window = BehaviorVisualizerGUI() + # window.set_data(processed_signal, times) + # window.load_trials(trials) + # window.show() + # # Uncomment to debug + # app.exec_() + + +# if __name__ == '__main__': +# suite = unittest.TestSuite() +# suite.addTest(TestPlotters("test_gui")) +# runner = unittest.TextTestRunner() +# runner.run(suite) From d317eef8cd50adb2b34843cc3e046bcfc34133af Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 10 Dec 2024 20:28:16 +0100 Subject: [PATCH 5/7] Ruff --- src/iblphotometry_tests/test_plots.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/iblphotometry_tests/test_plots.py b/src/iblphotometry_tests/test_plots.py index 03cd278..4f165f3 100644 --- a/src/iblphotometry_tests/test_plots.py +++ b/src/iblphotometry_tests/test_plots.py @@ -1,5 +1,5 @@ -import sys -import unittest +# import sys +# import unittest import pandas as pd from pathlib import Path @@ -8,7 +8,7 @@ from iblphotometry.behavior import psth, psth_times import iblphotometry.plots as plots -from gui.rawdata_visualizer import BehaviorVisualizerGUI +# from gui.rawdata_visualizer import BehaviorVisualizerGUI from iblphotometry.synthetic import synthetic101 import iblphotometry.preprocessing as ffpr from iblphotometry_tests.base_tests import PhotometryDataTestCase From 5f28f8a543d65fcfc4c137fbd1a18eba03a85cac Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 10 Dec 2024 20:38:49 +0100 Subject: [PATCH 6/7] WIP: fixing tests --- src/iblphotometry_tests/test_plots.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/iblphotometry_tests/test_plots.py b/src/iblphotometry_tests/test_plots.py index 4f165f3..6ec9744 100644 --- a/src/iblphotometry_tests/test_plots.py +++ b/src/iblphotometry_tests/test_plots.py @@ -1,5 +1,4 @@ -# import sys -# import unittest +import unittest import pandas as pd from pathlib import Path @@ -99,11 +98,11 @@ def test_class_plotsignalresponse(self): trials = pd.read_parquet(self.paths['trials_table_kcenia_pqt']) plotobj = plots.PlotSignalResponse() plotobj.set_data(trials, processed_signal, times) - fig, axs = plotobj.set_fig_layout() - plotobj.plot_trialsort_psth(fig, axs) - fig, ax = plt.subplots(1, 1) + _, axs = plotobj.set_fig_layout() + plotobj.plot_trialsort_psth(axs) + _, ax = plt.subplots(1, 1) plotobj.plot_processed_trialtick(ax) - plt.show() + # plt.show() plt.close('all') """ @@ -210,8 +209,9 @@ def test_plot_event_tick(self): # app.exec_() -# if __name__ == '__main__': -# suite = unittest.TestSuite() -# suite.addTest(TestPlotters("test_gui")) -# runner = unittest.TextTestRunner() -# runner.run(suite) +if __name__ == '__main__': + unittest.main() + # suite = unittest.TestSuite() + # suite.addTest(TestPlotters()) + # runner = unittest.TextTestRunner() + # runner.run(suite) From b3a8d18f95a54b1bd342c6d46cdf37994d0575e9 Mon Sep 17 00:00:00 2001 From: Cyrille Rossant Date: Tue, 10 Dec 2024 20:54:40 +0100 Subject: [PATCH 7/7] Ruff format --- src/gui/rawdata_visualizer.py | 17 +++++++++-------- src/iblphotometry/plots.py | 2 +- src/iblphotometry_tests/test_plots.py | 1 + 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/gui/rawdata_visualizer.py b/src/gui/rawdata_visualizer.py index ae8bc20..e84ef5a 100644 --- a/src/gui/rawdata_visualizer.py +++ b/src/gui/rawdata_visualizer.py @@ -300,15 +300,17 @@ def open_behavior_gui(self): assert self.behavior_gui is not None if signal is None: - print("Apply a filter before opening the Behavior GUI") + print('Apply a filter before opening the Behavior GUI') else: - print("Opening Behavior GUI") + print('Opening Behavior GUI') self.behavior_gui.set_data(signal, self.times) self.behavior_gui.show() class BehaviorVisualizerGUI(QWidget): - def __init__(self, ): + def __init__( + self, + ): super().__init__() self.trials = None self.init_ui() @@ -355,10 +357,7 @@ def load_trials(self, trials): def load_file(self, file_path): # load a trial file try: - if ( - file_path.endswith('.pqt') - or file_path.endswith('.parquet') - ): + if file_path.endswith('.pqt') or file_path.endswith('.parquet'): self.load_trials(pd.read_parquet(file_path)) else: raise ValueError('Unsupported file format') @@ -376,7 +375,9 @@ def update_plots(self): self.figure.clear() self.plotobj.set_data( - self.trials, self.processed_signal, self.times, + self.trials, + self.processed_signal, + self.times, ) # NOTE: we need to update the layout as it depends on the data self.figure, self.axes = self.plotobj.set_fig_layout(figure=self.figure) diff --git a/src/iblphotometry/plots.py b/src/iblphotometry/plots.py index 5c5b27a..8f2a1c7 100644 --- a/src/iblphotometry/plots.py +++ b/src/iblphotometry/plots.py @@ -199,7 +199,7 @@ def set_fig_layout(self, figure=None): def plot_trialsort_psth(self, axs): signal_keys = [k for k in self.psth_dict.keys() if k != 'times'] if axs.shape[1] < len(signal_keys): - raise ValueError("Error, skipping PSTH plotting") + raise ValueError('Error, skipping PSTH plotting') for iaxs, event in enumerate(signal_keys): axs_plt = [axs[0, iaxs], axs[1, iaxs]] diff --git a/src/iblphotometry_tests/test_plots.py b/src/iblphotometry_tests/test_plots.py index 6ec9744..b97e0ce 100644 --- a/src/iblphotometry_tests/test_plots.py +++ b/src/iblphotometry_tests/test_plots.py @@ -7,6 +7,7 @@ from iblphotometry.behavior import psth, psth_times import iblphotometry.plots as plots + # from gui.rawdata_visualizer import BehaviorVisualizerGUI from iblphotometry.synthetic import synthetic101 import iblphotometry.preprocessing as ffpr