Skip to content

Commit

Permalink
feat: add PASS_CONTEXTVARS
Browse files Browse the repository at this point in the history
  • Loading branch information
r1b committed Jan 29, 2021
1 parent f5ce994 commit d7578ad
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 9 deletions.
20 changes: 20 additions & 0 deletions pynetdicom/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,23 @@
>>> from pynetdicom import _config
>>> _config.STORE_RECV_CHUNKED_DATASET = True
"""

PASS_CONTEXTVARS = False
"""Pass context-local state to concurrent pynetdicom code.
.. versionadded:: 2.0
If ``True``, then any ``contextvars.ContextVar`` instances defined in the
calling context will be made available to pynetdicom's concurrent contexts.
This allows the caller to define contextual behavior without modifying
pynetdicom. For example, one could add a logging filter to the pynetdicom
logger that references an externally defined ``contextvars.ContextVar``.
Default: ``False``.
Examples
--------
>>> from pynetdicom import _config
>>> _config.PASS_CONTEXTVARS = True
"""
4 changes: 2 additions & 2 deletions pynetdicom/ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pynetdicom.transport import (
AssociationSocket, AssociationServer, ThreadedAssociationServer
)
from pynetdicom.utils import validate_ae_title
from pynetdicom.utils import make_target, validate_ae_title
from pynetdicom._globals import (
MODE_REQUESTOR,
DEFAULT_MAX_LENGTH,
Expand Down Expand Up @@ -1216,7 +1216,7 @@ def start_server(self, address, block=True, ssl_context=None,
)

thread = threading.Thread(
target=server.serve_forever,
target=make_target(server.serve_forever),
name=f"AcceptorServer@{timestamp}"
)
thread.daemon = True
Expand Down
5 changes: 3 additions & 2 deletions pynetdicom/association.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
UnifiedProcedureStepQuerySOPClass
)
from pynetdicom.status import code_to_category, STORAGE_SERVICE_CLASS_STATUS
from pynetdicom.utils import make_target


# pylint: enable=no-name-in-module
Expand Down Expand Up @@ -150,7 +151,7 @@ def __init__(self, ae, mode):
self._is_paused = False

# Thread setup
threading.Thread.__init__(self)
threading.Thread.__init__(self, target=make_target(self.run_reactor))
self.daemon = True

def abort(self):
Expand Down Expand Up @@ -596,7 +597,7 @@ def request(self):
LOGGER.info("Requesting Association")
self.acse.negotiate_association()

def run(self):
def run_reactor(self):
"""The main :class:`Association` reactor."""
# Start the DUL thread if not already started
if not self._started_dul:
Expand Down
3 changes: 2 additions & 1 deletion pynetdicom/dimse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
C_STORE, C_FIND, C_GET, C_MOVE, C_ECHO, C_CANCEL,
N_EVENT_REPORT, N_GET, N_SET, N_ACTION, N_CREATE, N_DELETE,
)
from pynetdicom.utils import make_target


LOGGER = logging.getLogger('pynetdicom.dimse')
Expand Down Expand Up @@ -269,7 +270,7 @@ def receive_primitive(self, primitive):
# N-EVENT-REPORT service requests are handled immediately
# Ugly hack, but would block the DUL otherwise
t = threading.Thread(
target=self.assoc._serve_request,
target=make_target(self.assoc._serve_request),
args=(primitive, context_id)
)
t.start()
Expand Down
5 changes: 3 additions & 2 deletions pynetdicom/dul.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
A_ASSOCIATE, A_RELEASE, A_ABORT, A_P_ABORT, P_DATA
)
from pynetdicom.timer import Timer
from pynetdicom.utils import make_target


LOGGER = logging.getLogger('pynetdicom.dul')
Expand Down Expand Up @@ -87,7 +88,7 @@ def __init__(self, assoc):
# TODO: try and make this event based rather than running loops
self._run_loop_delay = 0.001

Thread.__init__(self)
Thread.__init__(self, target=make_target(self.run_reactor))
self.daemon = False
self._kill_thread = False

Expand Down Expand Up @@ -350,7 +351,7 @@ def receive_pdu(self, wait=False, timeout=None):
except queue.Empty:
return None

def run(self):
def run_reactor(self):
"""Run the DUL reactor.
The main :class:`threading.Thread` run loop. Runs constantly, checking
Expand Down
53 changes: 52 additions & 1 deletion pynetdicom/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
"""Unit tests for the pynetdicom.utils module."""

from io import BytesIO
from threading import Thread
import logging
import sys

import pytest

from pydicom.uid import UID

from pynetdicom import _config, debug_logger
from pynetdicom.utils import validate_ae_title, pretty_bytes, validate_uid
from pynetdicom.utils import validate_ae_title, pretty_bytes, validate_uid, make_target
from .encoded_pdu_items import a_associate_rq


Expand Down Expand Up @@ -186,3 +188,52 @@ def test_bytesio(self):
result = pretty_bytes(bytestream, prefix='', delimiter='',
items_per_line=10)
assert isinstance(result[0], str)


class TestMakeTarget(object):
"""Tests for utils.make_target()."""
@pytest.mark.skipif(sys.version_info[:2] < (3, 7), reason="Branch uncovered in this Python version.")
def test_make_target(self):
"""Context Setup"""
from contextvars import ContextVar
foo = ContextVar("foo")
token = foo.set("foo")

"""Test for ``_config.PASS_CONTEXTVARS = False`` (the default)."""
assert _config.PASS_CONTEXTVARS is False

def target_without_context():
with pytest.raises(LookupError):
foo.get()

thread_without_context = Thread(target=make_target(target_without_context))
thread_without_context.start()
thread_without_context.join()

"""Test for ``_config.PASS_CONTEXTVARS = True``."""
_config.PASS_CONTEXTVARS = True

def target_with_context():
assert foo.get() == "foo"

thread_with_context = Thread(target=make_target(target_with_context))
thread_with_context.start()
thread_with_context.join()

_config.PASS_CONTEXTVARS = False

"""Context Teardown"""
foo.reset(token)

@pytest.mark.skipif(sys.version_info[:2] >= (3, 7), reason="Branch uncovered in this Python version.")
def test_invalid_python_version(self):
"""Test for ``_config.PASS_CONTEXTVARS = True`` and Python < 3.7"""
def noop():
pass

_config.PASS_CONTEXTVARS = True

with pytest.raises(RuntimeError, match="PASS_CONTEXTVARS requires Python >=3.7."):
make_target(noop)

_config.PASS_CONTEXTVARS = False
32 changes: 31 additions & 1 deletion pynetdicom/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from io import BytesIO
import logging
import sys
import unicodedata

from pynetdicom import _config
Expand Down Expand Up @@ -183,3 +182,34 @@ def validate_uid(uid):
return True

return False


def make_target(target_fn):
"""Wraps `target_fn` in a thunk that passes all contextvars from the
current context. It is assumed that `target_fn` is the target of a new
``threading.Thread``.
Requires:
* Python >=3.7
* :attr:`~pynetdicom._config.PASS_CONTEXTVARS` set ``True``
If the requirements are not met, the original `target_fn` is returned.
Parameters
----------
target_fn : Callable
The function to wrap
Returns
-------
Callable
The wrapped `target_fn` if requirements are met, else the original `target_fn`.
"""
if _config.PASS_CONTEXTVARS:
try:
from contextvars import copy_context
except ImportError as e:
raise RuntimeError("PASS_CONTEXTVARS requires Python >=3.7") from e
ctx = copy_context()
return lambda: ctx.run(target_fn)
return target_fn

0 comments on commit d7578ad

Please sign in to comment.