Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add context manager to enable/disable patching #48

Open
wants to merge 1 commit into
base: branch-24.06
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion rapids_dask_dependency/dask_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sys
from contextlib import contextmanager

from rapids_dask_dependency.utils import patch_warning_stacklevel
from rapids_dask_dependency.utils import _patching_enabled, patch_warning_stacklevel


class DaskLoader(importlib.machinery.SourceFileLoader):
Expand Down Expand Up @@ -59,6 +59,8 @@ def disable(self, name):
def find_spec(self, fullname: str, _, __=None):
if fullname in self._blocklist:
return None
if not _patching_enabled():
return None
if (
fullname in ("dask", "distributed")
or fullname.startswith("dask.")
Expand Down
28 changes: 28 additions & 0 deletions rapids_dask_dependency/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

import os
import warnings
from contextlib import contextmanager
from functools import lru_cache
Expand All @@ -24,3 +25,30 @@ def patch_warning_stacklevel(level):
warnings.warn = _make_warning_func(level)
yield
warnings.warn = previous_warn


# Default patching behavior depends on the value of the
# `RAPIDS_DASK_PATCHING` environment variable. If this
# environment variable does not exist, patching will be
# enabled. Otherwise, this variable must be set to
# `'True'` for patching to be enabled.


_env = "RAPIDS_DASK_PATCHING"


def _patching_enabled() -> bool:
return os.environ.get(_env, "True") == "True"


@contextmanager
def patching_context(enabled: bool = True):
original = os.environ.get(_env)
os.environ[_env] = "True" if enabled else "False"
try:
yield
finally:
if original is None:
os.environ.pop(_env, None)
else:
os.environ[_env] = "True" if original else "False"
12 changes: 12 additions & 0 deletions tests/test_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,15 @@ def test_distributed_cli_dask_spec_as_module():
print(e.stdout.decode())
print(e.stderr.decode())
raise


@run_test_in_subprocess
def test_dask_patching_disabled():
from rapids_dask_dependency.utils import patching_context

with patching_context(enabled=False):
import dask
import distributed

assert not hasattr(dask, "_rapids_patched")
assert not hasattr(distributed, "_rapids_patched")