diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d179a2a57..baa968f47 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -40,9 +40,8 @@ jobs: # Explicitly install cython: https://github.com/VKCOM/YouTokenToMe/issues/94 run: | pip install wheel cython - pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com . + pip install --no-cache-dir . pip install pytest - name: Run tests - # TODO: Remove env variable when gpu dependencies are optional run: | - RAPIDS_NO_INITIALIZE=1 python -m pytest -v --cpu + python -m pytest -v --cpu diff --git a/README.md b/README.md index eb8c37abe..a17a573eb 100644 --- a/README.md +++ b/README.md @@ -37,12 +37,20 @@ These modules are designed to be flexible and allow for reordering with few exce ## Installation -NeMo Curator currently requires Python 3.10 and a GPU with CUDA 12 or above installed in order to be used. +NeMo Curator currently requires Python 3.10 and the GPU accelerated modules require CUDA 12 or above installed in order to be used. -NeMo Curator can be installed manually by cloning the repository and installing as follows: +NeMo Curator can be installed manually by cloning the repository and installing as follows - + +For CPU only modules: +``` +pip install . ``` -pip install --extra-index-url https://pypi.nvidia.com . + +For CPU + CUDA accelerated modules ``` +pip install --extra-index-url https://pypi.nvidia.com ".[cuda12x]" +``` + ### NeMo Framework Container NeMo Curator is available in the [NeMo Framework Container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo). The NeMo Framework Container provides an end-to-end platform for development of custom generative AI models anywhere. The latest release of NeMo Curator comes preinstalled in the container. diff --git a/nemo_curator/datasets/doc_dataset.py b/nemo_curator/datasets/doc_dataset.py index af45f290c..37592b188 100644 --- a/nemo_curator/datasets/doc_dataset.py +++ b/nemo_curator/datasets/doc_dataset.py @@ -13,7 +13,6 @@ # limitations under the License. import dask.dataframe as dd -import dask_cudf from nemo_curator.utils.distributed_utils import read_data, write_to_disk from nemo_curator.utils.file_utils import get_all_files_paths_under @@ -182,10 +181,7 @@ def _read_json_or_parquet( ) dfs.append(df) - if backend == "cudf": - raw_data = dask_cudf.concat(dfs, ignore_unknown_divisions=True) - else: - raw_data = dd.concat(dfs, ignore_unknown_divisions=True) + raw_data = dd.concat(dfs, ignore_unknown_divisions=True) elif isinstance(input_files, str): # Single file diff --git a/nemo_curator/gpu_deduplication/utils.py b/nemo_curator/gpu_deduplication/utils.py index ed69477be..f6faefe77 100644 --- a/nemo_curator/gpu_deduplication/utils.py +++ b/nemo_curator/gpu_deduplication/utils.py @@ -13,84 +13,8 @@ # limitations under the License. import argparse -import logging -import os -import socket -from contextlib import nullcontext from time import time -import cudf -from dask_cuda import LocalCUDACluster -from distributed import Client, performance_report - - -def create_logger(rank, log_file, name="logger", log_level=logging.INFO): - # Create the logger - logger = logging.getLogger(name) - logger.setLevel(log_level) - - myhost = socket.gethostname() - - extra = {"host": myhost, "rank": rank} - formatter = logging.Formatter( - "%(asctime)s | %(host)s | Rank %(rank)s | %(message)s" - ) - - # File handler for output - file_handler = logging.FileHandler(log_file, mode="a") - file_handler.setFormatter(formatter) - logger.addHandler(file_handler) - logger = logging.LoggerAdapter(logger, extra) - - return logger - - -# TODO: Remove below to use nemo_curator.distributed_utils.get_client -def get_client(args) -> Client: - if args.scheduler_address: - if args.scheduler_file: - raise ValueError( - "Only one of scheduler_address or scheduler_file can be provided" - ) - else: - return Client(address=args.scheduler_address, timeout="30s") - elif args.scheduler_file: - return Client(scheduler_file=args.scheduler_file, timeout="30s") - else: - extra_kwargs = ( - { - "enable_tcp_over_ucx": True, - "enable_nvlink": True, - "enable_infiniband": False, - "enable_rdmacm": False, - } - if args.nvlink_only and args.protocol == "ucx" - else {} - ) - - cluster = LocalCUDACluster( - rmm_pool_size=args.rmm_pool_size, - protocol=args.protocol, - rmm_async=True, - **extra_kwargs, - ) - return Client(cluster) - - -def performance_report_if(path=None, report_name="dask-profile.html"): - if path is not None: - return performance_report(os.path.join(path, report_name)) - else: - return nullcontext() - - -# TODO: Remove below to use nemo_curator.distributed_utils._enable_spilling -def enable_spilling(): - """ - Enables spilling to host memory for cudf - """ - cudf.set_option("spill", True) - def get_num_workers(client): """ diff --git a/nemo_curator/modules/__init__.py b/nemo_curator/modules/__init__.py index d7c099803..434ebecf4 100644 --- a/nemo_curator/modules/__init__.py +++ b/nemo_curator/modules/__init__.py @@ -19,14 +19,19 @@ # See https://github.com/NVIDIA/NeMo-Curator/issues/31 os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" +from nemo_curator.utils.import_utils import gpu_only_import_from + from .add_id import AddId from .exact_dedup import ExactDuplicates from .filter import Filter, Score, ScoreFilter -from .fuzzy_dedup import LSH, MinHash from .meta import Sequential from .modify import Modify from .task import TaskDecontamination +# GPU packages +LSH = gpu_only_import_from("nemo_curator.modules.fuzzy_dedup", "LSH") +MinHash = gpu_only_import_from("nemo_curator.modules.fuzzy_dedup", "MinHash") + # Pytorch related imports must come after all imports that require cugraph, # because of context cleanup issues b/w pytorch and cugraph # See this issue: https://github.com/rapidsai/cugraph/issues/2718 diff --git a/nemo_curator/modules/exact_dedup.py b/nemo_curator/modules/exact_dedup.py index 5d960ac6e..2831f516f 100644 --- a/nemo_curator/modules/exact_dedup.py +++ b/nemo_curator/modules/exact_dedup.py @@ -28,7 +28,8 @@ from nemo_curator._compat import DASK_P2P_ERROR from nemo_curator.datasets import DocumentDataset -from nemo_curator.gpu_deduplication.utils import create_logger, performance_report_if +from nemo_curator.log import create_logger +from nemo_curator.utils.distributed_utils import performance_report_if from nemo_curator.utils.gpu_utils import is_cudf_type diff --git a/nemo_curator/modules/fuzzy_dedup.py b/nemo_curator/modules/fuzzy_dedup.py index 3b0576058..b51499678 100644 --- a/nemo_curator/modules/fuzzy_dedup.py +++ b/nemo_curator/modules/fuzzy_dedup.py @@ -22,12 +22,12 @@ from typing import List, Tuple, Union import cudf -import cugraph import cugraph.dask as dcg import cugraph.dask.comms.comms as Comms import cupy as cp import dask_cudf import numpy as np +from cugraph import MultiGraph from dask import dataframe as dd from dask.dataframe.shuffle import shuffle as dd_shuffle from dask.utils import M @@ -39,12 +39,13 @@ filter_text_rows_by_bucket_batch, merge_left_to_shuffled_right, ) -from nemo_curator.gpu_deduplication.utils import create_logger, performance_report_if -from nemo_curator.utils.distributed_utils import get_current_client, get_num_workers -from nemo_curator.utils.fuzzy_dedup_utils.id_mapping import ( - convert_str_id_to_int, - int_ids_to_str, +from nemo_curator.log import create_logger +from nemo_curator.utils.distributed_utils import ( + get_current_client, + get_num_workers, + performance_report_if, ) +from nemo_curator.utils.fuzzy_dedup_utils.id_mapping import int_ids_to_str from nemo_curator.utils.fuzzy_dedup_utils.io_utils import ( aggregated_anchor_docs_with_bk_read, get_restart_offsets, @@ -1120,7 +1121,7 @@ def _run_connected_components( df = df[[self.left_id, self.right_id]].astype(np.int64) df = dask_cudf.concat([df, self_edge_df]) - G = cugraph.MultiGraph(directed=False) + G = MultiGraph(directed=False) G.from_dask_cudf_edgelist( df, source=self.left_id, destination=self.right_id, renumber=False ) diff --git a/nemo_curator/scripts/compute_minhashes.py b/nemo_curator/scripts/compute_minhashes.py index c7a7e68b2..044653ceb 100644 --- a/nemo_curator/scripts/compute_minhashes.py +++ b/nemo_curator/scripts/compute_minhashes.py @@ -18,12 +18,13 @@ from nemo_curator import MinHash from nemo_curator.datasets import DocumentDataset from nemo_curator.gpu_deduplication.ioutils import strip_trailing_sep -from nemo_curator.gpu_deduplication.utils import ( - create_logger, - parse_nc_args, +from nemo_curator.gpu_deduplication.utils import parse_nc_args +from nemo_curator.log import create_logger +from nemo_curator.utils.distributed_utils import ( + get_client, performance_report_if, + read_data, ) -from nemo_curator.utils.distributed_utils import get_client, read_data from nemo_curator.utils.file_utils import get_all_files_paths_under diff --git a/nemo_curator/scripts/connected_components.py b/nemo_curator/scripts/connected_components.py index 1ab1282af..c04f0349d 100644 --- a/nemo_curator/scripts/connected_components.py +++ b/nemo_curator/scripts/connected_components.py @@ -15,7 +15,7 @@ import os import time -from nemo_curator.gpu_deduplication.utils import enable_spilling, parse_nc_args +from nemo_curator.gpu_deduplication.utils import parse_nc_args from nemo_curator.modules.fuzzy_dedup import ConnectedComponents from nemo_curator.utils.distributed_utils import get_client @@ -32,9 +32,10 @@ def main(args): st = time.time() output_path = os.path.join(args.output_dir, "connected_components.parquet") args.set_torch_to_use_rmm = False + args.enable_spilling = True + client = get_client(args, cluster_type="gpu") - enable_spilling() - client.run(enable_spilling) + components_stage = ConnectedComponents( cache_dir=args.cache_dir, jaccard_pairs_path=args.jaccard_pairs_path, diff --git a/nemo_curator/scripts/find_exact_duplicates.py b/nemo_curator/scripts/find_exact_duplicates.py index 7da01ea8e..16173861d 100644 --- a/nemo_curator/scripts/find_exact_duplicates.py +++ b/nemo_curator/scripts/find_exact_duplicates.py @@ -19,7 +19,8 @@ from nemo_curator.datasets import DocumentDataset from nemo_curator.gpu_deduplication.ioutils import strip_trailing_sep -from nemo_curator.gpu_deduplication.utils import create_logger, parse_nc_args +from nemo_curator.gpu_deduplication.utils import parse_nc_args +from nemo_curator.log import create_logger from nemo_curator.modules import ExactDuplicates from nemo_curator.utils.distributed_utils import get_client, read_data from nemo_curator.utils.file_utils import get_all_files_paths_under diff --git a/nemo_curator/scripts/jaccard_compute.py b/nemo_curator/scripts/jaccard_compute.py index f59157164..d16e95654 100644 --- a/nemo_curator/scripts/jaccard_compute.py +++ b/nemo_curator/scripts/jaccard_compute.py @@ -15,13 +15,13 @@ import os import time -from nemo_curator.gpu_deduplication.utils import enable_spilling, parse_nc_args +from nemo_curator.gpu_deduplication.utils import parse_nc_args from nemo_curator.modules.fuzzy_dedup import JaccardSimilarity from nemo_curator.utils.distributed_utils import get_client, get_num_workers def main(args): - description = """Computes the Jaccard similarity between document pairs + """Computes the Jaccard similarity between document pairs from partitioned parquet dataset. Result is a parquet dataset consiting of document id pair along with their Jaccard similarity score. """ @@ -30,9 +30,9 @@ def main(args): output_final_results_path = os.path.join( OUTPUT_PATH, "jaccard_similarity_results.parquet" ) + args.enable_spilling = True client = get_client(args, "gpu") - enable_spilling() - client.run(enable_spilling) + print(f"Num Workers = {get_num_workers(client)}", flush=True) print("Connected to dask cluster", flush=True) print("Running jaccard compute script", flush=True) diff --git a/nemo_curator/scripts/jaccard_shuffle.py b/nemo_curator/scripts/jaccard_shuffle.py index dc5d20f9b..c01935a61 100644 --- a/nemo_curator/scripts/jaccard_shuffle.py +++ b/nemo_curator/scripts/jaccard_shuffle.py @@ -15,12 +15,9 @@ import os import time -from nemo_curator.gpu_deduplication.utils import ( - get_client, - get_num_workers, - parse_nc_args, -) +from nemo_curator.gpu_deduplication.utils import get_num_workers, parse_nc_args from nemo_curator.modules.fuzzy_dedup import _Shuffle +from nemo_curator.utils.distributed_utils import get_client from nemo_curator.utils.fuzzy_dedup_utils.io_utils import ( get_text_ddf_from_json_path_with_blocksize, ) @@ -38,7 +35,7 @@ def main(args): OUTPUT_PATH = args.output_dir output_shuffled_docs_path = os.path.join(OUTPUT_PATH, "shuffled_docs.parquet") - client = get_client(args) + client = get_client(args, "gpu") client.run(func) print(f"Num Workers = {get_num_workers(client)}", flush=True) print("Connected to dask cluster", flush=True) diff --git a/nemo_curator/scripts/map_buckets.py b/nemo_curator/scripts/map_buckets.py index 522e4f417..9e3f71a51 100644 --- a/nemo_curator/scripts/map_buckets.py +++ b/nemo_curator/scripts/map_buckets.py @@ -15,12 +15,9 @@ import os import time -from nemo_curator.gpu_deduplication.utils import ( - get_client, - get_num_workers, - parse_nc_args, -) +from nemo_curator.gpu_deduplication.utils import get_num_workers, parse_nc_args from nemo_curator.modules.fuzzy_dedup import _MapBuckets +from nemo_curator.utils.distributed_utils import get_client from nemo_curator.utils.fuzzy_dedup_utils.io_utils import ( get_bucket_ddf_from_parquet_path, get_text_ddf_from_json_path_with_blocksize, @@ -157,7 +154,7 @@ def main(args): output_anchor_docs_with_bk_path = os.path.join( OUTPUT_PATH, "anchor_docs_with_bk.parquet" ) - client = get_client(args) + client = get_client(args, "gpu") print(f"Num Workers = {get_num_workers(client)}", flush=True) print("Connected to dask cluster", flush=True) print("Running jaccard map buckets script", flush=True) diff --git a/nemo_curator/scripts/minhash_lsh.py b/nemo_curator/scripts/minhash_lsh.py index fb2c6a90d..ec206dc10 100644 --- a/nemo_curator/scripts/minhash_lsh.py +++ b/nemo_curator/scripts/minhash_lsh.py @@ -24,7 +24,8 @@ from nemo_curator.gpu_deduplication.jaccard_utils.doc_id_mapping import ( convert_str_id_to_int, ) -from nemo_curator.gpu_deduplication.utils import create_logger, parse_nc_args +from nemo_curator.gpu_deduplication.utils import parse_nc_args +from nemo_curator.log import create_logger from nemo_curator.utils.distributed_utils import get_client diff --git a/nemo_curator/utils/distributed_utils.py b/nemo_curator/utils/distributed_utils.py index 71fa1cdca..2d7dc9213 100644 --- a/nemo_curator/utils/distributed_utils.py +++ b/nemo_curator/utils/distributed_utils.py @@ -11,20 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import os os.environ["RAPIDS_NO_INITIALIZE"] = "1" import warnings +from contextlib import nullcontext from pathlib import Path from typing import Union -import cudf import dask.dataframe as dd -import dask_cudf import pandas as pd -from dask.distributed import Client, LocalCluster, get_worker -from dask_cuda import LocalCUDACluster +from dask.distributed import Client, LocalCluster, get_worker, performance_report + +from nemo_curator.utils.gpu_utils import GPU_INSTALL_STRING, is_cudf_type +from nemo_curator.utils.import_utils import gpu_only_import, gpu_only_import_from + +cudf = gpu_only_import("cudf") +LocalCUDACluster = gpu_only_import_from("dask_cuda", "LocalCUDACluster") class DotDict: @@ -48,7 +53,6 @@ def start_dask_gpu_local_cluster(args) -> Client: GPUs present on the machine. """ - # Setting conservative defaults # which should work across most systems nvlink_only = getattr(args, "nvlink_only", False) @@ -166,6 +170,8 @@ def _enable_spilling(): i.e., computing on objects that occupy more memory than is available on the GPU. """ + import cudf + cudf.set_option("spill", True) @@ -265,6 +271,10 @@ def read_data( A Dask-cuDF or a Dask-pandas DataFrame. """ + if backend == "cudf": + # Try using cuDF. If not availible will throw an error. + test_obj = cudf.Series + if file_type == "pickle": df = read_pandas_pickle(input_files[0], add_filename=add_filename) df = dd.from_pandas(df, npartitions=16) @@ -369,10 +379,12 @@ def single_partition_write_with_filename(df, output_file_dir, output_type="jsonl warnings.warn(f"Empty partition found") empty_partition = False - if isinstance(df, pd.DataFrame): - success_ser = pd.Series([empty_partition]) - else: + if is_cudf_type(df): + import cudf + success_ser = cudf.Series([empty_partition]) + else: + success_ser = pd.Series([empty_partition]) if empty_partition: filename = df.filename.iloc[0] @@ -425,10 +437,13 @@ def write_to_disk(df, output_file_dir, write_to_filename=False, output_type="jso ) if write_to_filename: - if isinstance(df, dd.DataFrame): - output_meta = pd.Series([True], dtype="bool") - else: + if is_cudf_type(df): + import cudf + output_meta = cudf.Series([True]) + else: + output_meta = pd.Series([True], dtype="bool") + os.makedirs(output_file_dir, exist_ok=True) output = df.map_partitions( single_partition_write_with_filename, @@ -440,7 +455,7 @@ def write_to_disk(df, output_file_dir, write_to_filename=False, output_type="jso output = output.compute() else: if output_type == "jsonl": - if isinstance(df, dask_cudf.DataFrame): + if is_cudf_type(df): # See open issue here: https://github.com/rapidsai/cudf/issues/15211 # df.to_json(output_file_dir, orient="records", lines=True, engine="cudf", force_ascii=False) df.to_json( @@ -521,3 +536,10 @@ def get_current_client(): return Client.current() except ValueError: return None + + +def performance_report_if(path=None, report_name="dask-profile.html"): + if path is not None: + return performance_report(os.path.join(path, report_name)) + else: + return nullcontext() diff --git a/nemo_curator/utils/gpu_utils.py b/nemo_curator/utils/gpu_utils.py index de1c23dfe..86ba888fc 100644 --- a/nemo_curator/utils/gpu_utils.py +++ b/nemo_curator/utils/gpu_utils.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +GPU_INSTALL_STRING = """Install GPU packages via `pip install --extra-index-url https://pypi.nvidia.com nemo_curator[cuda12x]` +or use `pip install --extra-index-url https://pypi.nvidia.com ".[cuda12x]"` if installing from source""" + def is_cudf_type(obj): """ diff --git a/nemo_curator/utils/import_utils.py b/nemo_curator/utils/import_utils.py new file mode 100644 index 000000000..ea78e4597 --- /dev/null +++ b/nemo_curator/utils/import_utils.py @@ -0,0 +1,384 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is adapted from cuML's safe_imports module: +# https://github.com/rapidsai/cuml/blob/e93166ea0dddfa8ef2f68c6335012af4420bc8ac/python/cuml/internals/safe_imports.py + + +import importlib +import logging +import traceback +from contextlib import contextmanager + +from nemo_curator.utils.gpu_utils import GPU_INSTALL_STRING + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +logger.addHandler(logging.StreamHandler()) + + +class UnavailableError(Exception): + """Error thrown if a symbol is unavailable due to an issue importing it""" + + +@contextmanager +def null_decorator(*args, **kwargs): + if len(kwargs) == 0 and len(args) == 1 and callable(args[0]): + return args[0] + else: + + def inner(func): + return func + + return inner + + +class UnavailableMeta(type): + """A metaclass for generating placeholder objects for unavailable symbols + + This metaclass allows errors to be deferred from import time to the time + that a symbol is actually used in order to streamline the usage of optional + dependencies. This is particularly useful for attempted imports of GPU-only + modules which will only be invoked if GPU-only functionality is + specifically used. + + If an attempt to import a symbol fails, this metaclass is used to generate + a class which stands in for that symbol. Any attempt to call the symbol + (instantiate the class) or access its attributes will throw an + UnavailableError exception. Furthermore, this class can be used in + e.g. isinstance checks, since it will (correctly) fail to match any + instance it is compared against. + + In addition to calls and attribute access, a number of dunder methods are + implemented so that other common usages of imported symbols (e.g. + arithmetic) throw an UnavailableError, but this is not guaranteed for + all possible uses. In such cases, other exception types (typically + TypeErrors) will be thrown instead. + """ + + def __new__(meta, name, bases, dct): + if dct.get("_msg", None) is None: + dct["_msg"] = f"{name} could not be imported" + name = f"MISSING{name}" + return super(UnavailableMeta, meta).__new__(meta, name, bases, dct) + + def __call__(cls, *args, **kwargs): + raise UnavailableError(cls._msg) + + def __getattr__(cls, name): + raise UnavailableError(cls._msg) + + def __eq__(cls, other): + raise UnavailableError(cls._msg) + + def __lt__(cls, other): + raise UnavailableError(cls._msg) + + def __gt__(cls, other): + raise UnavailableError(cls._msg) + + def __ne__(cls, other): + raise UnavailableError(cls._msg) + + def __abs__(cls, other): + raise UnavailableError(cls._msg) + + def __add__(cls, other): + raise UnavailableError(cls._msg) + + def __radd__(cls, other): + raise UnavailableError(cls._msg) + + def __iadd__(cls, other): + raise UnavailableError(cls._msg) + + def __floordiv__(cls, other): + raise UnavailableError(cls._msg) + + def __rfloordiv__(cls, other): + raise UnavailableError(cls._msg) + + def __ifloordiv__(cls, other): + raise UnavailableError(cls._msg) + + def __lshift__(cls, other): + raise UnavailableError(cls._msg) + + def __rlshift__(cls, other): + raise UnavailableError(cls._msg) + + def __mul__(cls, other): + raise UnavailableError(cls._msg) + + def __rmul__(cls, other): + raise UnavailableError(cls._msg) + + def __imul__(cls, other): + raise UnavailableError(cls._msg) + + def __ilshift__(cls, other): + raise UnavailableError(cls._msg) + + def __pow__(cls, other): + raise UnavailableError(cls._msg) + + def __rpow__(cls, other): + raise UnavailableError(cls._msg) + + def __ipow__(cls, other): + raise UnavailableError(cls._msg) + + def __rshift__(cls, other): + raise UnavailableError(cls._msg) + + def __rrshift__(cls, other): + raise UnavailableError(cls._msg) + + def __irshift__(cls, other): + raise UnavailableError(cls._msg) + + def __sub__(cls, other): + raise UnavailableError(cls._msg) + + def __rsub__(cls, other): + raise UnavailableError(cls._msg) + + def __isub__(cls, other): + raise UnavailableError(cls._msg) + + def __truediv__(cls, other): + raise UnavailableError(cls._msg) + + def __rtruediv__(cls, other): + raise UnavailableError(cls._msg) + + def __itruediv__(cls, other): + raise UnavailableError(cls._msg) + + def __divmod__(cls, other): + raise UnavailableError(cls._msg) + + def __rdivmod__(cls, other): + raise UnavailableError(cls._msg) + + def __neg__(cls): + raise UnavailableError(cls._msg) + + def __invert__(cls): + raise UnavailableError(cls._msg) + + def __hash__(cls): + raise UnavailableError(cls._msg) + + def __index__(cls): + raise UnavailableError(cls._msg) + + def __iter__(cls): + raise UnavailableError(cls._msg) + + def __delitem__(cls, name): + raise UnavailableError(cls._msg) + + def __setitem__(cls, name, value): + raise UnavailableError(cls._msg) + + def __enter__(cls, *args, **kwargs): + raise UnavailableError(cls._msg) + + def __get__(cls, *args, **kwargs): + raise UnavailableError(cls._msg) + + def __delete__(cls, *args, **kwargs): + raise UnavailableError(cls._msg) + + def __len__(cls): + raise UnavailableError(cls._msg) + + +def is_unavailable(obj): + """Helper to check if given symbol is actually a placeholder""" + return type(obj) is UnavailableMeta + + +class UnavailableNullContext: + """A placeholder class for unavailable context managers + + This context manager will return a value which will throw an + UnavailableError if used in any way, but the context manager itself can be + safely invoked. + """ + + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return UnavailableMeta( + "MissingContextValue", + (), + {"_msg": "Attempted to make use of placeholder context return value."}, + ) + + def __exit__(self, *args, **kwargs): + pass + + +def safe_import(module, *, msg=None, alt=None): + """A function used to import modules that may not be available + + This function will attempt to import a module with the given name, but it + will not throw an ModuleNotFoundError if the module is not found. Instead, it will + return a placeholder object which will raise an exception only if used. + + Parameters + ---------- + module: str + The name of the module to import. + msg: str or None + An optional error message to be displayed if this module is used + after a failed import. + alt: object + An optional module to be used in place of the given module if it + fails to import + + Returns + ------- + object + The imported module, the given alternate, or a class derived from + UnavailableMeta. + """ + try: + return importlib.import_module(module) + except ModuleNotFoundError: + exception_text = traceback.format_exc() + logger.debug(f"Import of {module} failed with: {exception_text}") + except Exception: + exception_text = traceback.format_exc() + raise + if msg is None: + msg = f"{module} could not be imported" + if alt is None: + return UnavailableMeta(module.rsplit(".")[-1], (), {"_msg": msg}) + else: + return alt + + +def safe_import_from(module, symbol, *, msg=None, alt=None): + """A function used to import symbols from modules that may not be available + + This function will attempt to import a symbol with the given name from + the given module, but it will not throw an ImportError if the symbol is not + found. Instead, it will return a placeholder object which will raise an + exception only if used. + + Parameters + ---------- + module: str + The name of the module in which the symbol is defined. + symbol: str + The name of the symbol to import. + msg: str or None + An optional error message to be displayed if this symbol is used + after a failed import. + alt: object + An optional object to be used in place of the given symbol if it fails + to import + + Returns + ------- + object + The imported symbol, the given alternate, or a class derived from + UnavailableMeta. + """ + try: + imported_module = importlib.import_module(module) + return getattr(imported_module, symbol) + except ModuleNotFoundError: + exception_text = traceback.format_exc() + logger.debug(f"Import of {module} failed with: {exception_text}") + except AttributeError: + exception_text = traceback.format_exc() + logger.info(f"Import of {symbol} from {module} failed with: {exception_text}") + except Exception: + exception_text = traceback.format_exc() + raise + if msg is None: + msg = f"{module}.{symbol} could not be imported" + if alt is None: + return UnavailableMeta(symbol, (), {"_msg": msg}) + else: + return alt + + +def gpu_only_import(module, *, alt=None): + """A function used to import modules required only in GPU installs + + This function will attempt to import a module with the given name. + This function will attempt to import a symbol with the given name from + the given module, but it will not throw an ImportError if the symbol is not + found. Instead, it will return a placeholder object which will raise an + exception only if used with instructions on installing a GPU build. + + Parameters + ---------- + module: str + The name of the module to import. + alt: object + An optional module to be used in place of the given module if it + fails to import in a non-GPU-enabled install + + Returns + ------- + object + The imported module, the given alternate, or a class derived from + UnavailableMeta. + """ + + return safe_import( + module, + msg=f"{module} is not installed in non GPU-enabled installations. {GPU_INSTALL_STRING}", + alt=alt, + ) + + +def gpu_only_import_from(module, symbol, *, alt=None): + """A function used to import symbols required only in GPU installs + + This function will attempt to import a module with the given name. + This function will attempt to import a symbol with the given name from + the given module, but it will not throw an ImportError if the symbol is not + found. Instead, it will return a placeholder object which will raise an + exception only if used with instructions on installing a GPU build. + + Parameters + ---------- + module: str + The name of the module to import. + symbol: str + The name of the symbol to import. + alt: object + An optional object to be used in place of the given symbol if it fails + to import in a non-GPU-enabled install + + Returns + ------- + object + The imported symbol, the given alternate, or a class derived from + UnavailableMeta. + """ + return safe_import_from( + module, + symbol, + msg=f"{module}.{symbol} is not installed in non GPU-enabled installations. {GPU_INSTALL_STRING}", + alt=alt, + ) diff --git a/setup.py b/setup.py index b47ef5c95..8fc60e926 100644 --- a/setup.py +++ b/setup.py @@ -55,10 +55,6 @@ "comment_parser", "beautifulsoup4", "mwparserfromhell @ git+https://github.com/earwig/mwparserfromhell.git@0f89f44", - "cudf-cu12>=24.2", - "dask-cudf-cu12>=24.2", - "cugraph-cu12>=24.2", - "dask-cuda>=24.2", "spacy>=3.6.0, <4.0.0", "presidio-analyzer==2.2.351", "presidio-anonymizer==2.2.351", @@ -68,6 +64,15 @@ # due to this: https://github.com/miso-belica/jusText/issues/47 "lxml[html_clean]", ], + extras_require={ + "cuda12x": [ + "cudf-cu12>=24.2", + "dask-cudf-cu12>=24.2", + "cugraph-cu12>=24.2", + "dask-cuda>=24.2", + "spacy[cuda12x]>=3.6.0, <4.0.0", + ] + }, entry_points={ "console_scripts": [ "get_common_crawl_urls=nemo_curator.scripts.get_common_crawl_urls:console_script", diff --git a/tests/test_filters.py b/tests/test_filters.py index 11bf57388..4ab11c21a 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -149,7 +149,9 @@ def test_retain_score_filter(self, letter_count_data): filtered_data = filter_step(letter_count_data) expected_indices = [2, 3] - expected_data = DocumentDataset(letter_count_data.df.loc[expected_indices]) + # Compute before loc due to https://github.com/dask/dask-expr/issues/1036 + expected_data = letter_count_data.df.compute().loc[expected_indices] + expected_data = DocumentDataset(dd.from_pandas(expected_data, 2)) expected_data.df[score_field] = pd.Series([5, 7], index=expected_data.df.index) assert all_equal( expected_data, filtered_data @@ -168,7 +170,9 @@ def test_filter(self, letter_count_data): filtered_data = filter_step(scored_data) expected_indices = [2, 3] - expected_data = letter_count_data.df.loc[expected_indices] + # Compute before loc due to https://github.com/dask/dask-expr/issues/1036 + expected_data = letter_count_data.df.compute().loc[expected_indices] + expected_data = dd.from_pandas(expected_data, 2) expected_data[score_field] = pd.Series([5, 7], index=expected_data.index) expected_data = DocumentDataset(expected_data) assert all_equal( diff --git a/tests/test_fuzzy_dedup.py b/tests/test_fuzzy_dedup.py index 3c6a32754..a1acb901f 100644 --- a/tests/test_fuzzy_dedup.py +++ b/tests/test_fuzzy_dedup.py @@ -16,14 +16,16 @@ from itertools import combinations from typing import Iterable -import cudf -import dask_cudf import numpy as np import pytest from dask.dataframe.utils import assert_eq from nemo_curator.datasets import DocumentDataset from nemo_curator.modules import LSH, MinHash +from nemo_curator.utils.import_utils import gpu_only_import + +cudf = gpu_only_import("cudf") +dask_cudf = gpu_only_import("dask_cudf") @pytest.fixture