Skip to content

Commit

Permalink
Add the option for multiple views loading - take2 (#346)
Browse files Browse the repository at this point in the history
* moving to new branch

* final test fix

* Update movement/io/load_poses.py

Co-authored-by: Niko Sirmpilatze <[email protected]>

* doc change

* fixed doc

---------

Co-authored-by: Niko Sirmpilatze <[email protected]>
  • Loading branch information
vigji and niksirbi authored Nov 22, 2024
1 parent a3956c4 commit 174817d
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 1 deletion.
15 changes: 14 additions & 1 deletion docs/source/getting_started/movement_dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ To discuss the specifics of both types of `movement` datasets, it is useful to c
To learn more about `xarray` data structures in general, see the relevant
[documentation](xarray:user-guide/data-structures.html).


## Dataset structure

```{figure} ../_static/dataset_structure.png
Expand Down Expand Up @@ -135,6 +134,20 @@ In both cases, appropriate **coordinates** are assigned to each **dimension**.
- `space` is labelled with either `x`, `y` (2D) or `x`, `y`, `z` (3D). Note that bounding boxes datasets are restricted to 2D space.
- `time` is labelled in seconds if `fps` is provided, otherwise the **coordinates** are expressed in frames (ascending 0-indexed integers).

:::{dropdown} Additional dimensions
:color: info
:icon: info
The above **dimensions** and **coordinates** are created
by default when loading a `movement` dataset from a single
file containing pose or bounding boxes tracks.

In some cases, you may encounter or create datasets with extra
**dimensions**. For example, the
{func}`movement.io.load_poses.from_multiview_files()` function
creates an additional `views` **dimension**,
with the **coordinates** being the names given to each camera view.
:::

### Data variables
The data variables in a `movement` dataset are the arrays that hold the actual data, as {class}`xarray.DataArray` objects.

Expand Down
35 changes: 35 additions & 0 deletions movement/io/load_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,41 @@ def from_dlc_file(
)


def from_multiview_files(
file_path_dict: dict[str, Path | str],
source_software: Literal["DeepLabCut", "SLEAP", "LightningPose"],
fps: float | None = None,
) -> xr.Dataset:
"""Load and merge pose tracking data from multiple views (cameras).
Parameters
----------
file_path_dict : dict[str, Union[Path, str]]
A dict whose keys are the view names and values are the paths to load.
source_software : {'LightningPose', 'SLEAP', 'DeepLabCut'}
The source software of the file.
fps : float, optional
The number of frames per second in the video. If None (default),
the `time` coordinates will be in frame numbers.
Returns
-------
xarray.Dataset
``movement`` dataset containing the pose tracks, confidence scores,
and associated metadata, with an additional ``views`` dimension.
"""
views_list = list(file_path_dict.keys())
new_coord_views = xr.DataArray(views_list, dims="view")

dataset_list = [
from_file(f, source_software=source_software, fps=fps)
for f in file_path_dict.values()
]

return xr.concat(dataset_list, dim=new_coord_views)


def _ds_from_lp_or_dlc_file(
file_path: Path | str,
source_software: Literal["LightningPose", "DeepLabCut"],
Expand Down
17 changes: 17 additions & 0 deletions tests/test_unit/test_load_poses.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,20 @@ def test_from_numpy_valid(
source_software=source_software,
)
self.assert_dataset(ds, expected_source_software=source_software)

def test_from_multiview_files(self):
"""Test that the from_file() function delegates to the correct
loader function according to the source_software.
"""
view_names = ["view_0", "view_1"]
file_path_dict = {
view: DATA_PATHS.get("DLC_single-wasp.predictions.h5")
for view in view_names
}
multi_view_ds = load_poses.from_multiview_files(
file_path_dict, source_software="DeepLabCut"
)

assert isinstance(multi_view_ds, xr.Dataset)
assert "view" in multi_view_ds.dims
assert multi_view_ds.view.values.tolist() == view_names

0 comments on commit 174817d

Please sign in to comment.