Skip to content

Commit

Permalink
style cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelbray32 committed Feb 1, 2024
1 parent f3ea979 commit 0173963
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 26 deletions.
50 changes: 27 additions & 23 deletions src/spyglass/spikesorting/analysis/v1/group.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datajoint as dj
import numpy as np
from itertools import compress
from ripple_detection import get_multiunit_population_firing_rate

from spyglass.common import Session # noqa: F401
Expand All @@ -18,24 +19,27 @@ class SortedSpikesGroupUnitSelectionParams(SpyglassMixin, dj.Manual):
include_labels = Null: longblob
exclude_labels = Null: longblob
"""

@property
def default_params(self):
return {
"unit_filter_params_name": "all_units",
"include_labels": [],
"exclude_labels": [],
}
contents = [
[
"all_units",
[],
[],
],
[
"exclude_noise",
[],
["noise", "mua"],
],
[
"default_exclusion",
[],
["noise", "mua"],
],
]

@classmethod
def insert_default(cls, **kwargs):
"""
Insert default parameter set for position determination
"""
cls.insert1(
{**cls().default_params},
skip_duplicates=True,
)
def insert_default(cls):
cls.insert(cls.contents, skip_duplicates=True)


@schema
Expand Down Expand Up @@ -64,15 +68,13 @@ def create_group(
"nwb_file_name": nwb_file_name,
"unit_filter_params_name": unit_filter_params_name,
}
parts_insert = [{**key, **group_key} for key in keys]

self.insert1(
group_key,
skip_duplicates=True,
)
for key in keys:
self.SortGroup.insert1(
{**key, **group_key},
skip_duplicates=True,
)
self.SortGroup.insert(parts_insert)

@staticmethod
def filter_units(
Expand Down Expand Up @@ -133,8 +135,10 @@ def fetch_spike_data(key, time_slice=None):
include_unit = SortedSpikesGroup.filter_units(
group_label_list, include_labels, exclude_labels
)
from itertools import compress # worth bumping to top of script
sorting_spike_times = list(compress(sorting_spike_times, include_unit))

sorting_spike_times = list(
compress(sorting_spike_times, include_unit)
)

# filter the spike times based on the time slice if provided
if time_slice is not None:
Expand Down
10 changes: 7 additions & 3 deletions src/spyglass/spikesorting/imported.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def add_annotation(
unit id
label : List[str], optional
list of str labels for the unit, by default None
annotations : _type_, optional
annotations : dict, optional
dictionary of other annotation values for unit, by default None
merge_annotations : bool, optional
whether to merge with existing annotations, by default False
Expand All @@ -90,11 +90,15 @@ def add_annotation(
label = [label]
query = self & key
if not len(query) == 1:
raise ValueError(f"ImportedSpikeSorting key must be unique. Found: {query}")
raise ValueError(
f"ImportedSpikeSorting key must be unique. Found: {query}"
)
unit_key = {**key, "id": id}
annotation_query = ImportedSpikeSorting.Annotations & unit_key
if annotation_query and not merge_annotations:
raise ValueError(f"Unit already has annotations: {annotation_query}")
raise ValueError(
f"Unit already has annotations: {annotation_query}"
)
elif annotation_query:
existing_annotations = annotation_query.fetch1()
existing_annotations["label"] += label
Expand Down

0 comments on commit 0173963

Please sign in to comment.