Skip to content

Commit

Permalink
Merge pull request #3 from int-brain-lab/gc_pick
Browse files Browse the repository at this point in the history
Pick tested
  • Loading branch information
GaelleChapuis authored May 30, 2024
2 parents 61be186 + f45cc8b commit 15b3bad
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 28 deletions.
124 changes: 96 additions & 28 deletions src/viewephys/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,19 +122,83 @@ def closeEvent(self, event):
self.close()


class PickSpikes():

def __init__(self):
default_df = self.init_df()
self.update_pick(default_df)

def init_df(self, nrow=0):
init_df = pd.DataFrame({
'sample': np.zeros(nrow, dtype=np.int32),
'trace': np.zeros(nrow, dtype=np.int32) * -1,
'amp': np.zeros(nrow, dtype=np.int32),
'group': np.zeros(nrow, dtype=np.int32),
})
return init_df

def update_pick(self, df):
self.picks = df
self.pick_index = df.shape[0] # Last index of spike picked (== len of df table)
self.pick_group = df['group'].max() # Last group created

def load_df(self, df):
'''
Load a dataframe that contains already picked spikes
:return:
'''
default_df = self.init_df()

if isinstance(df, pd.DataFrame):
# check all keys are in
indxmissing = np.where(~df.columns.isin(default_df.columns))[0]
if len(indxmissing) > 0:
raise ValueError(f'df does not contain column {default_df.columns[indxmissing]}')
self.update_pick(df)
else:
raise ValueError('df input is not pd.DataFrame')

def new_row_frompick(self, sample=None, trace=None, amp=None, group=None):
new_row = self.init_df(nrow=1)
new_row['sample'] = sample
new_row['trace'] = trace
new_row['amp'] = amp
new_row['group'] = group
return new_row

def add_spike(self, new_row):
df = self.picks
# Check columns of new row
indxmissing = np.where(~df.columns.isin(new_row.columns))[0]
if len(indxmissing) > 0:
raise ValueError(f'new_row does not contain column {df.columns[indxmissing]}')
# Append new row
df_updated = pd.concat([df, new_row])
df_updated = df_updated.reset_index(drop=True)
self.update_pick(df_updated)

def remove_spike(self, indx_remove):
df = self.picks
if df.shape[0] > 0 and len(indx_remove) > 0: # Update only if non-empty
df_updated = df.drop(indx_remove).copy()
df_updated = df_updated.reset_index(drop=True)
self.update_pick(df_updated)


def indx_select(self, sample, trace, s_range=0.5 * 30000, tr_range=3):
iclose = np.where(np.logical_and(
np.abs(self.picks['sample'] - sample) <= (s_range + 1),
np.abs(self.picks['trace'] - trace) <= (tr_range + 1)
))[0]
return iclose

class EphysViewer(EasyQC):
keyPressed = QtCore.pyqtSignal(int)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ctrl.model.picks = pd.DataFrame({
'sample': np.zeros(N_SAMPLES_INIT, dtype=np.int32),
'trace': np.zeros(N_SAMPLES_INIT, dtype=np.int32) * -1,
'amp': np.zeros(N_SAMPLES_INIT, dtype=np.int32),
'group': np.zeros(N_SAMPLES_INIT, dtype=np.int32),
})
self.ctrl.model.pick_index = 0
self.ctrl.model.pick_group = 0

self.ctrl.model.pickspikes = PickSpikes()
self.menufile.setEnabled(True)
self.settings = QtCore.QSettings('int-brain-lab', 'EphysViewer')
self.header_curves = {}
Expand Down Expand Up @@ -225,28 +289,29 @@ def mouseClickPickingEvent(self, event):
"""

if event.buttons() == QtCore.Qt.RightButton:
self.ctrl.model.pick_group += 1
self.ctrl.model.pickspikes.pick_group += 1 # TODO check logic of incrementing here
if event.buttons() != QtCore.Qt.LeftButton:
return
TR_RANGE = 3
S_RANGE = int(0.5 / self.ctrl.model.si)
qxy = self.imageItem_seismic.mapFromScene(event.scenePos())
s, tr = (qxy.x(), qxy.y())
# if event.buttons() == QtCore.Qt.MiddleButton:
match event.modifiers():
match event.modifiers(): # upon clicking:
# --- Remove a spike when shift key is pressed
case QtCore.Qt.KeyboardModifier.ShiftModifier:
iclose = np.where(np.logical_and(
np.abs(self.ctrl.model.picks['sample'] - s) <= (S_RANGE + 1),
np.abs(self.ctrl.model.picks['trace'] - tr) <= (TR_RANGE + 1)
))[0]
self.ctrl.model.picks.drop(iclose, inplace=True)
self.ctrl.model.pick_index -= iclose.size
return
i_remv = self.ctrl.model.pickspikes.indx_select(
sample=s, trace=tr, s_range=S_RANGE, tr_range=TR_RANGE)
self.ctrl.model.pickspikes.remove_spike(i_remv)
tmax = None

# --- Add a spike
case QtCore.Qt.ControlModifier:
# the control modifier prevents wrapping around the maximum number of picks
# the control modifier prevents wrapping around the nearby maximal voltage
tmax, xmax = (int(round(s)), int(round(tr)))
# this is the automatic wrapping around the maximum number of picks

case _:
# if no key is pressed and click, automatic wrapping around the nearby maximal voltage
xscale = np.arange(-TR_RANGE, TR_RANGE + 1) + np.round(tr).astype(np.int32)
tscale = np.arange(-S_RANGE, S_RANGE + 1) + np.round(s).astype(np.int32)
ix = slice(xscale[0], xscale[-1] + 1)
Expand All @@ -259,16 +324,19 @@ def mouseClickPickingEvent(self, event):
tmax, xmax = np.unravel_index(np.argmax(np.abs(self.ctrl.model.data[it, ix])),
(S_RANGE * 2 + 1, TR_RANGE * 2 + 1))
tmax, xmax = (tscale[tmax], xscale[xmax])
# we add the spike to the dataframe
i = self.ctrl.model.pick_index
self.ctrl.model.picks.at[i, 'sample'] = tmax
self.ctrl.model.picks.at[i, 'trace'] = xmax
self.ctrl.model.picks.at[i, 'amp'] = self.ctrl.model.data[tmax, xmax]
self.ctrl.model.picks.at[i, 'group'] = self.ctrl.model.pick_group
self.ctrl.model.pick_index += 1

if tmax is not None: # When spike is added
# we add the spike to the dataframe
amp = self.ctrl.model.data[tmax, xmax]
group = 0 # TODO group
# Create new row
new_row = self.ctrl.model.pickspikes.new_row_frompick(
sample=tmax, trace=xmax, amp=amp, group=group)
self.ctrl.model.pickspikes.add_spike(new_row=new_row)

# updates scatter plot
self.ctrl.add_scatter(self.ctrl.model.picks['sample'] * self.ctrl.model.si,
self.ctrl.model.picks['trace'],
self.ctrl.add_scatter(self.ctrl.model.pickspikes.picks['sample'] * self.ctrl.model.si,
self.ctrl.model.pickspikes.picks['trace'],
label='_picks', rgb=PICK_COLOR)

def save_current_plot(self, filename):
Expand Down
1 change: 1 addition & 0 deletions src/viewephys/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ def test_model_dataclass():

ProbeData(spikes=spikes, clusters=clusters, channels=channels)
ProbeData(spikes=pd.DataFrame(spikes), clusters=pd.DataFrame(clusters), channels=pd.DataFrame(channels))

86 changes: 86 additions & 0 deletions src/viewephys/tests/test_pick.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from viewephys.gui import PickSpikes
import numpy as np
import pandas as pd

ps = PickSpikes()
DEFAULT_DF_COLUMNS = ['sample', 'trace', 'amp', 'group']


def test_init_df():
df = ps.init_df(nrow=0)
# Check size
np.testing.assert_equal(df.shape[0], 0)
# Check column names
indxmissing = np.where(~df.columns.isin(DEFAULT_DF_COLUMNS))[0]
np.testing.assert_(len(indxmissing) == 0)


def test_new_row_frompick():
new_row = ps.new_row_frompick(sample=1, trace=2, amp=3, group=4)
# Check size
np.testing.assert_(new_row.shape[0] == 1)
# Check column names
indxmissing = np.where(~new_row.columns.isin(DEFAULT_DF_COLUMNS))[0]
np.testing.assert_(len(indxmissing) == 0)
# Check values
np.testing.assert_(new_row['sample'][0] == 1)
np.testing.assert_(new_row['trace'][0] == 2)
np.testing.assert_(new_row['amp'][0] == 3)
np.testing.assert_(new_row['group'][0] == 4)


def test_update_pick():
# Create empty df
df = ps.init_df(nrow=0)
# Update
ps.update_pick(df)
# Test
pd.testing.assert_frame_equal(ps.picks, df)
np.testing.assert_(ps.pick_index == 0)
np.testing.assert_(np.isnan(ps.pick_group))

# ----
# Create filled df (2 rows)
df1 = ps.new_row_frompick(sample=1, trace=2, amp=3, group=4)
df2 = ps.new_row_frompick(sample=3, trace=2, amp=3, group=5)
df = pd.concat([df1, df2])
# Update
ps.update_pick(df)
# Test
pd.testing.assert_frame_equal(ps.picks, df)
np.testing.assert_(ps.pick_index == 2)
np.testing.assert_(ps.pick_group == 5)


def test_add_spike():
df1 = ps.new_row_frompick(sample=1, trace=2, amp=3, group=4)
df2 = ps.new_row_frompick(sample=3, trace=2, amp=3, group=5)
df = pd.concat([df1, df2])
df = df.reset_index(drop=True)
ps.update_pick(df1)
ps.add_spike(new_row=df2)
# Test
pd.testing.assert_frame_equal(ps.picks, df)
np.testing.assert_(ps.pick_index == 2)
np.testing.assert_(ps.pick_group == 5)
np.testing.assert_(ps.picks.index[0] == 0) # This is redundant with testing equal to df
np.testing.assert_(ps.picks.index[1] == 1) # But better be 100% sure


def test_remove_spike():
df1 = ps.new_row_frompick(sample=1, trace=2, amp=3, group=4)
df2 = ps.new_row_frompick(sample=3, trace=2, amp=3, group=5)
df3 = ps.new_row_frompick(sample=6, trace=6, amp=3, group=5)
df4 = ps.new_row_frompick(sample=7, trace=6, amp=3, group=5)
df = pd.concat([df1, df2, df3, df4])
df = df.reset_index(drop=True)
# Update
ps.update_pick(df)
# Remove 2nd and 4th spikes (== index 1, 3; Warning: index starts at 0)
indx_remove = np.array([1, 3])
# So we keep only the 1st and 3rd spike
df_test = pd.concat([df1, df3])
df_test = df_test.reset_index(drop=True)

ps.remove_spike(indx_remove=indx_remove)
pd.testing.assert_frame_equal(ps.picks, df_test)

0 comments on commit 15b3bad

Please sign in to comment.