diff --git a/CHANGELOG.md b/CHANGELOG.md index adf11fa98..ae91a3131 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ - Ensure merge tables are declared during file insertion #1205 - Update URL for DANDI Docs #1210 - Improve cron job documentation and script #1226 +- Update export process to include `~external` tables #XXXX ### Pipelines diff --git a/src/spyglass/common/common_usage.py b/src/spyglass/common/common_usage.py index 58b28c0f6..89285abb4 100644 --- a/src/spyglass/common/common_usage.py +++ b/src/spyglass/common/common_usage.py @@ -228,8 +228,6 @@ def _add_externals_to_restr_graph( ) restr_graph.visited.add(analysis_name) - restr_graph.visited.update({raw_name, analysis_name}) - return restr_graph def get_restr_graph( @@ -255,9 +253,14 @@ def get_restr_graph( ) restr_graph = RestrGraph( - seed_table=self, leaves=leaves, verbose=verbose, cascade=cascade + seed_table=self, leaves=leaves, verbose=verbose, cascade=False ) - return self._add_externals_to_restr_graph(restr_graph, key) + restr_graph = self._add_externals_to_restr_graph(restr_graph, key) + + if cascade: + restr_graph.cascade() + + return restr_graph def preview_tables(self, **kwargs) -> list[dj.FreeTable]: """Return a list of restricted FreeTables for a given restriction/key. diff --git a/src/spyglass/position/position_merge.py b/src/spyglass/position/position_merge.py index 76d9b40f8..a118060ca 100644 --- a/src/spyglass/position/position_merge.py +++ b/src/spyglass/position/position_merge.py @@ -75,3 +75,25 @@ def fetch1_dataframe(self) -> DataFrame: & key ) return query.fetch1_dataframe() + + def fetch_pose_dataframe(self): + key = self.merge_restrict(self.fetch("KEY", as_dict=True)).fetch( + "KEY", as_dict=True + ) + query = ( + source_class_dict[ + to_camel_case(self.merge_get_parent(self.proj()).table_name) + ] + & key + ) + return query.fetch_pose_dataframe() + + def fetch_video_path(self, key=dict()): + key = self.merge_restrict((self & key).proj()).proj() + query = ( + source_class_dict[ + to_camel_case(self.merge_get_parent(self.proj()).table_name) + ] + & key + ) + return query.fetch_video_path() diff --git a/src/spyglass/position/v1/position_dlc_pose_estimation.py b/src/spyglass/position/v1/position_dlc_pose_estimation.py index 04f7724f4..15124de9b 100644 --- a/src/spyglass/position/v1/position_dlc_pose_estimation.py +++ b/src/spyglass/position/v1/position_dlc_pose_estimation.py @@ -354,7 +354,7 @@ def _logged_make(self, key): def fetch_dataframe(self, *attrs, **kwargs) -> pd.DataFrame: """Fetch a concatenated dataframe of all bodyparts.""" - entries = (self.BodyPart & self).fetch("KEY") + entries = (self.BodyPart & self).fetch("KEY", log_export=False) nwb_data_dict = { entry["bodypart"]: (self.BodyPart() & entry).fetch_nwb()[0] for entry in entries diff --git a/src/spyglass/utils/dj_graph.py b/src/spyglass/utils/dj_graph.py index 1e6137a02..f387abfad 100644 --- a/src/spyglass/utils/dj_graph.py +++ b/src/spyglass/utils/dj_graph.py @@ -9,9 +9,11 @@ from functools import cached_property from hashlib import md5 as hash_md5 from itertools import chain as iter_chain +from pathlib import Path from typing import Any, Dict, Iterable, List, Set, Tuple, Union from datajoint import FreeTable, Table +from datajoint import config as dj_config from datajoint.condition import make_condition from datajoint.hash import key_hash from datajoint.user_tables import TableMeta @@ -782,6 +784,12 @@ def analysis_file_tbl(self) -> Table: return AnalysisNwbfile() + @property + def file_externals(self): + from spyglass.common.common_nwbfile import schema + + return schema.external + def cascade_files(self): """Set node attribute for analysis files.""" analysis_pk = self.analysis_file_tbl.primary_key @@ -791,6 +799,34 @@ def cascade_files(self): files = list(ft.fetch(*analysis_pk)) self._set_node(ft, "files", files) + stores = dj_config["stores"] + + analysis_paths = [ + str(Path(p).relative_to(stores["analysis"]["location"])) + for p in self._get_ft( + self.analysis_file_tbl.full_table_name, with_restr=True + ).fetch("analysis_file_abs_path") + ] + self._set_restr( + self.file_externals["analysis"], + f"filepath in {tuple(analysis_paths)}", + ) + + raw_paths = [ + str(Path(p).relative_to(stores["raw"]["location"])) + for p in self._get_ft( + "`common_nwbfile`.`nwbfile`", with_restr=True + ).fetch("nwb_file_abs_path") + ] + if len(raw_paths) == 1: + self._set_restr( + self.file_externals["raw"], f"filepath = '{raw_paths[0]}'" + ) + elif len(raw_paths) > 1: + self._set_restr( + self.file_externals["raw"], f"filepath in {tuple(raw_paths)}" + ) + @property def file_dict(self) -> Dict[str, List[str]]: """Return dictionary of analysis files from all visited nodes. diff --git a/src/spyglass/utils/dj_merge_tables.py b/src/spyglass/utils/dj_merge_tables.py index e5e30c848..9b732bbd5 100644 --- a/src/spyglass/utils/dj_merge_tables.py +++ b/src/spyglass/utils/dj_merge_tables.py @@ -261,8 +261,8 @@ def _merge_repr( datajoint.expression.Union """ - parts = [ - cls() * p # join with master to include sec key (i.e., 'source') + parts = [ # join with master to include sec key (i.e., 'source') + cls().join(p, log_export=False) for p in cls._merge_restrict_parts( restriction=restriction, add_invalid_restrict=False, diff --git a/src/spyglass/utils/mixins/export.py b/src/spyglass/utils/mixins/export.py index 222963ebb..648842b8d 100644 --- a/src/spyglass/utils/mixins/export.py +++ b/src/spyglass/utils/mixins/export.py @@ -7,7 +7,8 @@ from os import environ from re import match as re_match -from datajoint.condition import AndList, make_condition +from datajoint.condition import AndList, Top, make_condition +from datajoint.expression import QueryExpression from datajoint.table import Table from packaging.version import parse as version_parse @@ -262,6 +263,7 @@ def _run_join(self, **kwargs): table_list.append(other) # can other._log_fetch else: logger.warning(f"Cannot export log join for\n{other}") + __import__("pdb").set_trace() joined = self.proj().join(other.proj(), log_export=False) for table in table_list: # log separate for unique pks @@ -290,11 +292,20 @@ def _run_with_log(self, method, *args, log_export=True, **kwargs): if getattr(method, "__name__", None) == "join": # special case self._run_join(**kwargs) else: - self._log_fetch(restriction=kwargs.get("restriction")) + restr = kwargs.get("restriction") + if isinstance(restr, QueryExpression) and getattr( + restr, "restriction" + ): + restr = restr.restriction + self._log_fetch(restriction=restr) logger.debug(f"Export: {self._called_funcs()}") return ret + def is_restr(self, restr) -> bool: + """Check if a restriction is actually restricting.""" + return bool(restr) and restr != True and not isinstance(restr, Top) + # -------------------------- Intercept DJ methods -------------------------- def fetch(self, *args, log_export=True, **kwargs): @@ -317,11 +328,14 @@ def restrict(self, restriction): """Log restrict for export.""" if not self.export_id: return super().restrict(restriction) + log_export = "fetch_nwb" not in self._called_funcs() + if self.is_restr(restriction) and self.is_restr(self.restriction): + combined = AndList([restriction, self.restriction]) + else: + combined = restriction or self.restriction return self._run_with_log( - super().restrict, - restriction=AndList([restriction, self.restriction]), - log_export=log_export, + super().restrict, restriction=combined, log_export=log_export ) def join(self, other, log_export=True, *args, **kwargs):