From ac3a96ba6de862e35990c80950255d93d328eca1 Mon Sep 17 00:00:00 2001 From: Robert Jensen <5550520+r1b@users.noreply.github.com> Date: Fri, 15 Jan 2021 12:03:07 -0500 Subject: [PATCH] feat: add PASS_CONTEXTVARS --- pynetdicom/_config.py | 20 +++++++++++++ pynetdicom/ae.py | 4 +-- pynetdicom/association.py | 5 ++-- pynetdicom/dimse.py | 3 +- pynetdicom/dul.py | 5 ++-- pynetdicom/tests/test_utils.py | 53 +++++++++++++++++++++++++++++++++- pynetdicom/utils.py | 10 +++++++ 7 files changed, 92 insertions(+), 8 deletions(-) diff --git a/pynetdicom/_config.py b/pynetdicom/_config.py index 69e2839056..e7eec9dec0 100644 --- a/pynetdicom/_config.py +++ b/pynetdicom/_config.py @@ -151,3 +151,23 @@ >>> from pynetdicom import _config >>> _config.STORE_RECV_CHUNKED_DATASET = True """ + +PASS_CONTEXTVARS = False +"""Pass context-local state to concurrent pynetdicom code. + +.. versionadded:: 2.0 + +If ``True``, then any ``contextvars.ContextVar`` instances defined in the +calling context will be made available to pynetdicom's concurrent contexts. +This allows the caller to define contextual behavior without modifying +pynetdicom. For example, one could add a logging filter to the pynetdicom +logger that references an externally defined ``contextvars.ContextVar``. + +Default: ``False``. + +Examples +-------- + +>>> from pynetdicom import _config +>>> _config.PASS_CONTEXTVARS = True +""" diff --git a/pynetdicom/ae.py b/pynetdicom/ae.py index f490265d0a..401f95426a 100644 --- a/pynetdicom/ae.py +++ b/pynetdicom/ae.py @@ -13,7 +13,7 @@ from pynetdicom.transport import ( AssociationSocket, AssociationServer, ThreadedAssociationServer ) -from pynetdicom.utils import validate_ae_title +from pynetdicom.utils import make_target, validate_ae_title from pynetdicom._globals import ( MODE_REQUESTOR, DEFAULT_MAX_LENGTH, @@ -1216,7 +1216,7 @@ def start_server(self, address, block=True, ssl_context=None, ) thread = threading.Thread( - target=server.serve_forever, + target=make_target(server.serve_forever), name=f"AcceptorServer@{timestamp}" ) thread.daemon = True diff --git a/pynetdicom/association.py b/pynetdicom/association.py index c7f267c75d..3a720ab28b 100644 --- a/pynetdicom/association.py +++ b/pynetdicom/association.py @@ -49,6 +49,7 @@ UnifiedProcedureStepQuerySOPClass ) from pynetdicom.status import code_to_category, STORAGE_SERVICE_CLASS_STATUS +from pynetdicom.utils import make_target # pylint: enable=no-name-in-module @@ -150,7 +151,7 @@ def __init__(self, ae, mode): self._is_paused = False # Thread setup - threading.Thread.__init__(self) + threading.Thread.__init__(self, target=make_target(self.run_reactor)) self.daemon = True def abort(self): @@ -596,7 +597,7 @@ def request(self): LOGGER.info("Requesting Association") self.acse.negotiate_association() - def run(self): + def run_reactor(self): """The main :class:`Association` reactor.""" # Start the DUL thread if not already started if not self._started_dul: diff --git a/pynetdicom/dimse.py b/pynetdicom/dimse.py index a8f333908d..40f25764ba 100644 --- a/pynetdicom/dimse.py +++ b/pynetdicom/dimse.py @@ -20,6 +20,7 @@ C_STORE, C_FIND, C_GET, C_MOVE, C_ECHO, C_CANCEL, N_EVENT_REPORT, N_GET, N_SET, N_ACTION, N_CREATE, N_DELETE, ) +from pynetdicom.utils import make_target LOGGER = logging.getLogger('pynetdicom.dimse') @@ -269,7 +270,7 @@ def receive_primitive(self, primitive): # N-EVENT-REPORT service requests are handled immediately # Ugly hack, but would block the DUL otherwise t = threading.Thread( - target=self.assoc._serve_request, + target=make_target(self.assoc._serve_request), args=(primitive, context_id) ) t.start() diff --git a/pynetdicom/dul.py b/pynetdicom/dul.py index 2e95170378..0d4ad8b7f8 100644 --- a/pynetdicom/dul.py +++ b/pynetdicom/dul.py @@ -20,6 +20,7 @@ A_ASSOCIATE, A_RELEASE, A_ABORT, A_P_ABORT, P_DATA ) from pynetdicom.timer import Timer +from pynetdicom.utils import make_target LOGGER = logging.getLogger('pynetdicom.dul') @@ -87,7 +88,7 @@ def __init__(self, assoc): # TODO: try and make this event based rather than running loops self._run_loop_delay = 0.001 - Thread.__init__(self) + Thread.__init__(self, target=make_target(self.run_reactor)) self.daemon = False self._kill_thread = False @@ -350,7 +351,7 @@ def receive_pdu(self, wait=False, timeout=None): except queue.Empty: return None - def run(self): + def run_reactor(self): """Run the DUL reactor. The main :class:`threading.Thread` run loop. Runs constantly, checking diff --git a/pynetdicom/tests/test_utils.py b/pynetdicom/tests/test_utils.py index f697f1b0bc..cbdc597129 100644 --- a/pynetdicom/tests/test_utils.py +++ b/pynetdicom/tests/test_utils.py @@ -1,14 +1,16 @@ """Unit tests for the pynetdicom.utils module.""" from io import BytesIO +from threading import Thread import logging +import sys import pytest from pydicom.uid import UID from pynetdicom import _config, debug_logger -from pynetdicom.utils import validate_ae_title, pretty_bytes, validate_uid +from pynetdicom.utils import validate_ae_title, pretty_bytes, validate_uid, make_target from .encoded_pdu_items import a_associate_rq @@ -186,3 +188,52 @@ def test_bytesio(self): result = pretty_bytes(bytestream, prefix='', delimiter='', items_per_line=10) assert isinstance(result[0], str) + + +class TestMakeTarget(object): + """Tests for utils.make_target().""" + @pytest.mark.skipif(sys.version_info[:2] < (3, 7), reason="Branch uncovered in this Python version.") + def test_make_target(self): + """Context Setup""" + from contextvars import ContextVar + foo = ContextVar("foo") + token = foo.set("foo") + + """Test for ``_config.PASS_CONTEXTVARS = False`` (the default).""" + assert _config.PASS_CONTEXTVARS is False + + def target_without_context(): + with pytest.raises(LookupError): + foo.get() + + thread_without_context = Thread(target=make_target(target_without_context)) + thread_without_context.start() + thread_without_context.join() + + """Test for ``_config.PASS_CONTEXTVARS = True``.""" + _config.PASS_CONTEXTVARS = True + + def target_with_context(): + assert foo.get() == "foo" + + thread_with_context = Thread(target=make_target(target_with_context)) + thread_with_context.start() + thread_with_context.join() + + _config.PASS_CONTEXTVARS = False + + """Context Teardown""" + foo.reset(token) + + @pytest.mark.skipif(sys.version_info[:2] >= (3, 7), reason="Branch uncovered in this Python version.") + def test_invalid_python_version(self): + """Test for ``_config.PASS_CONTEXTVARS = True`` and Python < 3.7""" + def noop(): + pass + + _config.PASS_CONTEXTVARS = True + + with pytest.raises(RuntimeError, match="PASS_CONTEXTVARS requires Python >=3.7."): + make_target(noop) + + _config.PASS_CONTEXTVARS = False diff --git a/pynetdicom/utils.py b/pynetdicom/utils.py index 239b3532f8..c0191c93ed 100644 --- a/pynetdicom/utils.py +++ b/pynetdicom/utils.py @@ -183,3 +183,13 @@ def validate_uid(uid): return True return False + + +def make_target(target_fn): + if _config.PASS_CONTEXTVARS: + if sys.version_info[:2] < (3, 7): + raise RuntimeError("PASS_CONTEXTVARS requires Python >=3.7.") + from contextvars import copy_context + ctx = copy_context() + return lambda: ctx.run(target_fn) + return target_fn