Skip to content

Commit

Permalink
disable video decoding for first-rows
Browse files Browse the repository at this point in the history
  • Loading branch information
lhoestq committed Oct 28, 2024
1 parent 31bf035 commit d884ff8
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions services/worker/src/worker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from collections.abc import Iterable
from dataclasses import dataclass, field
from itertools import count, islice
from typing import Literal, Optional, TypeVar, Union, overload
from typing import Any, Literal, Optional, TypeVar, Union, overload
from unittest.mock import patch
from urllib.parse import quote

import PIL
Expand Down Expand Up @@ -38,6 +39,13 @@
# ^ see https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.MAX_IMAGE_PIXELS


def _no_op_decode_example(self: Any, value: dict[str, Any], token_per_repo_id: Any = None) -> dict[str, Any]: # noqa: ARG001
return value


disable_video_decoding = patch("datasets.Video.decode_example", _no_op_decode_example)


@retry(on=[ConnectionError])
def get_rows(
dataset: str,
Expand Down Expand Up @@ -65,7 +73,8 @@ def get_rows(
raise TypeError("load_dataset should return a Dataset in normal mode")
if column_names:
ds = ds.select_columns(column_names)
rows_plus_one = list(itertools.islice(ds, rows_max_number + 1))
with disable_video_decoding:
rows_plus_one = list(itertools.islice(ds, rows_max_number + 1))
# ^^ to be able to detect if a split has exactly ROWS_MAX_NUMBER rows
rows = rows_plus_one[:rows_max_number]
all_fetched = len(rows_plus_one) <= rows_max_number
Expand Down

0 comments on commit d884ff8

Please sign in to comment.