Skip to content

Commit

Permalink
Merge pull request datajoint#2 from ttngu207/iub-lulab
Browse files Browse the repository at this point in the history
Add computed table `LabeledVideo`
  • Loading branch information
kushalbakshi authored Aug 12, 2024
2 parents db3cabe + 8c5a5e9 commit 353315d
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 2 deletions.
108 changes: 107 additions & 1 deletion element_deeplabcut/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@


schema = dj.schema()
logger = dj.logger

_linking_module = None
logger = dj.logger

Expand Down Expand Up @@ -766,10 +768,16 @@ def make(self, key):
find_full_path(get_dlc_root_data_dir(), fp).as_posix()
for fp in video_relpaths
]
analyze_video_params = (PoseEstimationTask & key).fetch1(
pose_estimation_params = (PoseEstimationTask & key).fetch1(
"pose_estimation_params"
) or {}

# expect a nested dictionary with "analyze_videos" params
# if not, assume "pose_estimation_params" as a flat dictionary that include relevant "analyze_videos" params
analyze_video_params = (
pose_estimation_params.get("analyze_videos") or pose_estimation_params
)

@memoized_result(
uniqueness_dict={
**analyze_video_params,
Expand Down Expand Up @@ -911,6 +919,104 @@ def get_trajectory(cls, key: dict, body_parts: list = "all") -> pd.DataFrame:
return df


@schema
class LabeledVideo(dj.Computed):
definition = """
-> PoseEstimation
"""

class File(dj.Part):
definition = """
-> master
-> VideoRecording.File
---
labeled_video_path: varchar(255) # relative path to labeled video
labeled_video_file=null: filepath@dlc-processed
"""

@property
def key_source(self):
return PoseEstimation & RecordingInfo

def make(self, key):
import deeplabcut

pose_estimation_params = (PoseEstimationTask & key).fetch1(
"pose_estimation_params"
) or {}

# expect a nested dictionary with "create_labeled_video" and "extract_outlier_frames" params
# if not, assume "pose_estimation_params" as a flat dictionary
create_labeled_video_params = (
pose_estimation_params.get("create_labeled_video") or pose_estimation_params
)

outputframerate = create_labeled_video_params.pop(
"outputframerate", 5
) # final labeled video FPS defaults to 5 Hz

dlc_model_ = (Model & key).fetch1()
fps, nframes = (RecordingInfo & key).fetch1("fps", "nframes")
output_dir = (PoseEstimationTask & key).fetch1("pose_estimation_output_dir")
output_dir = find_full_path(get_dlc_root_data_dir(), output_dir)

project_path = find_full_path(
get_dlc_root_data_dir(), dlc_model_["project_path"]
)

try:
dlc_config = next(output_dir.glob("dj_dlc_config*.yaml"))
dlc_config = project_path / dlc_config.name
assert dlc_config.exists()
except (StopIteration, AssertionError):
dlc_config = next(project_path.glob("dj_dlc_config*.yaml"))
logger.warning(
f"No dj_dlc_config*.yaml file found in {output_dir} - this is unexpected.\nUsing {dlc_config}"
)

entries = []
for vkey in (VideoRecording.File & key).fetch("KEY"):
video_file = (VideoRecording.File & vkey).fetch1("file_path")
video_file = find_full_path(get_dlc_root_data_dir(), video_file)

# -- create labeled video --
create_labeled_video_kwargs = {
k: v
for k, v in create_labeled_video_params.items()
if k in inspect.signature(deeplabcut.create_labeled_video).parameters
}
create_labeled_video_kwargs.update(
dict(
config=dlc_config.as_posix(),
videos=[video_file.as_posix()],
shuffle=dlc_model_["shuffle"],
trainingsetindex=dlc_model_["trainingsetindex"],
modelprefix=dlc_model_["model_prefix"],
destfolder=output_dir.as_posix(),
Frames2plot=np.arange(0, nframes, int(fps / outputframerate)),
outputframerate=outputframerate,
)
)
deeplabcut.create_labeled_video(**create_labeled_video_kwargs)

labeled_video_path = next(
output_dir.glob(f"{video_file.stem}*_labeled.mp4")
)
entries.append(
{
**key,
**vkey,
"labeled_video_path": labeled_video_path.relative_to(
get_dlc_processed_data_dir()
).as_posix(),
"labeled_video_file": labeled_video_path.as_posix(),
}
)

self.insert1(key)
self.File.insert(entries)


def str_to_bool(value) -> bool:
"""Return whether the provided string represents true. Otherwise false.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,6 @@
"tests": ["pytest", "pytest-cov", "shutils"],
"dlc-pytorch": [
"deeplabcut @ git+https://github.com/DeepLabCut/DeepLabCut.git@pytorch_dlc"
]
],
},
)

0 comments on commit 353315d

Please sign in to comment.