diff --git a/pyproject.toml b/pyproject.toml index 0b8be2f6..280991b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ tests = [ "wheel", ] densepose = [ + "torch<2.1.0", # densepose broken after 2.1.0; https://github.com/facebookresearch/detectron2/issues/5110 "detectron2-densepose @ git+https://github.com/facebookresearch/detectron2@main#subdirectory=projects/DensePose" ] diff --git a/tests/test_densepose.py b/tests/test_densepose.py index c6dbe223..c673fca5 100644 --- a/tests/test_densepose.py +++ b/tests/test_densepose.py @@ -60,7 +60,10 @@ def test_image(model, chimp_image_path, tmp_path): ) # output to disk - assert anatomy_info.shape == (2, 44) + assert anatomy_info.shape in [ + (2, 44), + (1, 44), + ] # depends on number of chimps identified; varies by version assert (anatomy_info > 0).any().any() assert (tmp_path / f"anatomized_{model}.csv").stat().st_size > 0 @@ -106,7 +109,12 @@ def test_video(model, chimp_video_path, tmp_path): ) # output to disk - assert anatomy_info.shape == (10, 46) + assert anatomy_info.shape[0] in [ + 8, + 9, + 10, + ] # depends on number of chimps identified; varies by version + assert anatomy_info.shape[1] == 46 assert (anatomy_info > 0).any().any() assert (tmp_path / f"anatomized_{model}.csv").stat().st_size > 0