Skip to content

Commit

Permalink
fix the failing deprecated tests for merging next PR
Browse files Browse the repository at this point in the history
  • Loading branch information
oliche committed May 29, 2024
1 parent eaa2e30 commit 61be186
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
3 changes: 2 additions & 1 deletion src/viewephys/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from one.alf.files import get_session_path
import spikeglx
from ibldsp import voltage, utils
from iblatlas import BrainRegions
from iblatlas.atlas import BrainRegions

from viewephys.gui import viewephys, SNS_PALETTE

Expand Down Expand Up @@ -65,6 +65,7 @@ class ProbeData:
channels: Union[dict, pd.DataFrame] = field(default_factory=dict)
trials: Union[dict, pd.DataFrame] = field(default_factory=dict)
sr: Union[spikeglx.Reader, Streamer, str, Path] = None

def __post_init__(self):
if isinstance(self.sr, str) or isinstance(self.sr, Path):
self.sr = spikeglx.Reader(self.ap_file)
Expand Down
13 changes: 6 additions & 7 deletions src/viewephys/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
import pandas as pd
from brainbox.tests.test_metrics import multiple_spike_trains
from neuropixel import trace_header
from iblutil.util import Bunch


def test_model_dataclass():
# TODO: make a meta-data file fixture to create the bin file for the test
# this test won't run on a machine without the bin file below
st, sa, sc = multiple_spike_trains()
spikes = dict(times=st, clusters=sc, amps=sa)
clusters = dict(channels=np.random.randint(0, 384, np.max(sc)))
channels = trace_header(version=1)
spikes = Bunch(dict(times=st, clusters=sc, amps=sa, depths=sa * 0 + 100))
clusters = Bunch(dict(channels=np.random.randint(0, 384, np.max(sc))))
channels = Bunch(trace_header(version=1))

ProbeData(spikes=spikes, clusters=clusters, channels=channels, ap_file='toto.bin')
ProbeData(spikes=pd.DataFrame(spikes), clusters=pd.DataFrame(clusters), channels=pd.DataFrame(channels), ap_file='toto.bin')
ProbeData(spikes=spikes, clusters=clusters, channels=channels)
ProbeData(spikes=pd.DataFrame(spikes), clusters=pd.DataFrame(clusters), channels=pd.DataFrame(channels))

0 comments on commit 61be186

Please sign in to comment.