Skip to content

Commit

Permalink
Export updates
Browse files Browse the repository at this point in the history
  • Loading branch information
CBroz1 committed Feb 19, 2025
1 parent 4812727 commit c9e05ee
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Ensure merge tables are declared during file insertion #1205
- Update URL for DANDI Docs #1210
- Improve cron job documentation and script #1226
- Update export process to include `~external` tables #XXXX

### Pipelines

Expand Down
11 changes: 7 additions & 4 deletions src/spyglass/common/common_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,6 @@ def _add_externals_to_restr_graph(
)
restr_graph.visited.add(analysis_name)

restr_graph.visited.update({raw_name, analysis_name})

return restr_graph

def get_restr_graph(
Expand All @@ -255,9 +253,14 @@ def get_restr_graph(
)

restr_graph = RestrGraph(
seed_table=self, leaves=leaves, verbose=verbose, cascade=cascade
seed_table=self, leaves=leaves, verbose=verbose, cascade=False
)
return self._add_externals_to_restr_graph(restr_graph, key)
restr_graph = self._add_externals_to_restr_graph(restr_graph, key)

if cascade:
restr_graph.cascade()

return restr_graph

def preview_tables(self, **kwargs) -> list[dj.FreeTable]:
"""Return a list of restricted FreeTables for a given restriction/key.
Expand Down
22 changes: 22 additions & 0 deletions src/spyglass/position/position_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,25 @@ def fetch1_dataframe(self) -> DataFrame:
& key
)
return query.fetch1_dataframe()

def fetch_pose_dataframe(self):
key = self.merge_restrict(self.fetch("KEY", as_dict=True)).fetch(
"KEY", as_dict=True
)
query = (
source_class_dict[
to_camel_case(self.merge_get_parent(self.proj()).table_name)
]
& key
)
return query.fetch_pose_dataframe()

def fetch_video_path(self, key=dict()):
key = self.merge_restrict((self & key).proj()).proj()
query = (
source_class_dict[
to_camel_case(self.merge_get_parent(self.proj()).table_name)
]
& key
)
return query.fetch_video_path()
2 changes: 1 addition & 1 deletion src/spyglass/position/v1/position_dlc_pose_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def _logged_make(self, key):

def fetch_dataframe(self, *attrs, **kwargs) -> pd.DataFrame:
"""Fetch a concatenated dataframe of all bodyparts."""
entries = (self.BodyPart & self).fetch("KEY")
entries = (self.BodyPart & self).fetch("KEY", log_export=False)
nwb_data_dict = {
entry["bodypart"]: (self.BodyPart() & entry).fetch_nwb()[0]
for entry in entries
Expand Down
36 changes: 36 additions & 0 deletions src/spyglass/utils/dj_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from functools import cached_property
from hashlib import md5 as hash_md5
from itertools import chain as iter_chain
from pathlib import Path
from typing import Any, Dict, Iterable, List, Set, Tuple, Union

from datajoint import FreeTable, Table
from datajoint import config as dj_config
from datajoint.condition import make_condition
from datajoint.hash import key_hash
from datajoint.user_tables import TableMeta
Expand Down Expand Up @@ -782,6 +784,12 @@ def analysis_file_tbl(self) -> Table:

return AnalysisNwbfile()

@property
def file_externals(self):
from spyglass.common.common_nwbfile import schema

return schema.external

def cascade_files(self):
"""Set node attribute for analysis files."""
analysis_pk = self.analysis_file_tbl.primary_key
Expand All @@ -791,6 +799,34 @@ def cascade_files(self):
files = list(ft.fetch(*analysis_pk))
self._set_node(ft, "files", files)

stores = dj_config["stores"]

analysis_paths = [
str(Path(p).relative_to(stores["analysis"]["location"]))
for p in self._get_ft(
self.analysis_file_tbl.full_table_name, with_restr=True
).fetch("analysis_file_abs_path")
]
self._set_restr(
self.file_externals["analysis"],
f"filepath in {tuple(analysis_paths)}",
)

raw_paths = [
str(Path(p).relative_to(stores["raw"]["location"]))
for p in self._get_ft(
"`common_nwbfile`.`nwbfile`", with_restr=True
).fetch("nwb_file_abs_path")
]
if len(raw_paths) == 1:
self._set_restr(
self.file_externals["raw"], f"filepath = '{raw_paths[0]}'"
)
elif len(raw_paths) > 1:
self._set_restr(
self.file_externals["raw"], f"filepath in {tuple(raw_paths)}"
)

@property
def file_dict(self) -> Dict[str, List[str]]:
"""Return dictionary of analysis files from all visited nodes.
Expand Down
4 changes: 2 additions & 2 deletions src/spyglass/utils/dj_merge_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,8 @@ def _merge_repr(
datajoint.expression.Union
"""

parts = [
cls() * p # join with master to include sec key (i.e., 'source')
parts = [ # join with master to include sec key (i.e., 'source')
cls().join(p, log_export=False)
for p in cls._merge_restrict_parts(
restriction=restriction,
add_invalid_restrict=False,
Expand Down
24 changes: 19 additions & 5 deletions src/spyglass/utils/mixins/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from os import environ
from re import match as re_match

from datajoint.condition import AndList, make_condition
from datajoint.condition import AndList, Top, make_condition
from datajoint.expression import QueryExpression
from datajoint.table import Table
from packaging.version import parse as version_parse

Expand Down Expand Up @@ -262,6 +263,7 @@ def _run_join(self, **kwargs):
table_list.append(other) # can other._log_fetch
else:
logger.warning(f"Cannot export log join for\n{other}")
__import__("pdb").set_trace()

joined = self.proj().join(other.proj(), log_export=False)
for table in table_list: # log separate for unique pks
Expand Down Expand Up @@ -290,11 +292,20 @@ def _run_with_log(self, method, *args, log_export=True, **kwargs):
if getattr(method, "__name__", None) == "join": # special case
self._run_join(**kwargs)
else:
self._log_fetch(restriction=kwargs.get("restriction"))
restr = kwargs.get("restriction")
if isinstance(restr, QueryExpression) and getattr(
restr, "restriction"
):
restr = restr.restriction
self._log_fetch(restriction=restr)
logger.debug(f"Export: {self._called_funcs()}")

return ret

def is_restr(self, restr) -> bool:
"""Check if a restriction is actually restricting."""
return bool(restr) and restr != True and not isinstance(restr, Top)

# -------------------------- Intercept DJ methods --------------------------

def fetch(self, *args, log_export=True, **kwargs):
Expand All @@ -317,11 +328,14 @@ def restrict(self, restriction):
"""Log restrict for export."""
if not self.export_id:
return super().restrict(restriction)

log_export = "fetch_nwb" not in self._called_funcs()
if self.is_restr(restriction) and self.is_restr(self.restriction):
combined = AndList([restriction, self.restriction])
else:
combined = restriction or self.restriction
return self._run_with_log(
super().restrict,
restriction=AndList([restriction, self.restriction]),
log_export=log_export,
super().restrict, restriction=combined, log_export=log_export
)

def join(self, other, log_export=True, *args, **kwargs):
Expand Down

0 comments on commit c9e05ee

Please sign in to comment.