diff --git a/services/worker/src/worker/utils.py b/services/worker/src/worker/utils.py index 1dbc91efe..3ca0a5b36 100644 --- a/services/worker/src/worker/utils.py +++ b/services/worker/src/worker/utils.py @@ -8,7 +8,8 @@ from collections.abc import Iterable from dataclasses import dataclass, field from itertools import count, islice -from typing import Literal, Optional, TypeVar, Union, overload +from typing import Any, Literal, Optional, TypeVar, Union, overload +from unittest.mock import patch from urllib.parse import quote import PIL @@ -38,6 +39,13 @@ # ^ see https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.MAX_IMAGE_PIXELS +def _no_op_decode_example(self: Any, value: dict[str, Any], token_per_repo_id: Any = None) -> dict[str, Any]: # noqa: ARG001 + return value + + +disable_video_decoding = patch("datasets.Video.decode_example", _no_op_decode_example) + + @retry(on=[ConnectionError]) def get_rows( dataset: str, @@ -65,7 +73,8 @@ def get_rows( raise TypeError("load_dataset should return a Dataset in normal mode") if column_names: ds = ds.select_columns(column_names) - rows_plus_one = list(itertools.islice(ds, rows_max_number + 1)) + with disable_video_decoding: + rows_plus_one = list(itertools.islice(ds, rows_max_number + 1)) # ^^ to be able to detect if a split has exactly ROWS_MAX_NUMBER rows rows = rows_plus_one[:rows_max_number] all_fetched = len(rows_plus_one) <= rows_max_number