diff --git a/docs/user-guide/bestpractices.rst b/docs/user-guide/bestpractices.rst index 89c87ecb..6162c470 100644 --- a/docs/user-guide/bestpractices.rst +++ b/docs/user-guide/bestpractices.rst @@ -24,6 +24,16 @@ Handling GPU Out-of-Memory (OOM) Errors NeMo Curator is designed to be scalable with large amounts of text data, but OOM errors occur when the available GPU memory is insufficient for a given task. To help avoid these issues and ensure efficient processing, here are some strategies for managing memory usage and mitigating OOM challenges. +Controlling Partition Sizes +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The user should consider using ``files_per_partition`` or ``blocksize`` when reading data. This can help reduce the memory load by processing large datasets in smaller chunks. + +#. The ``blocksize`` argument is available for ``jsonl`` and ``parquet`` files. However, for `parquet` files, it is currently only available when ``add_filename=False``. + +#. For the ``blocksize`` parameter, the recommendation is to use 1/32 of the total GPU memory. For example, if you have a GPU with 32GB of memory, you can set ``blocksize="1GB"``. + + Utilize RMM Options ~~~~~~~~~~~~~~~~~~~ `RAPIDS Memory Manager (RMM) `_ is a package that enables you to allocate device memory in a highly configurable way. @@ -59,6 +69,7 @@ Alternatively, you can set these flags while initializing your own Dask client, client = Client(cluster) + Fuzzy Deduplication Guidelines ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Fuzzy deduplication is one of the most computationally expensive algorithms within the NeMo Curator pipeline. diff --git a/nemo_curator/_compat.py b/nemo_curator/_compat.py index 26fb0574..5de25ebd 100644 --- a/nemo_curator/_compat.py +++ b/nemo_curator/_compat.py @@ -23,6 +23,15 @@ # When mocking with autodoc the dask version is not there _dask_version = parse_version("2024.06.0") + +try: + import dask_cudf + + _dask_cudf_version = parse_version(dask_cudf.__version__) +except (ImportError, TypeError): + # When mocking with autodoc the dask version is not there + _dask_cudf_version = parse_version("2024.06.0") + try: import cudf @@ -40,6 +49,7 @@ DASK_SHUFFLE_METHOD_ARG = _dask_version > parse_version("2024.1.0") DASK_P2P_ERROR = _dask_version < parse_version("2023.10.0") DASK_SHUFFLE_CAST_DTYPE = _dask_version > parse_version("2023.12.0") +DASK_CUDF_PARQUET_READ_INCONSISTENT_SCHEMA = _dask_version > parse_version("2024.12") # Query-planning check (and cache) _DASK_QUERY_PLANNING_ENABLED = None diff --git a/nemo_curator/datasets/doc_dataset.py b/nemo_curator/datasets/doc_dataset.py index 3bebbd7d..10ffe223 100644 --- a/nemo_curator/datasets/doc_dataset.py +++ b/nemo_curator/datasets/doc_dataset.py @@ -50,7 +50,8 @@ def read_json( cls, input_files: Union[str, List[str]], backend: Literal["pandas", "cudf"] = "pandas", - files_per_partition: int = 1, + files_per_partition: Optional[int] = None, + blocksize: Optional[str] = "1gb", add_filename: bool = False, input_meta: Union[str, dict] = None, columns: Optional[List[str]] = None, @@ -74,8 +75,9 @@ def read_json( input_files=input_files, file_type="jsonl", backend=backend, - files_per_partition=files_per_partition, add_filename=add_filename, + files_per_partition=files_per_partition, + blocksize=blocksize, input_meta=input_meta, columns=columns, **kwargs, @@ -87,8 +89,9 @@ def read_parquet( cls, input_files: Union[str, List[str]], backend: Literal["pandas", "cudf"] = "pandas", - files_per_partition: int = 1, - add_filename: bool = False, + files_per_partition: Optional[int] = None, + blocksize: Optional[str] = "1gb", + add_filename=False, columns: Optional[List[str]] = None, **kwargs, ) -> "DocumentDataset": @@ -109,8 +112,9 @@ def read_parquet( input_files=input_files, file_type="parquet", backend=backend, - files_per_partition=files_per_partition, add_filename=add_filename, + files_per_partition=files_per_partition, + blocksize=blocksize, columns=columns, **kwargs, ) @@ -121,8 +125,6 @@ def read_pickle( cls, input_files: Union[str, List[str]], backend: Literal["pandas", "cudf"] = "pandas", - files_per_partition: int = 1, - add_filename: bool = False, columns: Optional[List[str]] = None, **kwargs, ) -> "DocumentDataset": @@ -142,8 +144,6 @@ def read_pickle( input_files=input_files, file_type="pickle", backend=backend, - files_per_partition=files_per_partition, - add_filename=add_filename, columns=columns, **kwargs, ) @@ -234,8 +234,9 @@ def _read_json_or_parquet( input_files: Union[str, List[str]], file_type: str, backend: Literal["cudf", "pandas"], - files_per_partition: int, add_filename: bool, + files_per_partition: Optional[int] = None, + blocksize: Optional[str] = None, input_meta: Union[str, dict] = None, columns: Optional[List[str]] = None, **kwargs, @@ -267,6 +268,7 @@ def _read_json_or_parquet( file_type=file_type, backend=backend, files_per_partition=files_per_partition, + blocksize=blocksize, add_filename=add_filename, input_meta=input_meta, columns=columns, @@ -286,6 +288,7 @@ def _read_json_or_parquet( file_type=file_type, backend=backend, files_per_partition=files_per_partition, + blocksize=blocksize, add_filename=add_filename, input_meta=input_meta, columns=columns, @@ -311,6 +314,7 @@ def _read_json_or_parquet( file_type=file_type, backend=backend, files_per_partition=files_per_partition, + blocksize=blocksize, add_filename=add_filename, input_meta=input_meta, columns=columns, diff --git a/nemo_curator/utils/distributed_utils.py b/nemo_curator/utils/distributed_utils.py index 75931b95..89b4415f 100644 --- a/nemo_curator/utils/distributed_utils.py +++ b/nemo_curator/utils/distributed_utils.py @@ -17,6 +17,10 @@ import os import shutil +import dask + +from nemo_curator._compat import DASK_CUDF_PARQUET_READ_INCONSISTENT_SCHEMA + os.environ["RAPIDS_NO_INITIALIZE"] = "1" import random import warnings @@ -24,7 +28,7 @@ from datetime import datetime from itertools import zip_longest from pathlib import Path -from typing import Dict, List, Literal, Optional, Union +from typing import Callable, Dict, List, Literal, Optional, Union import dask.dataframe as dd import numpy as np @@ -269,15 +273,30 @@ def _set_torch_to_use_rmm(): torch.cuda.memory.change_current_allocator(rmm_torch_allocator) +def select_columns( + df: Union[dd.DataFrame, pd.DataFrame, "cudf.DataFrame"], + columns: List[str], + filetype: Literal["jsonl", "json", "parquet"], + add_filename: bool, +) -> Union[dd.DataFrame, pd.DataFrame, "cudf.DataFrame"]: + # We exclude parquet because the parquet readers already support column selection + if filetype in ["jsonl", "json"] and columns is not None: + if add_filename and "filename" not in columns: + columns.append("filename") + df = df[columns] + + return df + + def read_single_partition( files: List[str], backend: Literal["cudf", "pandas"] = "cudf", filetype: str = "jsonl", add_filename: bool = False, input_meta: Union[str, dict] = None, - columns: Optional[List[str]] = None, + io_columns: Optional[List[str]] = None, **kwargs, -) -> Union[cudf.DataFrame, pd.DataFrame]: +) -> Union["cudf.DataFrame", pd.DataFrame]: """ This function reads a file with cuDF, sorts the columns of the DataFrame and adds a "filename" column. @@ -315,9 +334,13 @@ def read_single_partition( read_kwargs["dtype"] = ( ast.literal_eval(input_meta) if type(input_meta) == str else input_meta ) + # because pandas doesn't support `prune_columns`, it'll always return all columns even when input_meta is specified + # to maintain consistency we explicitly set `io_columns` here + if backend == "pandas" and not io_columns: + io_columns = list(read_kwargs["dtype"].keys()) elif filetype == "parquet": - read_kwargs = {"columns": columns} + read_kwargs = {"columns": io_columns} if backend == "cudf": read_f = cudf.read_parquet else: @@ -346,18 +369,133 @@ def read_single_partition( df = read_f(file, **read_kwargs, **kwargs) if add_filename: df["filename"] = os.path.basename(file) + df = select_columns(df, io_columns, filetype, add_filename) df_ls.append(df) + df = concat_f(df_ls, ignore_index=True) else: df = read_f(files, **read_kwargs, **kwargs) + df = select_columns(df, io_columns, filetype, add_filename) + return df - if filetype in ["jsonl", "json"] and columns is not None: - if add_filename and "filename" not in columns: - columns.append("filename") - df = df[columns] - df = df[sorted(df.columns)] - return df +def read_data_blocksize( + input_files: List[str], + backend: Literal["cudf", "pandas"], + file_type: Literal["parquet", "jsonl"], + blocksize: str, + add_filename: bool = False, + input_meta: Union[str, dict] = None, + columns: Optional[List[str]] = None, + **kwargs, +) -> dd.DataFrame: + + read_kwargs = dict() + + postprocessing_func: Optional[Callable[[dd.DataFrame], dd.DataFrame]] = None + if file_type == "jsonl": + warnings.warn( + "If underlying JSONL data does not have a consistent schema, reading with blocksize will fail. " + "Please use files_per_partition approach instead." + ) + + if backend == "pandas": + warnings.warn( + "Pandas backend with blocksize cannot read multiple JSONL files into a single partition. " + "Please use files_per_partition if blocksize exceeds average file size." + ) + read_func = dd.read_json + read_kwargs["lines"] = True + if input_meta is not None: + if backend == "cudf": + # To save GPU memory, we prune columns while reading, and keep only those that are + # specified in the input_meta + read_kwargs["prune_columns"] = True + + read_kwargs["dtype"] = ( + ast.literal_eval(input_meta) + if isinstance(input_meta, str) + else input_meta + ) + + if not columns: + # To maintain consistency with the behavior of `read_data_fpp` where passing `input_meta` + # only returns those columns, we explicitly set `columns` here + columns = list(read_kwargs["dtype"].keys()) + if add_filename: + + def extract_filename(path: str) -> str: + return os.path.basename(path) + + read_kwargs["include_path_column"] = add_filename + read_kwargs["path_converter"] = extract_filename + postprocessing_func = lambda df: df.rename(columns={"path": "filename"}) + + elif file_type == "parquet": + if backend == "cudf" and not DASK_CUDF_PARQUET_READ_INCONSISTENT_SCHEMA: + warnings.warn( + "If underlying Parquet data does not have consistent schema, reading with blocksize will fail. " + "Please update underlying RAPIDS package to version 25.02 or higher, or use files_per_partition approach instead." + ) + elif backend == "pandas": + warnings.warn( + "If underlying Parquet data does not have a consistent column order, reading with blocksize might fail. " + "Please use files_per_partition approach instead." + ) + + if add_filename: + msg = "add_filename and blocksize cannot be set at the same time for Parquet files." + raise ValueError(msg) + read_func = dd.read_parquet + read_kwargs["columns"] = columns + # In dask_cudf >= 24.12, aggregate_files is not required, but we've kept here until + # it gets in dask (pandas) as well + read_kwargs["aggregate_files"] = True + else: + msg = f"Reading with blocksize is only supported for JSONL and Parquet files, not {file_type=}" + raise ValueError(msg) + + with dask.config.set({"dataframe.backend": backend}): + df = read_func(input_files, blocksize=blocksize, **read_kwargs, **kwargs) + if postprocessing_func is not None: + df = postprocessing_func(df) + + output = select_columns(df, columns, file_type, add_filename) + return output[sorted(output.columns)] + + +def read_data_files_per_partition( + input_files: List[str], + file_type: Literal["parquet", "json", "jsonl"], + backend: Literal["cudf", "pandas"] = "cudf", + add_filename: bool = False, + files_per_partition: Optional[int] = None, + input_meta: Union[str, dict] = None, + columns: Optional[List[str]] = None, + **kwargs, +) -> dd.DataFrame: + input_files = sorted(input_files) + if files_per_partition > 1: + input_files = [ + input_files[i : i + files_per_partition] + for i in range(0, len(input_files), files_per_partition) + ] + else: + input_files = [[file] for file in input_files] + + output = dd.from_map( + read_single_partition, + input_files, + filetype=file_type, + backend=backend, + add_filename=add_filename, + input_meta=input_meta, + enforce_metadata=False, + io_columns=columns, + **kwargs, + ) + output = output[sorted(output.columns)] + return output def read_pandas_pickle( @@ -390,12 +528,13 @@ def read_data( input_files: Union[str, List[str]], file_type: str = "pickle", backend: Literal["cudf", "pandas"] = "cudf", - files_per_partition: int = 1, + blocksize: Optional[str] = None, + files_per_partition: Optional[int] = 1, add_filename: bool = False, input_meta: Union[str, dict] = None, columns: Optional[List[str]] = None, **kwargs, -) -> Union[dd.DataFrame, dask_cudf.DataFrame]: +) -> dd.DataFrame: """ This function can read multiple data formats and returns a Dask-cuDF DataFrame. @@ -414,13 +553,8 @@ def read_data( A Dask-cuDF or a Dask-pandas DataFrame. """ - if backend == "cudf": - # Try using cuDF. If not availible will throw an error. - test_obj = cudf.Series - if isinstance(input_files, str): input_files = [input_files] - if file_type == "pickle": df = read_pandas_pickle( input_files[0], add_filename=add_filename, columns=columns, **kwargs @@ -442,30 +576,44 @@ def read_data( "function with the `keep_extensions` parameter." ) - print(f"Reading {len(input_files)} files", flush=True) - input_files = sorted(input_files) - - if files_per_partition > 1: - input_files = [ - input_files[i : i + files_per_partition] - for i in range(0, len(input_files), files_per_partition) - ] - - else: - input_files = [[file] for file in input_files] - - return dd.from_map( - read_single_partition, - input_files, - filetype=file_type, - backend=backend, - add_filename=add_filename, - input_meta=input_meta, - enforce_metadata=False, - columns=columns, - **kwargs, + print( + f"Reading {len(input_files)} files with {blocksize=} / {files_per_partition=}", + flush=True, ) - + if blocksize is not None and files_per_partition is not None: + msg = "blocksize and files_per_partition cannot be set at the same time" + raise ValueError(msg) + + if blocksize is not None and ( + file_type == "jsonl" or (file_type == "parquet" and not add_filename) + ): + return read_data_blocksize( + input_files, + backend=backend, + file_type=file_type, + blocksize=blocksize, + add_filename=add_filename, + input_meta=input_meta, + columns=columns, + **kwargs, + ) + else: + if backend == "cudf" and ( + file_type == "jsonl" or (file_type == "parquet" and not add_filename) + ): + warnings.warn( + "Consider passing in blocksize for better control over memory usage." + ) + return read_data_files_per_partition( + input_files, + file_type=file_type, + backend=backend, + add_filename=add_filename, + files_per_partition=files_per_partition, + input_meta=input_meta, + columns=columns, + **kwargs, + ) else: raise RuntimeError("Could not read data, please check file type") diff --git a/tests/test_io.py b/tests/test_io.py index 76fd6ade..432c00b3 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -113,9 +113,7 @@ def test_meta_dict(self, jsonl_dataset): output_meta = str({col: str(dtype) for col, dtype in dataset.df.dtypes.items()}) - expected_meta = ( - "{'date': 'datetime64[ns, UTC]', 'id': 'float64', 'text': 'object'}" - ) + expected_meta = "{'id': 'float64'}" assert ( output_meta == expected_meta @@ -139,9 +137,7 @@ def test_meta_str(self, jsonl_dataset): output_meta = str({col: str(dtype) for col, dtype in dataset.df.dtypes.items()}) - expected_meta = ( - "{'date': 'datetime64[ns, UTC]', 'id': 'float64', 'text': 'object'}" - ) + expected_meta = "{'id': 'float64'}" assert ( output_meta == expected_meta @@ -240,6 +236,7 @@ def test_multifile_multi_partition(self, tmp_path, file_ext, read_f): got_df = read_f( str(tmp_path / file_ext), + blocksize=None, files_per_partition=2, backend="pandas", add_filename=True, diff --git a/tests/test_read_data.py b/tests/test_read_data.py new file mode 100644 index 00000000..a619be3a --- /dev/null +++ b/tests/test_read_data.py @@ -0,0 +1,586 @@ +import tempfile + +import pandas as pd +import pytest + +from nemo_curator._compat import DASK_CUDF_PARQUET_READ_INCONSISTENT_SCHEMA +from nemo_curator.utils.distributed_utils import ( + read_data, + read_data_blocksize, + read_data_files_per_partition, +) +from nemo_curator.utils.file_utils import get_all_files_paths_under + +NUM_FILES = 5 +NUM_RECORDS = 100 + + +# Fixture to create multiple small JSONL files +@pytest.fixture +def mock_multiple_jsonl_files(tmp_path): + file_paths = [] + for file_id in range(NUM_FILES): + jsonl_file = tmp_path / f"test_{file_id}.jsonl" + with open(jsonl_file, "w") as f: + for record_id in range(NUM_RECORDS): + # 100 rows are ~5kb + f.write( + f'{{"id": "id_{file_id}_{record_id}", "text": "A longish string {file_id}_{record_id}"}}\n' + ) + file_paths.append(str(jsonl_file)) + return file_paths + + +# Fixture to create multiple small Parquet files +@pytest.fixture +def mock_multiple_parquet_files(tmp_path): + file_paths = [] + for file_id in range(NUM_FILES): + # 100 rows are ~5kb + parquet_file = tmp_path / f"test_{file_id}.parquet" + df = pd.DataFrame( + [ + { + "id": f"id_{file_id}_{record_id}", + "text": f"A string {file_id}_{record_id}", + } + for record_id in range(NUM_RECORDS) + ] + ) + # We specify row_group_size so that we can test splitting a single big file into smaller chunks + df.to_parquet(parquet_file, compression=None, row_group_size=10) + file_paths.append(str(parquet_file)) + return file_paths + + +@pytest.fixture +def mock_multiple_jsonl_files_different_cols(tmp_path): + file_paths = [] + for file_id in range(NUM_FILES): + jsonl_file = tmp_path / f"different_cols_test_{file_id}.jsonl" + + def make_record_without_meta(file_id, record_id): + return { + "id": f"id_{file_id}_{record_id}", + "text": f"A string {file_id}_{record_id}", + } + + def make_record_with_meta(file_id, record_id): + return { + "text": f"A string {file_id}_{record_id}", + "meta1": [ + {"field1": "field_one", "field2": "field_two"}, + ], + "id": f"id_{file_id}_{record_id}", + } + + df = pd.DataFrame( + [ + ( + make_record_without_meta(file_id, record_id) + if file_id == 0 + else make_record_with_meta(file_id, record_id) + ) + for record_id in range(NUM_RECORDS) + ] + ) + + df.to_json(jsonl_file, orient="records", lines=True) + file_paths.append(str(jsonl_file)) + return file_paths + + +# Fixture to create multiple small Parquet files +@pytest.fixture +def mock_multiple_parquet_files_different_cols(tmp_path): + file_paths = [] + for file_id in range(NUM_FILES): + # 100 rows are ~5kb + parquet_file = tmp_path / f"test_diff_cols_{file_id}.parquet" + + def make_record_without_meta(file_id, record_id): + return { + "id": f"id_{file_id}_{record_id}", + "text": f"A string {file_id}_{record_id}", + } + + def make_record_with_meta(file_id, record_id): + return { + "text": f"A string {file_id}_{record_id}", + "meta1": [ + {"field1": "field_one", "field2": "field_two"}, + ], + "id": f"id_{file_id}_{record_id}", + } + + df = pd.DataFrame( + [ + ( + make_record_without_meta(file_id, record_id) + if file_id == 0 + else make_record_with_meta(file_id, record_id) + ) + for record_id in range(NUM_RECORDS) + ] + ) + df.to_parquet(parquet_file, compression=None, row_group_size=10) + file_paths.append(str(parquet_file)) + return file_paths + + +@pytest.mark.gpu +@pytest.mark.parametrize("file_type", ["jsonl", "parquet"]) +@pytest.mark.parametrize("blocksize", ["1kb", "5kb", "10kb"]) +def test_cudf_read_data_blocksize_partitioning( + mock_multiple_jsonl_files, mock_multiple_parquet_files, file_type, blocksize +): + import cudf + + input_files = ( + mock_multiple_jsonl_files + if file_type == "jsonl" + else mock_multiple_parquet_files + ) + + df = read_data_blocksize( + input_files=input_files, + backend="cudf", + file_type=file_type, + blocksize=blocksize, + add_filename=False, + input_meta=None, + columns=None, + ) + + # Compute the number of partitions in the resulting DataFrame + num_partitions = df.optimize().npartitions + # Assert that we have two partitions (since we have ~15KB total data and a blocksize of 10KB) + if blocksize == "1kb": + assert ( + num_partitions > NUM_FILES + ), f"Expected > {NUM_FILES} partitions but got {num_partitions}" + elif blocksize == "5kb": + assert ( + num_partitions == NUM_FILES + ), f"Expected {NUM_FILES} partitions but got {num_partitions}" + elif blocksize == "10kb": + assert ( + num_partitions < NUM_FILES + ), f"Expected < {NUM_FILES} partitions but got {num_partitions}" + else: + raise ValueError(f"Invalid blocksize: {blocksize}") + total_rows = len(df) + assert ( + total_rows == NUM_FILES * NUM_RECORDS + ), f"Expected {NUM_FILES * NUM_RECORDS} rows but got {total_rows}" + + assert isinstance(df["id"].compute(), cudf.Series) + + +@pytest.mark.parametrize("file_type", ["jsonl", "parquet"]) +@pytest.mark.parametrize("blocksize", ["1kb", "5kb", "10kb"]) +def test_pandas_read_data_blocksize_partitioning( + mock_multiple_jsonl_files, mock_multiple_parquet_files, file_type, blocksize +): + input_files = ( + mock_multiple_jsonl_files + if file_type == "jsonl" + else mock_multiple_parquet_files + ) + + df = read_data_blocksize( + input_files=input_files, + backend="pandas", + file_type=file_type, + blocksize=blocksize, + add_filename=False, + input_meta=None, + columns=None, + ) + + # Compute the number of partitions in the resulting DataFrame + num_partitions = df.npartitions + # Our total data is ~25kb where each file is 5kb + if blocksize == "1kb": + assert ( + num_partitions > NUM_FILES + ), f"Expected > {NUM_FILES} partitions but got {num_partitions}" + elif blocksize == "5kb": + assert ( + num_partitions == NUM_FILES + ), f"Expected {NUM_FILES} partitions but got {num_partitions}" + elif blocksize == "10kb": + # Because pandas doesn't suppport reading json files together, a partition will only be as big as a single file + if file_type == "jsonl": + assert ( + num_partitions == NUM_FILES + ), f"Expected {NUM_FILES} partitions but got {num_partitions}" + # Parquet files can be read together + elif file_type == "parquet": + assert ( + num_partitions < NUM_FILES + ), f"Expected > {NUM_FILES} partitions but got {num_partitions}" + else: + raise ValueError(f"Invalid blocksize: {blocksize}") + total_rows = len(df) + assert ( + total_rows == NUM_FILES * NUM_RECORDS + ), f"Expected {NUM_FILES * NUM_RECORDS} rows but got {total_rows}" + + assert isinstance(df["id"].compute(), pd.Series) + + +@pytest.mark.parametrize( + "backend", + ["pandas", pytest.param("cudf", marks=pytest.mark.gpu)], +) +@pytest.mark.parametrize("file_type", ["jsonl", "parquet"]) +@pytest.mark.parametrize("fpp", [1, NUM_FILES // 2, NUM_FILES, NUM_FILES * 2]) +def test_read_data_fpp_partitioning( + mock_multiple_jsonl_files, mock_multiple_parquet_files, backend, file_type, fpp +): + input_files = ( + mock_multiple_jsonl_files + if file_type == "jsonl" + else mock_multiple_parquet_files + ) + + df = read_data_files_per_partition( + input_files=input_files, + backend=backend, + file_type=file_type, + files_per_partition=fpp, + add_filename=False, + input_meta=None, + columns=None, + ) + + # Compute the number of partitions in the resulting DataFrame + num_partitions = df.npartitions + # Assert that we have two partitions (since we have ~15KB total data and a blocksize of 10KB) + if fpp == 1: + assert ( + num_partitions == NUM_FILES + ), f"Expected {NUM_FILES} partitions but got {num_partitions}" + elif fpp == NUM_FILES // 2: + assert ( + num_partitions < NUM_FILES + ), f"Expected {NUM_FILES} partitions but got {num_partitions}" + elif fpp >= NUM_FILES: + assert num_partitions == 1, f"Expected 1 partition but got {num_partitions}" + else: + raise ValueError(f"Invalid fpp: {fpp}") + total_rows = len(df) + assert ( + total_rows == NUM_FILES * NUM_RECORDS + ), f"Expected {NUM_FILES * NUM_RECORDS} rows but got {total_rows}" + if backend == "cudf": + import cudf + + assert isinstance(df["id"].compute(), cudf.Series) + elif backend == "pandas": + assert isinstance(df["id"].compute(), pd.Series) + + +@pytest.mark.parametrize( + "backend", + [ + "pandas", + pytest.param("cudf", marks=pytest.mark.gpu), + ], +) +def test_read_data_blocksize_add_filename_jsonl(mock_multiple_jsonl_files, backend): + df = read_data_blocksize( + input_files=mock_multiple_jsonl_files, + backend=backend, + file_type="jsonl", + blocksize="128Mib", + add_filename=True, + input_meta=None, + columns=None, + ) + + assert "filename" in df.columns + file_names = df["filename"].unique().compute() + if backend == "cudf": + file_names = file_names.to_pandas() + + assert len(file_names) == NUM_FILES + assert set(file_names.values) == { + f"test_{file_id}.jsonl" for file_id in range(NUM_FILES) + } + + +@pytest.mark.parametrize( + "backend", + [ + "pandas", + pytest.param("cudf", marks=pytest.mark.gpu), + ], +) +def test_read_data_blocksize_add_filename_parquet(mock_multiple_parquet_files, backend): + with pytest.raises( + ValueError, + match="add_filename and blocksize cannot be set at the same time for Parquet files", + ): + read_data_blocksize( + input_files=mock_multiple_parquet_files, + backend=backend, + file_type="parquet", + blocksize="128Mib", + add_filename=True, + input_meta=None, + columns=None, + ) + + +@pytest.mark.parametrize( + "backend,file_type", + [ + pytest.param("cudf", "jsonl", marks=pytest.mark.gpu), + pytest.param("cudf", "parquet", marks=pytest.mark.gpu), + ("pandas", "jsonl"), + pytest.param( + "pandas", + "parquet", + marks=pytest.mark.xfail( + reason="filename column inaccessible with pandas backend and parquet" + ), + ), + ], +) +def test_read_data_fpp_add_filename( + mock_multiple_jsonl_files, mock_multiple_parquet_files, backend, file_type +): + input_files = ( + mock_multiple_jsonl_files + if file_type == "jsonl" + else mock_multiple_parquet_files + ) + + df = read_data_files_per_partition( + input_files=input_files, + backend=backend, + file_type=file_type, + files_per_partition=NUM_FILES, + add_filename=True, + input_meta=None, + columns=None, + ) + + assert list(df.columns) == list(df.head().columns) + assert set(df.columns) == {"filename", "id", "text"} + file_names = df["filename"].unique().compute() + if backend == "cudf": + file_names = file_names.to_pandas() + + assert len(file_names) == NUM_FILES + assert set(file_names.values) == { + f"test_{file_id}.{file_type}" for file_id in range(NUM_FILES) + } + + +@pytest.mark.parametrize( + "backend", + [ + "pandas", + pytest.param("cudf", marks=pytest.mark.gpu), + ], +) +@pytest.mark.parametrize( + "file_type,add_filename,function_name", + [ + *[ + ("jsonl", True, func) + for func in ["read_data_blocksize", "read_data_files_per_partition"] + ], + *[ + ("jsonl", False, func) + for func in ["read_data_blocksize", "read_data_files_per_partition"] + ], + *[ + ("parquet", False, func) + for func in ["read_data_blocksize", "read_data_files_per_partition"] + ], + *[("parquet", True, "read_data_files_per_partition")], + ], +) +@pytest.mark.parametrize( + "cols_to_select", [None, ["id"], ["text", "id"], ["id", "text"]] +) +def test_read_data_select_columns( + mock_multiple_jsonl_files, + mock_multiple_parquet_files, + backend, + file_type, + add_filename, + function_name, + cols_to_select, +): + input_files = ( + mock_multiple_jsonl_files + if file_type == "jsonl" + else mock_multiple_parquet_files + ) + if function_name == "read_data_files_per_partition": + func = read_data_files_per_partition + read_kwargs = {"files_per_partition": 1} + elif function_name == "read_data_blocksize": + func = read_data_blocksize + read_kwargs = {"blocksize": "128Mib"} + + df = func( + input_files=input_files, + backend=backend, + file_type=file_type, + add_filename=add_filename, + input_meta=None, + columns=list(cols_to_select) if cols_to_select else None, + **read_kwargs, + ) + if not cols_to_select: + cols_to_select = ["id", "text"] + + assert list(df.columns) == list(df.head().columns) + if not add_filename: + assert list(df.columns) == sorted(cols_to_select) + else: + assert list(df.columns) == sorted(cols_to_select + ["filename"]) + + +@pytest.mark.parametrize( + "backend", + [ + "pandas", + pytest.param("cudf", marks=pytest.mark.gpu), + ], +) +@pytest.mark.parametrize( + "function_name", ["read_data_blocksize", "read_data_files_per_partition"] +) +@pytest.mark.parametrize( + "input_meta", [{"id": "str"}, {"text": "str"}, {"id": "str", "text": "str"}] +) +def test_read_data_input_meta( + mock_multiple_jsonl_files, backend, function_name, input_meta +): + if function_name == "read_data_files_per_partition": + func = read_data_files_per_partition + read_kwargs = {"files_per_partition": 1} + elif function_name == "read_data_blocksize": + func = read_data_blocksize + read_kwargs = {"blocksize": "128Mib"} + + df = func( + input_files=mock_multiple_jsonl_files, + backend=backend, + file_type="jsonl", + add_filename=False, + input_meta=input_meta, + columns=None, + **read_kwargs, + ) + + assert list(df.columns) == list(input_meta.keys()) + + +def xfail_inconsistent_schema_jsonl(): + return pytest.mark.xfail( + reason="inconsistent schemas are not supported with jsonl files, " + "see https://github.com/dask/dask/issues/11595" + ) + + +@pytest.mark.parametrize( + "backend", + [ + pytest.param("pandas"), + pytest.param("cudf", marks=[pytest.mark.gpu]), + ], +) +@pytest.mark.parametrize("file_type", ["jsonl", "parquet"]) +@pytest.mark.parametrize("fpp", [1, 3, 5]) +def test_read_data_different_columns_files_per_partition( + mock_multiple_jsonl_files_different_cols, + mock_multiple_parquet_files_different_cols, + backend, + file_type, + fpp, +): + read_kwargs = {"columns": ["id", "text"]} + if file_type == "jsonl": + input_files = mock_multiple_jsonl_files_different_cols + read_kwargs["input_meta"] = {"id": "str", "text": "str"} + elif file_type == "parquet": + input_files = mock_multiple_parquet_files_different_cols + if backend == "cudf": + read_kwargs["allow_mismatched_pq_schemas"] = True + + df = read_data( + input_files=input_files, + file_type=file_type, + backend=backend, + add_filename=False, + files_per_partition=fpp, + blocksize=None, + **read_kwargs, + ) + assert list(df.columns) == ["id", "text"] + assert list(df.compute().columns) == ["id", "text"] + with tempfile.TemporaryDirectory() as tmpdir: + df.to_parquet(tmpdir) + assert len(df) == NUM_FILES * NUM_RECORDS + + +@pytest.mark.parametrize( + "backend,file_type", + [ + pytest.param( + "cudf", "jsonl", marks=[pytest.mark.gpu, xfail_inconsistent_schema_jsonl()] + ), + pytest.param("pandas", "jsonl", marks=[xfail_inconsistent_schema_jsonl()]), + pytest.param( + "cudf", + "parquet", + marks=[pytest.mark.gpu] + + ( + [xfail_inconsistent_schema_jsonl()] + if not DASK_CUDF_PARQUET_READ_INCONSISTENT_SCHEMA + else [] + ), + ), + pytest.param("pandas", "parquet"), + ], +) +@pytest.mark.parametrize("blocksize", ["1kb", "5kb", "10kb"]) +def test_read_data_different_columns_blocksize( + mock_multiple_jsonl_files_different_cols, + mock_multiple_parquet_files_different_cols, + backend, + file_type, + blocksize, +): + read_kwargs = {"columns": ["id", "text"]} + read_kwargs["columns"] = ["id", "text"] + if file_type == "jsonl": + input_files = mock_multiple_jsonl_files_different_cols + read_kwargs["input_meta"] = {"id": "str", "text": "str"} + elif file_type == "parquet": + input_files = mock_multiple_parquet_files_different_cols + if backend == "cudf": + read_kwargs["allow_mismatched_pq_schemas"] = True + + df = read_data( + input_files=input_files, + file_type=file_type, + blocksize=blocksize, + files_per_partition=None, + backend=backend, + add_filename=False, + **read_kwargs, + ) + assert list(df.columns) == ["id", "text"] + assert list(df.compute().columns) == ["id", "text"] + with tempfile.TemporaryDirectory() as tmpdir: + df.to_parquet(tmpdir) + assert len(df) == NUM_FILES * NUM_RECORDS diff --git a/tests/test_separate_by_metadata.py b/tests/test_separate_by_metadata.py index 3e01d7f0..020bf21d 100644 --- a/tests/test_separate_by_metadata.py +++ b/tests/test_separate_by_metadata.py @@ -61,6 +61,7 @@ def test_metadatasep( str(data_dir), backend=backend, files_per_partition=files_per_partition, + blocksize=None, add_filename=True, ).df separate_by_metadata( @@ -80,6 +81,7 @@ def test_metadatasep( str(output_dir / metadata), backend=backend, files_per_partition=1, + blocksize=None, add_filename=True, ).df dfs.append(meta_df)