diff --git a/.coveragerc b/.coveragerc deleted file mode 100644 index 604c14775f..0000000000 --- a/.coveragerc +++ /dev/null @@ -1,37 +0,0 @@ -[run] -branch=True -source=trio -omit= - setup.py -# These are run in subprocesses, but still don't work. We follow -# coverage's documentation to no avail. - */trio/_core/_tests/test_multierror_scripts/* -# Omit the generated files in trio/_core starting with _generated_ - */trio/_core/_generated_* -# Script used to check type completeness that isn't run in tests - */trio/_tests/check_type_completeness.py -# The test suite spawns subprocesses to test some stuff, so make sure -# this doesn't corrupt the coverage files -parallel=True - -[report] -precision = 1 -skip_covered = True -exclude_lines = - pragma: no cover - abc.abstractmethod - if TYPE_CHECKING.*: - if _t.TYPE_CHECKING: - if t.TYPE_CHECKING: - @overload - class .*\bProtocol\b.*\): - raise NotImplementedError - -partial_branches = - pragma: no branch - if not TYPE_CHECKING: - if not _t.TYPE_CHECKING: - if not t.TYPE_CHECKING: - if .* or not TYPE_CHECKING: - if .* or not _t.TYPE_CHECKING: - if .* or not t.TYPE_CHECKING: diff --git a/.github/workflows/autodeps.yml b/.github/workflows/autodeps.yml index 40cf05726c..0e0655c5aa 100644 --- a/.github/workflows/autodeps.yml +++ b/.github/workflows/autodeps.yml @@ -7,6 +7,7 @@ on: jobs: Autodeps: + if: github.repository_owner == 'python-trio' name: Autodeps timeout-minutes: 10 runs-on: 'ubuntu-latest' @@ -18,23 +19,24 @@ jobs: contents: write steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.8" - name: Bump dependencies run: | - python -m pip install -U pip + python -m pip install -U pip pre-commit python -m pip install -r test-requirements.txt - pip-compile -U test-requirements.in - pip-compile -U docs-requirements.in + uv pip compile --universal --python-version=3.8 --upgrade test-requirements.in -o test-requirements.txt + uv pip compile --universal --python-version=3.8 --upgrade docs-requirements.in -o docs-requirements.txt + pre-commit autoupdate --jobs 0 - name: Black run: | # The new dependencies may contain a new black version. # Commit any changes immediately. python -m pip install -r test-requirements.txt - black setup.py trio + black src/trio - name: Commit changes and create automerge PR env: GH_TOKEN: ${{ github.token }} diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 71c5e221a3..60f8b79c03 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,18 +18,11 @@ jobs: strategy: fail-fast: false matrix: - # pypy 3.9 and 3.10 are failing, see https://github.com/python-trio/trio/issues/2678 and https://github.com/python-trio/trio/issues/2776 respectively - python: ['3.8', '3.9', '3.10'] #, 'pypy-3.9-nightly', 'pypy-3.10-nightly'] + python: ['3.8', '3.9', '3.10', '3.11', '3.12'] arch: ['x86', 'x64'] lsp: [''] lsp_extract_file: [''] extra_name: [''] - exclude: - # pypy does not release 32-bit binaries - - python: 'pypy-3.9-nightly' - arch: 'x86' - #- python: 'pypy-3.10-nightly' - # arch: 'x86' include: - python: '3.8' arch: 'x64' @@ -57,9 +50,9 @@ jobs: }} steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: # This allows the matrix to specify just the major.minor version while still # expanding it to get the latest patch version including alpha releases. @@ -94,7 +87,7 @@ jobs: strategy: fail-fast: false matrix: - python: ['pypy-3.9', 'pypy-3.10', '3.8', '3.9', '3.10', '3.11', '3.12-dev', 'pypy-3.9-nightly', 'pypy-3.10-nightly'] + python: ['pypy-3.9', 'pypy-3.10', '3.8', '3.9', '3.10', '3.11', '3.12', '3.13'] check_formatting: ['0'] no_test_requirements: ['0'] extra_name: [''] @@ -117,9 +110,9 @@ jobs: }} steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 if: "!endsWith(matrix.python, '-dev')" with: python-version: ${{ fromJSON(format('["{0}", "{1}"]', format('{0}.0-alpha - {0}.X', matrix.python), matrix.python))[startsWith(matrix.python, 'pypy')] }} @@ -150,7 +143,7 @@ jobs: strategy: fail-fast: false matrix: - python: ['3.8', '3.9', '3.10', 'pypy-3.9-nightly', 'pypy-3.10-nightly'] + python: ['3.8', '3.9', '3.10', '3.11', '3.12'] continue-on-error: >- ${{ ( @@ -162,9 +155,9 @@ jobs: }} steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ fromJSON(format('["{0}", "{1}"]', format('{0}.0-alpha - {0}.X', matrix.python), matrix.python))[startsWith(matrix.python, 'pypy')] }} cache: pip @@ -179,6 +172,65 @@ jobs: name: macOS (${{ matrix.python }}) flags: macOS,${{ matrix.python }} + # run CI on a musl linux + Alpine: + name: "Alpine" + runs-on: ubuntu-latest + container: alpine + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Install necessary packages + # can't use setup-python because that python doesn't seem to work; + # `python3-dev` (rather than `python:alpine`) for some ctypes reason, + # `nodejs` for pyright (`node-env` pulls in nodejs but that takes a while and can time out the test). + run: apk update && apk add python3-dev bash nodejs + - name: Enter virtual environment + run: python -m venv .venv + - name: Run tests + run: source .venv/bin/activate && ./ci.sh + - if: always() + uses: codecov/codecov-action@v3 + with: + directory: empty + token: 87cefb17-c44b-4f2f-8b30-1fff5769ce46 + name: Alpine + flags: Alpine,3.12 + + Cython: + name: "Cython" + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python: ['3.8', '3.12'] + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Setup python + uses: actions/setup-python@v5 + with: + python-version: '${{ matrix.python }}' + cache: pip + # setuptools is needed to get distutils on 3.12, which cythonize requires + - name: install trio and setuptools + run: python -m pip install --upgrade pip . setuptools + + - name: install cython<3 + run: python -m pip install "cython<3" + - name: compile pyx file + run: cythonize -i tests/cython/test_cython.pyx + - name: import & run module + run: python -c 'import tests.cython.test_cython' + + - name: install cython>=3 + run: python -m pip install "cython>=3" + - name: compile pyx file + # different cython version should trigger a re-compile, but --force just in case + run: cythonize --inplace --force tests/cython/test_cython.pyx + - name: import & run module + run: python -c 'import tests.cython.test_cython' + # https://github.com/marketplace/actions/alls-green#why check: # This job does nothing and is only used for the branch protection @@ -188,6 +240,8 @@ jobs: - Windows - Ubuntu - macOS + - Alpine + - Cython runs-on: ubuntu-latest diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d2e06e0dc3..810c28e4a2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ ci: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v4.6.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -18,18 +18,18 @@ repos: - id: check-case-conflict - id: sort-simple-yaml files: .pre-commit-config.yaml - - repo: https://github.com/psf/black - rev: 23.10.1 + - repo: https://github.com/psf/black-pre-commit-mirror + rev: 24.4.2 hooks: - id: black - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.1 + rev: v0.5.5 hooks: - id: ruff types: [file] types_or: [python, pyi, toml] args: ["--show-fixes"] - repo: https://github.com/codespell-project/codespell - rev: v2.2.6 + rev: v2.3.0 hooks: - id: codespell diff --git a/MANIFEST.in b/MANIFEST.in index eb9c0173da..440994e43a 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -2,7 +2,7 @@ include LICENSE LICENSE.MIT LICENSE.APACHE2 include README.rst include CODE_OF_CONDUCT.md CONTRIBUTING.md include test-requirements.txt -include trio/py.typed -recursive-include trio/_tests/test_ssl_certs *.pem +include src/trio/py.typed +recursive-include src/trio/_tests/test_ssl_certs *.pem recursive-include docs * prune docs/build diff --git a/README.rst b/README.rst index 6d8fa7633b..65f6df8946 100644 --- a/README.rst +++ b/README.rst @@ -18,7 +18,7 @@ :target: https://anaconda.org/conda-forge/trio :alt: Latest conda-forge version -.. image:: https://codecov.io/gh/python-trio/trio/branch/master/graph/badge.svg +.. image:: https://codecov.io/gh/python-trio/trio/branch/main/graph/badge.svg :target: https://codecov.io/gh/python-trio/trio :alt: Test coverage @@ -31,7 +31,7 @@ Trio – a friendly Python library for async concurrency and I/O The Trio project aims to produce a production-quality, `permissively licensed -`__, +`__, async/await-native I/O library for Python. Like all async libraries, its main purpose is to help you write programs that do **multiple things at the same time** with **parallelized I/O**. A web spider that @@ -134,7 +134,7 @@ choices **I want to make sure my company's lawyers won't get angry at me!** No worries, Trio is permissively licensed under your choice of MIT or Apache 2. See `LICENSE -`__ for details. +`__ for details. Code of conduct diff --git a/check.sh b/check.sh index 58b00aa1ce..d6efb8749a 100755 --- a/check.sh +++ b/check.sh @@ -13,7 +13,7 @@ fi # Test if the generated code is still up to date echo "::group::Generate Exports" -python ./trio/_tools/gen_exports.py --test \ +python ./src/trio/_tools/gen_exports.py --test \ || EXIT_STATUS=$? echo "::endgroup::" @@ -23,10 +23,10 @@ echo "::endgroup::" # autoflake --recursive --in-place . # pyupgrade --py3-plus $(find . -name "*.py") echo "::group::Black" -if ! black --check setup.py trio; then +if ! black --check src/trio; then echo "* Black found issues" >> "$GITHUB_STEP_SUMMARY" EXIT_STATUS=1 - black --diff setup.py trio + black --diff src/trio echo "::endgroup::" echo "::error:: Black found issues" else @@ -57,16 +57,16 @@ echo "::group::Mypy" rm -f mypy_annotate.dat # Pipefail makes these pipelines fail if mypy does, even if mypy_annotate.py succeeds. set -o pipefail -mypy trio --show-error-end --platform linux | python ./trio/_tools/mypy_annotate.py --dumpfile mypy_annotate.dat --platform Linux \ +mypy --show-error-end --platform linux | python ./src/trio/_tools/mypy_annotate.py --dumpfile mypy_annotate.dat --platform Linux \ || { echo "* Mypy (Linux) found type errors." >> "$GITHUB_STEP_SUMMARY"; MYPY=1; } # Darwin tests FreeBSD too -mypy trio --show-error-end --platform darwin | python ./trio/_tools/mypy_annotate.py --dumpfile mypy_annotate.dat --platform Mac \ +mypy --show-error-end --platform darwin | python ./src/trio/_tools/mypy_annotate.py --dumpfile mypy_annotate.dat --platform Mac \ || { echo "* Mypy (Mac) found type errors." >> "$GITHUB_STEP_SUMMARY"; MYPY=1; } -mypy trio --show-error-end --platform win32 | python ./trio/_tools/mypy_annotate.py --dumpfile mypy_annotate.dat --platform Windows \ +mypy --show-error-end --platform win32 | python ./src/trio/_tools/mypy_annotate.py --dumpfile mypy_annotate.dat --platform Windows \ || { echo "* Mypy (Windows) found type errors." >> "$GITHUB_STEP_SUMMARY"; MYPY=1; } set +o pipefail # Re-display errors using Github's syntax, read out of mypy_annotate.dat -python ./trio/_tools/mypy_annotate.py --dumpfile mypy_annotate.dat +python ./src/trio/_tools/mypy_annotate.py --dumpfile mypy_annotate.dat # Then discard. rm -f mypy_annotate.dat echo "::endgroup::" @@ -78,10 +78,10 @@ fi # Check pip compile is consistent echo "::group::Pip Compile - Tests" -pip-compile test-requirements.in +uv pip compile --universal --python-version=3.8 test-requirements.in -o test-requirements.txt echo "::endgroup::" echo "::group::Pip Compile - Docs" -pip-compile docs-requirements.in +uv pip compile --universal --python-version=3.8 docs-requirements.in -o docs-requirements.txt echo "::endgroup::" if git status --porcelain | grep -q "requirements.txt"; then @@ -97,15 +97,10 @@ fi codespell || EXIT_STATUS=$? echo "::group::Pyright interface tests" -python trio/_tests/check_type_completeness.py --overwrite-file || EXIT_STATUS=$? -if git status --porcelain trio/_tests/verify_types*.json | grep -q "M"; then - echo "* Type completeness changed, please update!" >> "$GITHUB_STEP_SUMMARY" - echo "::error::Type completeness changed, please update!" - git --no-pager diff --color trio/_tests/verify_types*.json - EXIT_STATUS=1 -fi +python src/trio/_tests/check_type_completeness.py || EXIT_STATUS=$? -pyright trio/_tests/type_tests || EXIT_STATUS=$? +pyright src/trio/_tests/type_tests || EXIT_STATUS=$? +pyright src/trio/_core/_tests/type_tests || EXIT_STATUS=$? echo "::endgroup::" # Finally, leave a really clear warning of any issues and exit @@ -118,8 +113,8 @@ Problems were found by static analysis (listed above). To fix formatting and see remaining errors, run pip install -r test-requirements.txt - black setup.py trio - isort setup.py trio + black src/trio + ruff check src/trio ./check.sh in your local checkout. diff --git a/ci.sh b/ci.sh index 157b3ce8b2..112ed04d7a 100755 --- a/ci.sh +++ b/ci.sh @@ -37,11 +37,11 @@ python -c "import sys, struct, ssl; print('python:', sys.version); print('versio echo "::endgroup::" echo "::group::Install dependencies" -python -m pip install -U pip setuptools wheel +python -m pip install -U pip build python -m pip --version -python setup.py sdist --formats=zip -python -m pip install dist/*.zip +python -m build +python -m pip install dist/*.whl if [ "$CHECK_FORMATTING" = "1" ]; then python -m pip install -r test-requirements.txt @@ -95,7 +95,7 @@ else # when installing, and then running 'certmgr.msc' and exporting the # certificate. See: # http://www.migee.com/2010/09/24/solution-for-unattendedsilent-installs-and-would-you-like-to-install-this-device-software/ - certutil -addstore "TrustedPublisher" trio/_tests/astrill-codesigning-cert.cer + certutil -addstore "TrustedPublisher" src/trio/_tests/astrill-codesigning-cert.cer # Double-slashes are how you tell windows-bash that you want a single # slash, and don't treat this as a unix-style filename that needs to # be replaced by a windows-style filename. @@ -128,7 +128,7 @@ else echo "::endgroup::" echo "::group:: Run Tests" - if COVERAGE_PROCESS_START=$(pwd)/../.coveragerc coverage run --rcfile=../.coveragerc -m pytest -r a -p trio._tests.pytest_plugin --junitxml=../test-results.xml --run-slow ${INSTALLDIR} --verbose --durations=10 $flags; then + if COVERAGE_PROCESS_START=$(pwd)/../pyproject.toml coverage run --rcfile=../pyproject.toml -m pytest -ra --junitxml=../test-results.xml --run-slow ${INSTALLDIR} --verbose --durations=10 $flags; then PASSED=true else PASSED=false @@ -136,9 +136,9 @@ else echo "::endgroup::" echo "::group::Coverage" - coverage combine --rcfile ../.coveragerc - coverage report -m --rcfile ../.coveragerc - coverage xml --rcfile ../.coveragerc + coverage combine --rcfile ../pyproject.toml + coverage report -m --rcfile ../pyproject.toml + coverage xml --rcfile ../pyproject.toml # Remove the LSP again; again we want to do this ASAP to avoid # accidentally breaking other stuff. diff --git a/docs-requirements.in b/docs-requirements.in index 9239fe3fce..c4695fc688 100644 --- a/docs-requirements.in +++ b/docs-requirements.in @@ -1,14 +1,17 @@ -# RTD is currently installing 1.5.3, which has a bug in :lineno-match: -sphinx >= 4.0, < 6.2 +# RTD is currently installing 1.5.3, which has a bug in :lineno-match: (??) +# sphinx 5.3 doesn't work with our _NoValue workaround +sphinx >= 6.0 jinja2 sphinx_rtd_theme sphinxcontrib-jquery sphinxcontrib-trio towncrier +sphinx-hoverxref +sphinx-codeautolink # Trio's own dependencies cffi; os_name == "nt" -attrs >= 19.2.0 +attrs >= 23.2.0 sortedcontainers idna outcome diff --git a/docs-requirements.txt b/docs-requirements.txt index a6f9148063..461d6e3d93 100644 --- a/docs-requirements.txt +++ b/docs-requirements.txt @@ -1,38 +1,38 @@ -# -# This file is autogenerated by pip-compile with Python 3.8 -# by the following command: -# -# pip-compile docs-requirements.in -# +# This file was autogenerated by uv via the following command: +# uv pip compile --universal --python-version=3.8 docs-requirements.in -o docs-requirements.txt alabaster==0.7.13 # via sphinx -attrs==23.1.0 +attrs==23.2.0 # via # -r docs-requirements.in # outcome -babel==2.13.0 +babel==2.15.0 # via sphinx -certifi==2023.7.22 +beautifulsoup4==4.12.3 + # via sphinx-codeautolink +certifi==2024.7.4 # via requests -cffi==1.16.0 - # via cryptography -charset-normalizer==3.3.0 +cffi==1.16.0 ; os_name == 'nt' or platform_python_implementation != 'PyPy' + # via + # -r docs-requirements.in + # cryptography +charset-normalizer==3.3.2 # via requests click==8.1.7 - # via - # click-default-group - # towncrier -click-default-group==1.2.4 # via towncrier -cryptography==41.0.4 +colorama==0.4.6 ; platform_system == 'Windows' or sys_platform == 'win32' + # via + # click + # sphinx +cryptography==42.0.8 # via pyopenssl -docutils==0.18.1 +docutils==0.20.1 # via # sphinx # sphinx-rtd-theme -exceptiongroup==1.1.3 +exceptiongroup==1.2.1 # via -r docs-requirements.in -idna==3.4 +idna==3.7 # via # -r docs-requirements.in # requests @@ -40,46 +40,54 @@ imagesize==1.4.1 # via sphinx immutables==0.20 # via -r docs-requirements.in -importlib-metadata==6.8.0 +importlib-metadata==8.0.0 ; python_version < '3.10' # via sphinx -importlib-resources==6.1.0 +importlib-resources==6.4.0 ; python_version < '3.10' # via towncrier incremental==22.10.0 # via towncrier -jinja2==3.1.2 +jinja2==3.1.4 # via # -r docs-requirements.in # sphinx # towncrier -markupsafe==2.1.3 +markupsafe==2.1.5 # via jinja2 -outcome==1.3.0 +outcome==1.3.0.post0 # via -r docs-requirements.in -packaging==23.2 +packaging==24.1 # via sphinx -pycparser==2.21 +pycparser==2.22 ; os_name == 'nt' or platform_python_implementation != 'PyPy' # via cffi -pygments==2.16.1 +pygments==2.18.0 # via sphinx -pyopenssl==23.2.0 +pyopenssl==24.1.0 # via -r docs-requirements.in -pytz==2023.3.post1 +pytz==2024.1 ; python_version < '3.9' # via babel -requests==2.31.0 +requests==2.32.3 # via sphinx -sniffio==1.3.0 +sniffio==1.3.1 # via -r docs-requirements.in snowballstemmer==2.2.0 # via sphinx sortedcontainers==2.4.0 # via -r docs-requirements.in -sphinx==6.1.3 +soupsieve==2.5 + # via beautifulsoup4 +sphinx==7.1.2 # via # -r docs-requirements.in + # sphinx-codeautolink + # sphinx-hoverxref # sphinx-rtd-theme # sphinxcontrib-jquery # sphinxcontrib-trio -sphinx-rtd-theme==1.3.0 +sphinx-codeautolink==0.15.2 + # via -r docs-requirements.in +sphinx-hoverxref==1.4.0 + # via -r docs-requirements.in +sphinx-rtd-theme==2.0.0 # via -r docs-requirements.in sphinxcontrib-applehelp==1.0.4 # via sphinx @@ -90,6 +98,7 @@ sphinxcontrib-htmlhelp==2.0.1 sphinxcontrib-jquery==4.1 # via # -r docs-requirements.in + # sphinx-hoverxref # sphinx-rtd-theme sphinxcontrib-jsmath==1.0.1 # via sphinx @@ -99,13 +108,13 @@ sphinxcontrib-serializinghtml==1.1.5 # via sphinx sphinxcontrib-trio==1.1.2 # via -r docs-requirements.in -tomli==2.0.1 +tomli==2.0.1 ; python_version < '3.11' # via towncrier -towncrier==23.6.0 +towncrier==23.11.0 # via -r docs-requirements.in -urllib3==2.0.7 +urllib3==2.2.2 # via requests -zipp==3.17.0 +zipp==3.19.2 ; python_version < '3.10' # via # importlib-metadata # importlib-resources diff --git a/docs/source/awesome-trio-libraries.rst b/docs/source/awesome-trio-libraries.rst index b3174c97a2..823bf0779a 100644 --- a/docs/source/awesome-trio-libraries.rst +++ b/docs/source/awesome-trio-libraries.rst @@ -91,24 +91,27 @@ RPC Testing ------- * `pytest-trio `__ - Pytest plugin for trio. -* `hypothesis-trio `__ - Hypothesis plugin for trio. +* `hypothesis-trio `__ - Hypothesis supports Trio out of the box for + ``@given(...)`` tests; this extension provides Trio-compatible stateful testing. * `trustme `__ - #1 quality TLS certs while you wait, for the discerning tester. * `pytest-aio `_ - Pytest plugin with support for trio, curio, asyncio +* `logot `_ - Test whether your async code is logging correctly. Tools and Utilities ------------------- -* `trio-typing `__ - Type hints for Trio and related projects. * `trio-util `__ - An assortment of utilities for the Trio async/await framework. -* `flake8-trio `__ - Highly opinionated linter for various sorts of problems in Trio and/or AnyIO. Can run as a flake8 plugin, or standalone with support for autofixing some errors. +* `flake8-async `__ - Highly opinionated linter for various sorts of problems in Trio, AnyIO and/or asyncio. Can run as a flake8 plugin, or standalone with support for autofixing some errors. * `tricycle `__ - This is a library of interesting-but-maybe-not-yet-fully-proven extensions to Trio. * `tenacity `__ - Retrying library for Python with async/await support. * `perf-timer `__ - A code timer with Trio async support (see ``TrioPerfTimer``). Collects execution time of a block of code excluding time when the coroutine isn't scheduled, such as during blocking I/O and sleep. Also offers ``trio_perf_counter()`` for low-level timing. * `aiometer `__ - Execute lots of tasks concurrently while controlling concurrency limits * `triotp `__ - OTP framework for Python Trio +* `aioresult `__ - Get the return value of a background async function in Trio or anyio, along with a simple Future class and wait utilities + Trio/Asyncio Interoperability ----------------------------- -* `anyio `__ - AnyIO is a asynchronous compatibility API that allows applications and libraries written against it to run unmodified on asyncio, curio and trio. +* `anyio `__ - AnyIO is a asynchronous compatibility API that allows applications and libraries written against it to run unmodified on asyncio or trio. * `sniffio `__ - This is a tiny package whose only purpose is to let you detect which async library your code is running under. * `trio-asyncio `__ - Trio-Asyncio lets you use many asyncio libraries from your Trio app. diff --git a/docs/source/conf.py b/docs/source/conf.py index c56ce12925..7ea27de24b 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -16,13 +16,25 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # +from __future__ import annotations + +import collections.abc import os import sys +import types +from typing import TYPE_CHECKING, cast + +if TYPE_CHECKING: + from sphinx.application import Sphinx + from sphinx.util.typing import Inventory # For our local_customization module sys.path.insert(0, os.path.abspath(".")) # For trio itself -sys.path.insert(0, os.path.abspath("../..")) +sys.path.insert(0, os.path.abspath("../../src")) + +# Enable reloading with `typing.TYPE_CHECKING` being True +os.environ["SPHINX_AUTODOC_RELOAD_MODULES"] = "1" # https://docs.readthedocs.io/en/stable/builds.html#build-environment if "READTHEDOCS" in os.environ: @@ -38,6 +50,21 @@ check=True, ) +# Sphinx is very finicky, and somewhat buggy, so we have several different +# methods to help it resolve links. +# 1. The ones that are not possible to fix are added to `nitpick_ignore` +# 2. some can be resolved with a simple alias in `autodoc_type_aliases`, +# even if that is primarily meant for TypeAliases +# 3. autodoc_process_signature is hooked up to an event, and we use it for +# whole-sale replacing types in signatures where internal details are not +# relevant or hard to read. +# 4. add_intersphinx manually modifies the intersphinx mappings after +# objects.inv has been parsed, to resolve bugs and version differences +# that causes some objects to be looked up incorrectly. +# 5. docs/source/typevars.py handles redirecting `typing_extensions` objects to `typing`, and linking `TypeVar`s to `typing.TypeVar` instead of sphinx wanting to link them to their individual definitions. +# It's possible there's better methods for resolving some of the above +# problems, but this works for now:tm: + # Warn about all references to unknown targets nitpicky = True # Except for these ones, which we expect to point to unknown targets: @@ -48,38 +75,48 @@ ("py:class", "trio.lowlevel.RunLocal"), # trio.abc is documented at random places scattered throughout the docs ("py:mod", "trio.abc"), - ("py:class", "math.inf"), ("py:exc", "Anything else"), ("py:class", "async function"), ("py:class", "sync function"), - # why aren't these found in stdlib? - ("py:class", "types.FrameType"), - # these are not defined in https://docs.python.org/3/objects.inv + # these do not have documentation on python.org + # nor entries in objects.inv ("py:class", "socket.AddressFamily"), ("py:class", "socket.SocketKind"), - ("py:class", "Buffer"), # collections.abc.Buffer, in 3.12 ] autodoc_inherit_docstrings = False default_role = "obj" -# These have incorrect __module__ set in stdlib and give the error -# `py:class reference target not found` -# Some of the nitpick_ignore's above can probably be fixed with this. -# See https://github.com/sphinx-doc/sphinx/issues/8315#issuecomment-751335798 + +# A dictionary for users defined type aliases that maps a type name to the full-qualified object name. It is used to keep type aliases not evaluated in the document. +# https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#confval-autodoc_type_aliases +# but it can also be used to help resolve various linking problems autodoc_type_aliases = { - # aliasing doesn't actually fix the warning for types.FrameType, but displaying - # "types.FrameType" is more helpful than just "frame" - "FrameType": "types.FrameType", - "Context": "OpenSSL.SSL.Context", # SSLListener.accept's return type is seen as trio._ssl.SSLStream "SSLStream": "trio.SSLStream", } +# https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#event-autodoc-process-signature def autodoc_process_signature( - app, what, name, obj, options, signature, return_annotation -): + app: Sphinx, + what: str, + name: str, + obj: object, + options: object, + signature: str, + return_annotation: str, +) -> tuple[str, str]: """Modify found signatures to fix various issues.""" + if name == "trio.testing._raises_group._ExceptionInfo.type": + # This has the type "type[E]", which gets resolved into the property itself. + # That means Sphinx can't resolve it. Fix the issue by overwriting with a fully-qualified + # name. + assert isinstance(obj, property), obj + assert isinstance(obj.fget, types.FunctionType), obj.fget + assert ( + obj.fget.__annotations__["return"] == "type[MatchE]" + ), obj.fget.__annotations__ + obj.fget.__annotations__["return"] = "type[~trio.testing._raises_group.MatchE]" if signature is not None: signature = signature.replace("~_contextvars.Context", "~contextvars.Context") if name == "trio.lowlevel.RunVar": # Typevar is not useful here. @@ -88,6 +125,19 @@ def autodoc_process_signature( # Strip the type from the union, make it look like = ... signature = signature.replace(" | type[trio._core._local._NoValue]", "") signature = signature.replace("", "...") + if name in ("trio.testing.RaisesGroup", "trio.testing.Matcher") and ( + "+E" in signature or "+MatchE" in signature + ): + # This typevar being covariant isn't handled correctly in some cases, strip the + + # and insert the fully-qualified name. + signature = signature.replace("+E", "~trio.testing._raises_group.E") + signature = signature.replace( + "+MatchE", "~trio.testing._raises_group.MatchE" + ) + if "DTLS" in name: + signature = signature.replace("SSL.Context", "OpenSSL.SSL.Context") + # Don't specify PathLike[str] | PathLike[bytes], this is just for humans. + signature = signature.replace("StrOrBytesPath", "str | bytes | os.PathLike") return signature, return_annotation @@ -97,7 +147,7 @@ def autodoc_process_signature( # is shipped (should be in the release after 0.2.4) # ...note that this has since grown to contain a bunch of other CSS hacks too # though. -def setup(app): +def setup(app: Sphinx) -> None: app.add_css_file("hackrtd.css") app.connect("autodoc-process-signature", autodoc_process_signature) # After Intersphinx runs, add additional mappings. @@ -120,6 +170,8 @@ def setup(app): "sphinx.ext.napoleon", "sphinxcontrib_trio", "sphinxcontrib.jquery", + "hoverxref.extension", + "sphinx_codeautolink", "local_customization", "typevars", ] @@ -129,18 +181,78 @@ def setup(app): "outcome": ("https://outcome.readthedocs.io/en/latest/", None), "pyopenssl": ("https://www.pyopenssl.org/en/stable/", None), "sniffio": ("https://sniffio.readthedocs.io/en/latest/", None), + "trio-util": ("https://trio-util.readthedocs.io/en/latest/", None), } +# See https://sphinx-hoverxref.readthedocs.io/en/latest/configuration.html +hoverxref_auto_ref = True +hoverxref_domains = ["py"] +# Set the default style (tooltip) for all types to silence logging. +# See https://github.com/readthedocs/sphinx-hoverxref/issues/211 +hoverxref_role_types = { + "attr": "tooltip", + "class": "tooltip", + "const": "tooltip", + "exc": "tooltip", + "func": "tooltip", + "meth": "tooltip", + "mod": "tooltip", + "obj": "tooltip", + "ref": "tooltip", + "data": "tooltip", +} + +# See https://sphinx-codeautolink.readthedocs.io/en/latest/reference.html#configuration +codeautolink_autodoc_inject = False +codeautolink_global_preface = """ +import trio +from trio import * +""" + + +def add_intersphinx(app: Sphinx) -> None: + """Add some specific intersphinx mappings. + + Hooked up to builder-inited. app.builder.env.interpshinx_inventory is not an official API, so this may break on new sphinx versions. + """ + + def add_mapping( + reftype: str, + library: str, + obj: str, + version: str = "3.12", + target: str | None = None, + ) -> None: + """helper function""" + url_version = "3" if version == "3.12" else version + if target is None: + target = f"{library}.{obj}" + + # sphinx doing fancy caching stuff makes this attribute invisible + # to type checkers + inventory = app.builder.env.intersphinx_inventory # type: ignore[attr-defined] + assert isinstance(inventory, dict) + inventory = cast("Inventory", inventory) + + inventory[f"py:{reftype}"][f"{target}"] = ( + "Python", + version, + f"https://docs.python.org/{url_version}/library/{library}.html/{obj}", + "-", + ) -def add_intersphinx(app) -> None: - """Add some specific intersphinx mappings.""" # This has been removed in Py3.12, so add a link to the 3.11 version with deprecation warnings. - app.builder.env.intersphinx_inventory["py:method"]["pathlib.Path.link_to"] = ( - "Python", - "3.11", - "https://docs.python.org/3.11/library/pathlib.html#pathlib.Path.link_to", - "-", - ) + add_mapping("method", "pathlib", "Path.link_to", "3.11") + # defined in py:data in objects.inv, but sphinx looks for a py:class + add_mapping("class", "math", "inf") + # `types.FrameType.__module__` is "builtins", so sphinx looks for + # builtins.FrameType. + # See https://github.com/sphinx-doc/sphinx/issues/11802 + add_mapping("class", "types", "FrameType") + # new in py3.12, and need target because sphinx is unable to look up + # the module of the object if compiling on <3.12 + if not hasattr(collections.abc, "Buffer"): + add_mapping("class", "collections.abc", "Buffer", target="Buffer") autodoc_member_order = "bysource" @@ -159,7 +271,7 @@ def add_intersphinx(app) -> None: # General information about the project. project = "Trio" -copyright = "2017, Nathaniel J. Smith" +copyright = "2017, Nathaniel J. Smith" # noqa: A001 # Name shadows builtin author = "Nathaniel J. Smith" # The version info for the project you're documenting, acts as replacement for @@ -187,7 +299,7 @@ def add_intersphinx(app) -> None: # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This patterns also effect to html_static_path and html_extra_path -exclude_patterns = [] +exclude_patterns: list[str] = [] # The name of the Pygments (syntax highlighting) style to use. pygments_style = "default" @@ -246,7 +358,7 @@ def add_intersphinx(app) -> None: # -- Options for LaTeX output --------------------------------------------- -latex_elements = { +latex_elements: dict[str, object] = { # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst index 6189814b3f..f37f57d5dd 100644 --- a/docs/source/contributing.rst +++ b/docs/source/contributing.rst @@ -133,7 +133,7 @@ in separate sections below: adding a test to make sure it stays fixed. * :ref:`pull-request-formatting`: If you changed Python code, then did - you run ``black setup.py trio``? (Or for other packages, replace + you run ``black trio``? (Or for other packages, replace ``trio`` with the package name.) * :ref:`pull-request-release-notes`: If your change affects @@ -199,8 +199,12 @@ you'll have a chance to see and fix any remaining issues then. Every change should have 100% coverage for both code and tests. But, you can use ``# pragma: no cover`` to mark lines where lack-of-coverage isn't something that we'd want to fix (as opposed to -it being merely hard to fix). For example:: +it being merely hard to fix). For example: +.. code-block:: python + + if ...: + ... else: # pragma: no cover raise AssertionError("this can't happen!") @@ -289,7 +293,9 @@ Instead of wasting time arguing about code formatting, we use `black `__ as well as other tools to automatically format all our code to a standard style. While you're editing code you can be as sloppy as you like about whitespace; and then before you commit, -just run:: +just run: + +.. code-block:: pip install -U pre-commit pre-commit @@ -301,12 +307,16 @@ names, writing useful comments, and making sure your docstrings are nicely formatted. (black doesn't reformat comments or docstrings.) If you would like, you can even have pre-commit run before you commit by -running:: +running: + +.. code-block:: pre-commit install and now pre-commit will run before git commits. You can uninstall the -pre-commit hook at any time by running:: +pre-commit hook at any time by running: + +.. code-block:: pre-commit uninstall @@ -314,9 +324,11 @@ pre-commit hook at any time by running:: Very occasionally, you'll want to override black formatting. To do so, you can can add ``# fmt: off`` and ``# fmt: on`` comments. -If you want to see what changes black will make, you can use:: +If you want to see what changes black will make, you can use: + +.. code-block:: - black --diff setup.py trio + black --diff trio (``--diff`` displays a diff, versus the default mode which fixes files in-place.) @@ -338,7 +350,7 @@ Basically, every pull request that has a user visible effect should add a short file to the ``newsfragments/`` directory describing the change, with a name like ``..rst``. See `newsfragments/README.rst -`__ +`__ for details. This way we can keep a good list of changes as we go, which makes the release manager happy, which means we get more frequent releases, which means your change gets into users' hands @@ -379,7 +391,7 @@ Documentation is hosted at `Read the Docs rebuilding it after every commit. For docstrings, we use `the Google docstring format -`__. +`__. If you add a new function or class, there's no mechanism for automatically adding that to the docs: you'll have to at least add a line like ``.. autofunction:: `` in the appropriate @@ -396,7 +408,9 @@ whitelist in ``docs/source/conf.py``. To build the docs locally, use our handy ``docs-requirements.txt`` file to install all of the required packages (possibly using a virtualenv). After that, build the docs using ``make html`` in the -docs directory. The whole process might look something like this:: +docs directory. The whole process might look something like this: + +.. code-block:: cd path/to/project/checkout/ pip install -r docs-requirements.txt diff --git a/docs/source/design.rst b/docs/source/design.rst index c3a47ab30c..8fc4102050 100644 --- a/docs/source/design.rst +++ b/docs/source/design.rst @@ -312,7 +312,9 @@ mean. This is often a challenging rule to follow – for example, the call soon code has to jump through some hoops to make it happen – but its most dramatic influence can seen in Trio's task-spawning interface, -where it motivates the use of "nurseries":: +where it motivates the use of "nurseries": + +.. code-block:: python async def parent(): async with trio.open_nursery() as nursery: @@ -376,18 +378,22 @@ Specific style guidelines unconditionally act as cancel+schedule points. * Any function that takes a callable to run should have a signature - like:: + like: + + .. code-block:: python - def call_the_thing(fn, *args, kwonly1, kwonly2, ...):: + def call_the_thing(fn, *args, kwonly1, kwonly2): ... where ``fn(*args)`` is the thing to be called, and ``kwonly1``, - ``kwonly2``, ... are keyword-only arguments that belong to + ``kwonly2``, are keyword-only arguments that belong to ``call_the_thing``. This applies even if ``call_the_thing`` doesn't take any arguments of its own, i.e. in this case its signature looks - like:: + like: - def call_the_thing(fn, *args):: + .. code-block:: python + + def call_the_thing(fn, *args): ... This allows users to skip faffing about with @@ -410,12 +416,14 @@ Specific style guidelines worse, and you get used to the convention pretty quick. * If it's desirable to have both blocking and non-blocking versions of - a function, then they look like:: + a function, then they look like: + + .. code-block:: python - async def OPERATION(...): + async def OPERATION(arg1, arg2): ... - def OPERATION_nowait(...): + def OPERATION_nowait(arg1, arg2): ... and the ``nowait`` version raises :exc:`trio.WouldBlock` if it would block. diff --git a/docs/source/history.rst b/docs/source/history.rst index f1baffbea7..9cef5191e5 100644 --- a/docs/source/history.rst +++ b/docs/source/history.rst @@ -5,6 +5,215 @@ Release history .. towncrier release notes start +Trio 0.26.0 (2024-07-05) +------------------------ + +Features +~~~~~~~~ + +- Added an interactive interpreter ``python -m trio``. + + This makes it easier to try things and experiment with trio in the a Python repl. + Use the ``await`` keyword without needing to call ``trio.run()`` + + .. code-block:: console + + $ python -m trio + Trio 0.21.0+dev, Python 3.10.6 + Use "await" directly instead of "trio.run()". + Type "help", "copyright", "credits" or "license" for more information. + >>> import trio + >>> await trio.sleep(1); print("hi") # prints after one second + hi + + See :ref:`interactive debugging` for further detail. (`#2972 `__) +- :class:`trio.testing.RaisesGroup` can now catch an unwrapped exception with ``unwrapped=True``. This means that the behaviour of :ref:`except* ` can be fully replicated in combination with ``flatten_subgroups=True`` (formerly ``strict=False``). (`#2989 `__) + + +Bugfixes +~~~~~~~~ + +- Fixed a bug where :class:`trio.testing.RaisesGroup(..., strict=False) ` would check the number of exceptions in the raised `ExceptionGroup` before flattening subgroups, leading to incorrectly failed matches. + It now properly supports end (``$``) regex markers in the ``match`` message, by no longer including " (x sub-exceptions)" in the string it matches against. (`#2989 `__) + + +Deprecations and removals +~~~~~~~~~~~~~~~~~~~~~~~~~ + +- Deprecated ``strict`` parameter from :class:`trio.testing.RaisesGroup`, previous functionality of ``strict=False`` is now in ``flatten_subgroups=True``. (`#2989 `__) + + +Trio 0.25.1 (2024-05-16) +------------------------ + +Bugfixes +~~~~~~~~ + +- Fix crash when importing trio in embedded Python on Windows, and other installs that remove docstrings. (`#2987 `__) + + +Trio 0.25.0 (2024-03-17) +------------------------ + +Breaking changes +~~~~~~~~~~~~~~~~ + +- The :ref:`strict_exception_groups ` parameter now defaults to `True` in `trio.run` and `trio.lowlevel.start_guest_run`. `trio.open_nursery` still defaults to the same value as was specified in `trio.run`/`trio.lowlevel.start_guest_run`, but if you didn't specify it there then all subsequent calls to `trio.open_nursery` will change. + This is unfortunately very tricky to change with a deprecation period, as raising a `DeprecationWarning` whenever :ref:`strict_exception_groups ` is not specified would raise a lot of unnecessary warnings. + + Notable side effects of changing code to run with ``strict_exception_groups==True`` + + * If an iterator raises `StopAsyncIteration` or `StopIteration` inside a nursery, then python will not recognize wrapped instances of those for stopping iteration. + * `trio.run_process` is now documented that it can raise an `ExceptionGroup`. It previously could do this in very rare circumstances, but with :ref:`strict_exception_groups ` set to `True` it will now do so whenever exceptions occur in ``deliver_cancel`` or with problems communicating with the subprocess. + + * Errors in opening the process is now done outside the internal nursery, so if code previously ran with ``strict_exception_groups=True`` there are cases now where an `ExceptionGroup` is *no longer* added. + * `trio.TrioInternalError` ``.__cause__`` might be wrapped in one or more `ExceptionGroups ` (`#2786 `__) + + +Features +~~~~~~~~ + +- Add `trio.testing.wait_all_threads_completed`, which blocks until no threads are running tasks. This is intended to be used in the same way as `trio.testing.wait_all_tasks_blocked`. (`#2937 `__) +- :class:`Path` is now a subclass of :class:`pathlib.PurePath`, allowing it to interoperate with other standard + :mod:`pathlib` types. + + Instantiating :class:`Path` now returns a concrete platform-specific subclass, one of :class:`PosixPath` or + :class:`WindowsPath`, matching the behavior of :class:`pathlib.Path`. (`#2959 `__) + + +Bugfixes +~~~~~~~~ + +- The pthread functions are now correctly found on systems using vanilla versions of musl libc. (`#2939 `__) + + +Miscellaneous internal changes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- use the regular readme for the PyPI long_description (`#2866 `__) + + +Trio 0.24.0 (2024-01-10) +------------------------ + +Features +~~~~~~~~ + +- New helper classes: :class:`~.testing.RaisesGroup` and :class:`~.testing.Matcher`. + + In preparation for changing the default of ``strict_exception_groups`` to `True`, we're introducing a set of helper classes that can be used in place of `pytest.raises `_ in tests, to check for an expected `ExceptionGroup`. + These are provisional, and only planned to be supplied until there's a good solution in ``pytest``. See https://github.com/pytest-dev/pytest/issues/11538 (`#2785 `__) + + +Deprecations and removals +~~~~~~~~~~~~~~~~~~~~~~~~~ + +- ``MultiError`` has been fully removed, and all relevant trio functions now raise ExceptionGroups instead. This should not affect end users that have transitioned to using ``except*`` or catching ExceptionGroup/BaseExceptionGroup. (`#2891 `__) + + +Trio 0.23.2 (2023-12-14) +------------------------ + +Features +~~~~~~~~ + +- `TypeVarTuple `_ is now used to fully type :meth:`nursery.start_soon() `, :func:`trio.run()`, :func:`trio.to_thread.run_sync()`, and other similar functions accepting ``(func, *args)``. This means type checkers will be able to verify types are used correctly. :meth:`nursery.start() ` is not fully typed yet however. (`#2881 `__) + + +Bugfixes +~~~~~~~~ + +- Make pyright recognize :func:`open_memory_channel` as generic. (`#2873 `__) + + +Miscellaneous internal changes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- Moved the metadata into :pep:`621`-compliant :file:`pyproject.toml`. (`#2860 `__) +- do not depend on exceptiongroup pre-release (`#2861 `__) +- Move .coveragerc into pyproject.toml (`#2867 `__) + + +Trio 0.23.1 (2023-11-04) +------------------------ + +Bugfixes +~~~~~~~~ + +- Don't crash on import in Anaconda interpreters. (`#2855 `__) + + +Trio 0.23.0 (2023-11-03) +------------------------ + +Headline features +~~~~~~~~~~~~~~~~~ + +- Add type hints. (`#543 `__) + + +Features +~~~~~~~~ + +- When exiting a nursery block, the parent task always waits for child + tasks to exit. This wait cannot be cancelled. However, previously, if + you tried to cancel it, it *would* inject a `Cancelled` exception, + even though it wasn't cancelled. Most users probably never noticed + either way, but injecting a `Cancelled` here is not really useful, and + in some rare cases caused confusion or problems, so Trio no longer + does that. (`#1457 `__) +- If called from a thread spawned by `trio.to_thread.run_sync`, `trio.from_thread.run` and + `trio.from_thread.run_sync` now reuse the task and cancellation status of the host task; + this means that context variables and cancel scopes naturally propagate 'through' + threads spawned by Trio. You can also use `trio.from_thread.check_cancelled` + to efficiently check for cancellation without reentering the Trio thread. (`#2392 `__) +- :func:`trio.lowlevel.start_guest_run` now does a bit more setup of the guest run + before it returns to its caller, so that the caller can immediately make calls to + :func:`trio.current_time`, :func:`trio.lowlevel.spawn_system_task`, + :func:`trio.lowlevel.current_trio_token`, etc. (`#2696 `__) + + +Bugfixes +~~~~~~~~ + +- When a starting function raises before calling :func:`trio.TaskStatus.started`, + :func:`trio.Nursery.start` will no longer wrap the exception in an undocumented + :exc:`ExceptionGroup`. Previously, :func:`trio.Nursery.start` would incorrectly + raise an :exc:`ExceptionGroup` containing it when using ``trio.run(..., + strict_exception_groups=True)``. (`#2611 `__) + + +Deprecations and removals +~~~~~~~~~~~~~~~~~~~~~~~~~ + +- To better reflect the underlying thread handling semantics, + the keyword argument for `trio.to_thread.run_sync` that was + previously called ``cancellable`` is now named ``abandon_on_cancel``. + It still does the same thing -- allow the thread to be abandoned + if the call to `trio.to_thread.run_sync` is cancelled -- but since we now + have other ways to propagate a cancellation without abandoning + the thread, "cancellable" has become somewhat of a misnomer. + The old ``cancellable`` name is now deprecated. (`#2841 `__) +- Deprecated support for ``math.inf`` for the ``backlog`` argument in ``open_tcp_listeners``, making its docstring correct in the fact that only ``TypeError`` is raised if invalid arguments are passed. (`#2842 `__) + + +Removals without deprecations +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- Drop support for Python3.7 and PyPy3.7/3.8. (`#2668 `__) +- Removed special ``MultiError`` traceback handling for IPython. As of `version 8.15 `_ `ExceptionGroup` is handled natively. (`#2702 `__) + + +Miscellaneous internal changes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- Trio now indicates its presence to `sniffio` using the ``sniffio.thread_local`` + interface that is preferred since sniffio v1.3.0. This should be less likely + than the previous approach to cause :func:`sniffio.current_async_library` to + return incorrect results due to unintended inheritance of contextvars. (`#2700 `__) +- On windows, if SIO_BASE_HANDLE failed and SIO_BSP_HANDLE_POLL didn't return a different socket, runtime error will now raise from the OSError that indicated the issue so that in the event it does happen it might help with debugging. (`#2807 `__) + + Trio 0.22.2 (2023-07-13) ------------------------ @@ -1086,7 +1295,9 @@ Highlights * The new nursery :meth:`~Nursery.start` method makes it easy to perform controlled start-up of long-running tasks. For example, given an appropriate ``http_server_on_random_open_port`` - function, you could write:: + function, you could write: + + .. code-block:: python port = await nursery.start(http_server_on_random_open_port) @@ -1328,7 +1539,9 @@ Other changes functions, if you're using asyncio you have to use asyncio functions, and so forth. (See the discussion of the "async sandwich" in the Trio tutorial for more details.) So for example, this isn't - going to work:: + going to work: + + .. code-block:: python async def main(): # asyncio here diff --git a/docs/source/index.rst b/docs/source/index.rst index fc13227c3a..1caf5c043b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -9,7 +9,7 @@ Trio: a friendly Python library for async concurrency and I/O The Trio project's goal is to produce a production-quality, `permissively licensed -`__, +`__, async/await-native I/O library for Python. Like all async libraries, its main purpose is to help you write programs that do **multiple things at the same time** with **parallelized I/O**. A web spider that diff --git a/docs/source/local_customization.py b/docs/source/local_customization.py index 96014f46f9..4cc115f993 100644 --- a/docs/source/local_customization.py +++ b/docs/source/local_customization.py @@ -1,3 +1,7 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + from docutils.parsers.rst import directives as directives from sphinx import addnodes from sphinx.domains.python import PyClasslike @@ -8,6 +12,10 @@ Options as Options, ) +if TYPE_CHECKING: + from sphinx.addnodes import desc_signature + from sphinx.application import Sphinx + """ .. interface:: The nursery interface @@ -18,13 +26,13 @@ class Interface(PyClasslike): - def handle_signature(self, sig, signode): + def handle_signature(self, sig: str, signode: desc_signature) -> tuple[str, str]: signode += addnodes.desc_name(sig, sig) return sig, "" - def get_index_text(self, modname, name_cls): + def get_index_text(self, modname: str, name_cls: tuple[str, str]) -> str: return f"{name_cls[0]} (interface in {modname})" -def setup(app): +def setup(app: Sphinx) -> None: app.add_directive_to_domain("py", "interface", Interface) diff --git a/docs/source/reference-core.rst b/docs/source/reference-core.rst index f601846701..37c5c05feb 100644 --- a/docs/source/reference-core.rst +++ b/docs/source/reference-core.rst @@ -102,6 +102,8 @@ them. Here are the rules: only that one will act as a checkpoint. This is documented on a case-by-case basis. + * :func:`trio.open_nursery` is a further exception to this rule. + * Third-party async functions / iterators / context managers can act as checkpoints; if you see ``await `` or one of its friends, then that *might* be a checkpoint. So to be safe, you @@ -113,7 +115,9 @@ code. Checkpoint-ness is a transitive property: if function A acts as a checkpoint, and you write a function that calls function A, then your function also acts as a checkpoint. If you don't, then it isn't. So there's nothing stopping someone from writing a function -like:: +like: + +.. code-block:: python # technically legal, but bad style: async def why_is_this_async(): @@ -136,7 +140,9 @@ technical requirement that Python imposes, but since it exactly matches the transitivity of checkpoint-ness, we're able to exploit it to help you keep track of checkpoints. Pretty sneaky, eh?) -A slightly trickier case is a function like:: +A slightly trickier case is a function like: + +.. code-block:: python async def sleep_or_not(should_sleep): if should_sleep: @@ -157,7 +163,9 @@ Inside Trio, we're very picky about this, because Trio is the foundation of the whole system so we think it's worth the extra effort to make things extra predictable. It's up to you how picky you want to be in your code. To give you a more realistic example of what this -kind of issue looks like in real life, consider this function:: +kind of issue looks like in real life, consider this function: + +.. code-block:: python async def recv_exactly(sock, nbytes): data = bytearray() @@ -249,7 +257,9 @@ explicitly or when a timeout expires. A simple timeout example ~~~~~~~~~~~~~~~~~~~~~~~~ -In the simplest case, you can apply a timeout to a block of code:: +In the simplest case, you can apply a timeout to a block of code: + +.. code-block:: python with trio.move_on_after(30): result = await do_http_get("https://...") @@ -350,7 +360,9 @@ exception "knows" which block it belongs to. So long as you don't stop it, the exception will keep propagating until it reaches the block that raised it, at which point it will stop automatically. -Here's an example:: +Here's an example: + +.. code-block:: python print("starting...") with trio.move_on_after(5): @@ -379,7 +391,9 @@ block timed out – perhaps you want to do something different, like try a fallback procedure or report a failure to our caller. To make this easier, :func:`move_on_after`\´s ``__enter__`` function returns an object representing this cancel scope, which we can use to check -whether this scope caught a :exc:`Cancelled` exception:: +whether this scope caught a :exc:`Cancelled` exception: + +.. code-block:: python with trio.move_on_after(5) as cancel_scope: await trio.sleep(10) @@ -397,7 +411,9 @@ has been cancelled, *all* cancellable operations in that block will keep raising :exc:`Cancelled`. This helps avoid some pitfalls around resource clean-up. For example, imagine that we have a function that connects to a remote server and sends some messages, and then cleans -up on the way out:: +up on the way out: + +.. code-block:: python with trio.move_on_after(TIMEOUT): conn = make_connection() @@ -424,7 +440,9 @@ cleanup handler, Trio will let you; it's trying to prevent you from accidentally shooting yourself in the foot. Intentional foot-shooting is no problem (or at least – it's not Trio's problem). To do this, create a new scope, and set its :attr:`~CancelScope.shield` -attribute to :data:`True`:: +attribute to :data:`True`: + +.. code-block:: python with trio.move_on_after(TIMEOUT): conn = make_connection() @@ -547,14 +565,18 @@ situation of just wanting to impose a timeout on some code: Cheat sheet: * If you want to impose a timeout on a function, but you don't care - whether it timed out or not:: + whether it timed out or not: + + .. code-block:: python with trio.move_on_after(TIMEOUT): await do_whatever() # carry on! * If you want to impose a timeout on a function, and then do some - recovery if it timed out:: + recovery if it timed out: + + .. code-block:: python with trio.move_on_after(TIMEOUT) as cancel_scope: await do_whatever() @@ -564,7 +586,9 @@ Cheat sheet: * If you want to impose a timeout on a function, and then if it times out then just give up and raise an error for your caller to deal - with:: + with: + + .. code-block:: python with trio.fail_after(TIMEOUT): await do_whatever() @@ -598,13 +622,17 @@ Most libraries for concurrent programming let you start new child tasks (or threads, or whatever) willy-nilly, whenever and where-ever you feel like it. Trio is a bit different: you can't start a child task unless you're prepared to be a responsible parent. The way you -demonstrate your responsibility is by creating a nursery:: +demonstrate your responsibility is by creating a nursery: + +.. code-block:: python async with trio.open_nursery() as nursery: ... And once you have a reference to a nursery object, you can start -children in that nursery:: +children in that nursery: + +.. code-block:: python async def child(): ... @@ -639,8 +667,7 @@ crucial things to keep in mind: * The nursery is marked as "closed", meaning that no new tasks can be started inside it. - * Any unhandled exceptions are re-raised inside the parent task. If - there are multiple exceptions, then they're collected up into a + * Any unhandled exceptions are re-raised inside the parent task, grouped into a single :exc:`BaseExceptionGroup` or :exc:`ExceptionGroup` exception. Since all tasks are descendents of the initial task, one consequence @@ -649,7 +676,9 @@ finished. .. note:: - A return statement will not cancel the nursery if it still has tasks running:: + A return statement will not cancel the nursery if it still has tasks running: + + .. code-block:: python async def main(): async with trio.open_nursery() as nursery: @@ -665,7 +694,9 @@ Child tasks and cancellation In Trio, child tasks inherit the parent nursery's cancel scopes. So in this example, both the child tasks will be cancelled when the timeout -expires:: +expires: + +.. code-block:: python with trio.move_on_after(TIMEOUT): async with trio.open_nursery() as nursery: @@ -675,7 +706,9 @@ expires:: Note that what matters here is the scopes that were active when :func:`open_nursery` was called, *not* the scopes active when ``start_soon`` is called. So for example, the timeout block below does -nothing at all:: +nothing at all: + +.. code-block:: python async with trio.open_nursery() as nursery: with trio.move_on_after(TIMEOUT): # don't do this! @@ -694,7 +727,9 @@ Errors in multiple child tasks Normally, in Python, only one thing happens at a time, which means that only one thing can go wrong at a time. Trio has no such -limitation. Consider code like:: +limitation. Consider code like: + +.. code-block:: python async def broken1(): d = {} @@ -716,7 +751,9 @@ what? The answer is that both exceptions are grouped in an :exc:`ExceptionGroup` encapsulate multiple exceptions being raised at once. To catch individual exceptions encapsulated in an exception group, the ``except*`` -clause was introduced in Python 3.11 (:pep:`654`). Here's how it works:: +clause was introduced in Python 3.11 (:pep:`654`). Here's how it works: + +.. code-block:: python try: async with trio.open_nursery() as nursery: @@ -733,9 +770,11 @@ If you want to reraise exceptions, or raise new ones, you can do so, but be awar exceptions raised in ``except*`` sections will be raised together in a new exception group. -But what if you can't use ``except*`` just yet? Well, for that there is the handy -exceptiongroup_ library which lets you approximate this behavior with exception handler -callbacks:: +But what if you can't use Python 3.11, and therefore ``except*``, just yet? +The same exceptiongroup_ library which backports `ExceptionGroup` also lets +you approximate this behavior with exception handler callbacks: + +.. code-block:: python from exceptiongroup import catch @@ -757,7 +796,9 @@ callbacks:: The semantics for the handler functions are equal to ``except*`` blocks, except for setting local variables. If you need to set local variables, you need to declare them -inside the handler function(s) with the ``nonlocal`` keyword:: +inside the handler function(s) with the ``nonlocal`` keyword: + +.. code-block:: python def handle_keyerrors(excgroup): nonlocal myflag @@ -768,43 +809,91 @@ inside the handler function(s) with the ``nonlocal`` keyword:: async with trio.open_nursery() as nursery: nursery.start_soon(broken1) -For reasons of backwards compatibility, nurseries raise ``trio.MultiError`` and -``trio.NonBaseMultiError`` which inherit from :exc:`BaseExceptionGroup` and -:exc:`ExceptionGroup`, respectively. Users should refrain from attempting to raise or -catch the Trio specific exceptions themselves, and treat them as if they were standard -:exc:`BaseExceptionGroup` or :exc:`ExceptionGroup` instances instead. - -"Strict" versus "loose" ExceptionGroup semantics -++++++++++++++++++++++++++++++++++++++++++++++++ - -Ideally, in some abstract sense we'd want everything that *can* raise an -`ExceptionGroup` to *always* raise an `ExceptionGroup` (rather than, say, a single -`ValueError`). Otherwise, it would be easy to accidentally write something like ``except -ValueError:`` (not ``except*``), which works if a single exception is raised but fails to -catch _anything_ in the case of multiple simultaneous exceptions (even if one of them is -a ValueError). However, this is not how Trio worked in the past: as a concession to -practicality when the ``except*`` syntax hadn't been dreamed up yet, the old -``trio.MultiError`` was raised only when at least two exceptions occurred -simultaneously. Adding a layer of `ExceptionGroup` around every nursery, while -theoretically appealing, would probably break a lot of existing code in practice. - -Therefore, we've chosen to gate the newer, "stricter" behavior behind a parameter -called ``strict_exception_groups``. This is accepted as a parameter to -:func:`open_nursery`, to set the behavior for that nursery, and to :func:`trio.run`, -to set the default behavior for any nursery in your program that doesn't override it. - -* With ``strict_exception_groups=True``, the exception(s) coming out of a nursery will - always be wrapped in an `ExceptionGroup`, so you'll know that if you're handling - single errors correctly, multiple simultaneous errors will work as well. - -* With ``strict_exception_groups=False``, a nursery in which only one task has failed - will raise that task's exception without an additional layer of `ExceptionGroup` - wrapping, so you'll get maximum compatibility with code that was written to - support older versions of Trio. - -To maintain backwards compatibility, the default is ``strict_exception_groups=False``. -The default will eventually change to ``True`` in a future version of Trio, once -Python 3.11 and later versions are in wide use. +.. _handling_exception_groups: + +Designing for multiple errors ++++++++++++++++++++++++++++++ + +Structured concurrency is still a young design pattern, but there are a few patterns +we've identified for how you (or your users) might want to handle groups of exceptions. +Note that the final pattern, simply raising an `ExceptionGroup`, is the most common - +and nurseries automatically do that for you. + +**First**, you might want to 'defer to' a particular exception type, raising just that if +there is any such instance in the group. For example: `KeyboardInterrupt` has a clear +meaning for the surrounding code, could reasonably take priority over errors of other +types, and whether you have one or several of them doesn't really matter. + +This pattern can often be implemented using a decorator or a context manager, such +as :func:`trio_util.multi_error_defer_to` or :func:`trio_util.defer_to_cancelled`. +Note however that re-raising a 'leaf' exception will discard whatever part of the +traceback is attached to the `ExceptionGroup` itself, so we don't recommend this for +errors that will be presented to humans. + +.. + TODO: what about `Cancelled`? It's relevantly similar to `KeyboardInterrupt`, + but if you have multiple Cancelleds destined for different scopes, it seems + like it might be bad to abandon all-but-one of those - we might try to execute + some more code which then itself gets cancelled again, and incur more cleanup. + That's only a mild inefficiency though, and the semantics are fine overall. + +**Second**, you might want to treat the concurrency inside your code as an implementation +detail which is hidden from your users - for example, abstracting a protocol which +involves sending and receiving data to a simple receive-only interface, or implementing +a context manager which maintains some background tasks for the length of the +``async with`` block. + +The simple option here is to ``raise MySpecificError from group``, allowing users to +handle your library-specific error. This is simple and reliable, but doesn't completely +hide the nursery. *Do not* unwrap single exceptions if there could ever be multiple +exceptions though; that always ends in latent bugs and then tears. + +The more complex option is to ensure that only one exception can in fact happen at a time. +This is *very hard*, for example you'll need to handle `KeyboardInterrupt` somehow, and +we strongly recommend having a ``raise PleaseReportBug from group`` fallback just in case +you get a group containing more than one exception. +This is useful when writing a context manager which starts some background tasks, and then +yields to user code which effectively runs 'inline' in the body of the nursery block. +In this case, the background tasks can be wrapped with e.g. the `outcome +`__ library to ensure that only one exception +can be raised (from end-user code); and then you can either ``raise SomeInternalError`` +if a background task failed, or unwrap the user exception if that was the only error. + +.. + For more on this pattern, see https://github.com/python-trio/trio/issues/2929 + and the linked issue on trio-websocket. We may want to provide a nursery mode + which handles this automatically; it's annoying but not too complicated and + seems like it might be a good feature to ship for such cases. + +**Third and most often**, the existence of a nursery in your code is not just an +implementation detail, and callers *should* be prepared to handle multiple exceptions +in the form of an `ExceptionGroup`, whether with ``except*`` or manual inspection +or by just letting it propagate to *their* callers. Because this is so common, +it's nurseries' default behavior and you don't need to do anything. + +.. _strict_exception_groups: + +Historical Note: "non-strict" ExceptionGroups ++++++++++++++++++++++++++++++++++++++++++++++ + +In early versions of Trio, the ``except*`` syntax hadn't be dreamt up yet, and we +hadn't worked with structured concurrency for long or in large codebases. +As a concession to convenience, some APIs would therefore raise single exceptions, +and only wrap concurrent exceptions in the old ``trio.MultiError`` type if there +were two or more. + +Unfortunately, the results were not good: calling code often didn't realize that +some function *could* raise a ``MultiError``, and therefore handle only the common +case - with the result that things would work well in testing, and then crash under +heavier load (typically in production). `asyncio.TaskGroup` learned from this +experience and *always* wraps errors into an `ExceptionGroup`, as does ``anyio``, +and as of Trio 0.25 that's our default behavior too. + +We currently support a compatibility argument ``strict_exception_groups=False`` to +`trio.run` and `trio.open_nursery`, which restores the old behavior (although +``MultiError`` itself has been fully removed). We strongly advise against it for +new code, and encourage existing uses to migrate - we consider the option deprecated +and plan to remove it after a period of documented and then runtime warnings. .. _exceptiongroup: https://pypi.org/project/exceptiongroup/ @@ -818,7 +907,9 @@ connections and supervise children at the same time. The solution here is simple once you see it: there's no requirement that a nursery object stay in the task that created it! We can write -code like this:: +code like this: + +.. code-block:: python async def new_connection_listener(handler, nursery): while True: @@ -832,7 +923,9 @@ code like this:: Notice that ``server`` opens a nursery and passes it to ``new_connection_listener``, and then ``new_connection_listener`` is able to start new tasks as "siblings" of itself. Of course, in this -case, we could just as well have written:: +case, we could just as well have written: + +.. code-block:: python async def server(handler): async with trio.open_nursery() as nursery: @@ -846,7 +939,9 @@ handy. One thing to remember, though: cancel scopes are inherited from the nursery, **not** from the task that calls ``start_soon``. So in this example, the timeout does *not* apply to ``child`` (or to anything -else):: +else): + +.. code-block:: python async def do_spawn(nursery): with trio.move_on_after(TIMEOUT): # don't do this, it has no effect @@ -874,7 +969,9 @@ no reason everyone should have to write their own. For example, here's a function that takes a list of functions, runs them all concurrently, and returns the result from the one that -finishes first:: +finishes first: + +.. code-block:: python async def race(*async_fns): if not async_fns: @@ -1072,7 +1169,9 @@ releasing the lock will call :meth:`~Lock.acquire` before the other task wakes up; in Trio releasing a lock is not a checkpoint.) With an unfair lock, this would result in the same task holding the lock forever and the other task being starved out. But if you run this, -you'll see that the two tasks politely take turns:: +you'll see that the two tasks politely take turns: + +.. code-block:: python # fairness-demo.py @@ -1273,7 +1372,9 @@ Notice a small trick we use: the code in ``main`` creates clone objects to pass into all the child tasks, and then closes the original objects using ``async with``. Another option is to pass clones into all-but-one of the child tasks, and then pass the original object into -the last task, like:: +the last task, like: + +.. code-block:: python # Also works, but is more finicky: send_channel, receive_channel = trio.open_memory_channel(0) @@ -1285,7 +1386,9 @@ the last task, like:: But this is more error-prone, especially if you use a loop to spawn the producers/consumers. -Just make sure that you don't write:: +Just make sure that you don't write: + +.. code-block:: python # Broken, will cause program to hang: send_channel, receive_channel = trio.open_memory_channel(0) @@ -1489,7 +1592,9 @@ statements. As you might expect, you use ``async for`` to iterate over them. :pep:`525` has many more details if you want them. For example, the following is a roundabout way to print -the numbers 0 through 9 with a 1-second delay before each one:: +the numbers 0 through 9 with a 1-second delay before each one: + +.. code-block:: python async def range_slowly(*args): """Like range(), but adds a 1-second sleep before each value.""" @@ -1537,7 +1642,9 @@ If you don't like that ambiguity, and you want to ensure that a generator's ``finally`` blocks and ``__aexit__`` handlers execute as soon as you're done using it, then you'll need to wrap your use of the generator in something like `async_generator.aclosing() -`__:: +`__: + +.. code-block:: python # Instead of this: async for value in my_generator(): @@ -1595,7 +1702,9 @@ Cancel scopes and nurseries .. warning:: You may not write a ``yield`` statement that suspends an async generator inside a `CancelScope` or `Nursery` that was entered within the generator. -That is, this is OK:: +That is, this is OK: + +.. code-block:: python async def some_agen(): with trio.move_on_after(1): @@ -1607,7 +1716,9 @@ That is, this is OK:: yield "second" ... -But this is not:: +But this is not: + +.. code-block:: python async def some_agen(): with trio.move_on_after(1): @@ -1752,7 +1863,9 @@ it's just a matter of creating two separate :class:`CapacityLimiter` objects and passing them in when running these jobs. Or here's an example of defining a custom policy that respects the global thread limit, while making sure that no individual user can use more than 3 -threads at a time:: +threads at a time: + +.. code-block:: python class CombinedLimiter: def __init__(self, first, second): @@ -1833,8 +1946,8 @@ to spawn a child thread, and then use a :ref:`memory channel The ``from_thread.run*`` functions reuse the host task that called :func:`trio.to_thread.run_sync` to run your provided function, as long as you're - using the default ``cancellable=False`` so Trio can be sure that the task will remain - around to perform the work. If you pass ``cancellable=True`` at the outset, or if + using the default ``abandon_on_cancel=False`` so Trio can be sure that the task will remain + around to perform the work. If you pass ``abandon_on_cancel=True`` at the outset, or if you provide a :class:`~trio.lowlevel.TrioToken` when calling back in to Trio, your functions will be executed in a new system task. Therefore, the :func:`~trio.lowlevel.current_task`, :func:`current_effective_deadline`, or other @@ -1842,7 +1955,7 @@ to spawn a child thread, and then use a :ref:`memory channel You can also use :func:`trio.from_thread.check_cancelled` to check for cancellation from a thread that was spawned by :func:`trio.to_thread.run_sync`. If the call to -:func:`~trio.to_thread.run_sync` was cancelled (even if ``cancellable=False``!), then +:func:`~trio.to_thread.run_sync` was cancelled, then :func:`~trio.from_thread.check_cancelled` will raise :func:`trio.Cancelled`. It's like ``trio.from_thread.run(trio.sleep, 0)``, but much faster. @@ -1909,6 +2022,66 @@ explicit and might be easier to reason about. ``contextvars``. +.. _interactive debugging: + + +Interactive debugging +--------------------- + +When you start an interactive Python session to debug any async program +(whether it's based on ``asyncio``, Trio, or something else), every await +expression needs to be inside an async function: + +.. code-block:: console + + $ python + Python 3.10.6 + Type "help", "copyright", "credits" or "license" for more information. + >>> import trio + >>> await trio.sleep(1) + File "", line 1 + SyntaxError: 'await' outside function + >>> async def main(): + ... print("hello...") + ... await trio.sleep(1) + ... print("world!") + ... + >>> trio.run(main) + hello... + world! + +This can make it difficult to iterate quickly since you have to redefine the +whole function body whenever you make a tweak. + +Trio provides a modified interactive console that lets you ``await`` at the top +level. You can access this console by running ``python -m trio``: + +.. code-block:: console + + $ python -m trio + Trio 0.21.0+dev, Python 3.10.6 + Use "await" directly instead of "trio.run()". + Type "help", "copyright", "credits" or "license" for more information. + >>> import trio + >>> print("hello..."); await trio.sleep(1); print("world!") + hello... + world! + +If you are an IPython user, you can use IPython's `autoawait +`__ +function. This can be enabled within the IPython shell by running the magic command +``%autoawait trio``. To have ``autoawait`` enabled whenever Trio installed, you can +add the following to your IPython startup files. +(e.g. ``~/.ipython/profile_default/startup/10-async.py``) + +.. code-block:: + + try: + import trio + get_ipython().run_line_magic("autoawait", "trio") + except ImportError: + pass + Exceptions and warnings ----------------------- diff --git a/docs/source/reference-io.rst b/docs/source/reference-io.rst index d0525e39d4..e8a967bf17 100644 --- a/docs/source/reference-io.rst +++ b/docs/source/reference-io.rst @@ -30,7 +30,9 @@ create complex transport configurations. Here's some examples: stdout. If for some reason you wanted to speak SSL to a subprocess, you could use a :class:`StapledStream` to combine its stdin/stdout into a single bidirectional :class:`~trio.abc.Stream`, and then wrap - that in an :class:`~trio.SSLStream`:: + that in an :class:`~trio.SSLStream`: + + .. code-block:: python ssl_context = ssl.create_default_context() ssl_context.check_hostname = False @@ -42,7 +44,9 @@ create complex transport configurations. Here's some examples: `__. In Trio this is trivial – just wrap your first :class:`~trio.SSLStream` in a second - :class:`~trio.SSLStream`:: + :class:`~trio.SSLStream`: + + .. code-block:: python # Get a raw SocketStream connection to the proxy: s0 = await open_tcp_stream("proxy", 443) @@ -370,12 +374,16 @@ broken features: :func:`~socket.getaddrinfo` and :func:`~socket.getnameinfo` instead. * :func:`~socket.getservbyport`: obsolete and `buggy - `__; instead, do:: + `__; instead, do: + + .. code-block:: python - _, service_name = await getnameinfo((127.0.0.1, port), NI_NUMERICHOST)) + _, service_name = await getnameinfo(('127.0.0.1', port), NI_NUMERICHOST) * :func:`~socket.getservbyname`: obsolete and `buggy - `__; instead, do:: + `__; instead, do: + + .. code-block:: python await getaddrinfo(None, service_name) @@ -631,6 +639,11 @@ Asynchronous path objects .. autoclass:: Path :members: + :inherited-members: + +.. autoclass:: PosixPath + +.. autoclass:: WindowsPath .. _async-file-objects: @@ -690,7 +703,9 @@ Asynchronous file objects `__. * Async file objects can be used as async iterators to iterate over - the lines of the file:: + the lines of the file: + + .. code-block:: python async with await trio.open_file(...) as f: async for line in f: @@ -728,6 +743,8 @@ task and interact with it while it's running: .. automethod:: fileno +.. autoclass:: trio._subprocess.StrOrBytesPath + .. autoclass:: trio.Process() .. autoattribute:: returncode @@ -843,19 +860,25 @@ shell doesn't provide any way to write a double quote inside a double-quoted string. Outside double quotes, any character (including a double quote) can be escaped using a leading ``^``. But since a pipeline is processed by running each command in the pipeline in a -subshell, multiple layers of escaping can be needed:: +subshell, multiple layers of escaping can be needed: + +.. code-block:: sh echo ^^^&x | find "x" | find "x" # prints: &x And if you combine pipelines with () grouping, you can need even more -levels of escaping:: +levels of escaping: + +.. code-block:: sh (echo ^^^^^^^&x | find "x") | find "x" # prints: &x Since process creation takes a single arguments string, ``CMD.EXE``\'s quoting does not influence word splitting, and double quotes are not removed during CMD.EXE's expansion pass. Double quotes are troublesome -because CMD.EXE handles them differently from the MSVC runtime rules; in:: +because CMD.EXE handles them differently from the MSVC runtime rules; in: + +.. code-block:: sh prog.exe "foo \"bar\" baz" diff --git a/docs/source/reference-lowlevel.rst b/docs/source/reference-lowlevel.rst index 712a36ad04..70133b9839 100644 --- a/docs/source/reference-lowlevel.rst +++ b/docs/source/reference-lowlevel.rst @@ -203,7 +203,9 @@ a stream. If you have two different file descriptors for sending and receiving, and want to bundle them together into a single bidirectional -`~trio.abc.Stream`, then use `trio.StapledStream`:: +`~trio.abc.Stream`, then use `trio.StapledStream`: + +.. code-block:: python bidirectional_stream = trio.StapledStream( trio.lowlevel.FdStream(write_fd), @@ -257,6 +259,12 @@ anything real. See `#26 .. function:: wait_overlapped(handle, lpOverlapped) :async: +.. function:: write_overlapped(handle, data) + :async: + +.. function:: readinto_overlapped(handle, data) + :async: + .. function:: current_iocp() .. function:: monitor_completion_key() @@ -397,7 +405,9 @@ The next two functions are used *together* to make up a checkpoint: These are commonly used in cases where you have an operation that might-or-might-not block, and you want to implement Trio's standard -checkpoint semantics. Example:: +checkpoint semantics. Example: + +.. code-block:: python async def operation_that_maybe_blocks(): await checkpoint_if_cancelled() @@ -458,7 +468,9 @@ non-blocking path, etc. If you really want to implement your own lock, then you should study the implementation of :class:`trio.Lock` and use :class:`ParkingLot`, which handles some of these issues for you. But this does serve to illustrate the basic structure of the -:func:`wait_task_rescheduled` API:: +:func:`wait_task_rescheduled` API: + +.. code-block:: python class NotVeryGoodLock: def __init__(self): @@ -590,7 +602,9 @@ like Qt. Its advantages are: from the host, and call sync host APIs from Trio. For example, if you're making a GUI app with Qt as the host loop, then making a `cancel button `__ and - connecting it to a `trio.CancelScope` is as easy as writing:: + connecting it to a `trio.CancelScope` is as easy as writing: + + .. code-block:: python # Trio code can create Qt objects without any special ceremony... my_cancel_button = QPushButton("Cancel") @@ -693,9 +707,12 @@ with your favorite event loop. Treat this section like a checklist. **Getting started:** The first step is to get something basic working. Here's a minimal example of running Trio on top of asyncio, that you -can use as a model:: +can use as a model: - import asyncio, trio +.. code-block:: python + + import asyncio + import trio # A tiny Trio program async def trio_main(): @@ -805,7 +822,9 @@ Here's how we'd extend our asyncio example to implement this pattern: return trio_main_outcome.unwrap() And then you can encapsulate all this machinery in a utility function -that exposes a `trio.run`-like API, but runs both loops together:: +that exposes a `trio.run`-like API, but runs both loops together: + +.. code-block:: python def trio_run_with_asyncio(trio_main, *args, **trio_run_kwargs): async def asyncio_main(): diff --git a/docs/source/reference-testing.rst b/docs/source/reference-testing.rst index 76ecd4a2d4..3b061a32db 100644 --- a/docs/source/reference-testing.rst +++ b/docs/source/reference-testing.rst @@ -72,6 +72,10 @@ Inter-task ordering .. autofunction:: wait_all_tasks_blocked +.. autofunction:: wait_all_threads_completed + +.. autofunction:: active_thread_count + .. _testing-streams: @@ -219,3 +223,16 @@ Testing checkpoints .. autofunction:: assert_no_checkpoints :with: + + +ExceptionGroup helpers +---------------------- + +.. autoclass:: RaisesGroup + :members: + +.. autoclass:: Matcher + :members: + +.. autoclass:: trio.testing._raises_group._ExceptionInfo + :members: diff --git a/docs/source/releasing.rst b/docs/source/releasing.rst index 0fe51370d5..e4cb70685d 100644 --- a/docs/source/releasing.rst +++ b/docs/source/releasing.rst @@ -35,16 +35,18 @@ Things to do for releasing: * push to your personal repository -* create pull request to ``python-trio/trio``'s "master" branch +* create pull request to ``python-trio/trio``'s "main" branch * verify that all checks succeeded * tag with vVERSION, push tag on ``python-trio/trio`` (not on your personal repository) -* push to PyPI:: +* push to PyPI: + + .. code-block:: git clean -xdf # maybe run 'git clean -xdn' first to see what it will delete - python3 setup.py sdist bdist_wheel + python3 -m build twine upload dist/* * update version number in the same pull request diff --git a/docs/source/tutorial.rst b/docs/source/tutorial.rst index 40eafd2833..c7218d873b 100644 --- a/docs/source/tutorial.rst +++ b/docs/source/tutorial.rst @@ -116,7 +116,9 @@ Python 3.5 added a major new feature: async functions. Using Trio is all about writing async functions, so let's start there. An async function is defined like a normal function, except you write -``async def`` instead of ``def``:: +``async def`` instead of ``def``: + +.. code-block:: python # A regular function def regular_double(x): @@ -138,12 +140,16 @@ async function and a regular function: ``await async_double(3)``. 2. You can't use the ``await`` keyword inside the body of a regular - function. If you try it, you'll get a syntax error:: + function. If you try it, you'll get a syntax error: + + .. code-block:: python def print_double(x): print(await async_double(x)) # <-- SyntaxError here - But inside an async function, ``await`` is allowed:: + But inside an async function, ``await`` is allowed: + + .. code-block:: python async def print_double(x): print(await async_double(x)) # <-- OK! @@ -183,7 +189,9 @@ things: 1. A runner function, which is a special *synchronous* function that takes and calls an *asynchronous* function. In Trio, this is - ``trio.run``:: + ``trio.run``: + + .. code-block:: python import trio @@ -208,7 +216,7 @@ things: :func:`trio.sleep`. (:func:`trio.sleep` is like :func:`time.sleep`, but with more async.) - .. code-block:: python3 + .. code-block:: python import trio @@ -254,7 +262,9 @@ little with writing simple async functions and running them with At some point in this process, you'll probably write some code like this, that tries to call an async function but leaves out the -``await``:: +``await``: + +.. code-block:: python import time import trio @@ -278,7 +288,7 @@ argument, then we would get a nice :exc:`TypeError` saying so. But unfortunately, if you forget an ``await``, you don't get that. What you actually get is: -.. code-block:: none +.. code-block:: pycon >>> trio.run(broken_double_sleep, 3) *yawn* Going to sleep @@ -295,21 +305,20 @@ depends on the whims of the garbage collector. If you're using PyPy, you might not even get a warning at all until the next GC collection runs: -.. code-block:: none +.. code-block:: pycon # On PyPy: - >>>> trio.run(broken_double_sleep, 3) + >>> trio.run(broken_double_sleep, 3) *yawn* Going to sleep Woke up after 0.00 seconds, feeling well rested! - >>>> # what the ... ?? not even a warning! + >>> # what the ... ?? not even a warning! - >>>> # but forcing a garbage collection gives us a warning: - >>>> import gc - >>>> gc.collect() + >>> # but forcing a garbage collection gives us a warning: + >>> import gc + >>> gc.collect() /home/njs/pypy-3.8-nightly/lib-python/3/importlib/_bootstrap.py:191: RuntimeWarning: coroutine 'sleep' was never awaited if _module_locks.get(name) is wr: # XXX PyPy fix? 0 - >>>> (If you can't see the warning above, try scrolling right.) @@ -335,7 +344,9 @@ use ``await``. But Python's trying to keep its options open for other libraries that are *ahem* a little less organized about things. So while for our purposes we can think of ``await trio.sleep(...)`` as a single piece of syntax, Python thinks of it as two things: first a -function call that returns this weird "coroutine" object:: +function call that returns this weird "coroutine" object: + +.. code-block:: pycon >>> trio.sleep(3) @@ -343,7 +354,9 @@ function call that returns this weird "coroutine" object:: and then that object gets passed to ``await``, which actually runs the function. So if you forget ``await``, then two bad things happen: your function doesn't actually get called, and you get a "coroutine" object -where you might have been expecting something else, like a number:: +where you might have been expecting something else, like a number: + +.. code-block:: pycon >>> async_double(3) + 1 TypeError: unsupported operand type(s) for +: 'coroutine' and 'int' @@ -1024,7 +1037,9 @@ Flow control in our echo client and server Here's a question you might be wondering about: why does our client use two separate tasks for sending and receiving, instead of a single task that alternates between them – like the server has? For example, -our client could use a single task like:: +our client could use a single task like: + +.. code-block:: python # Can you spot the two problems with this code? async def send_and_receive(client_stream): @@ -1060,7 +1075,9 @@ backed up in the network, until eventually something breaks. a limit on how many bytes you read each time, and see what happens. We could fix this by keeping track of how much data we're expecting at -each moment, and then keep calling ``receive_some`` until we get it all:: +each moment, and then keep calling ``receive_some`` until we get it all: + +.. code-block:: python expected = len(data) while expected > 0: @@ -1154,7 +1171,9 @@ TODO: maybe a brief discussion of :exc:`KeyboardInterrupt` handling? XX todo - timeout example:: + timeout example: + + .. code-block:: python async def counter(): for i in range(100000): @@ -1168,7 +1187,7 @@ TODO: maybe a brief discussion of :exc:`KeyboardInterrupt` handling? you can stick anything inside a timeout block, even child tasks [show something like the first example but with a timeout – they - both get cancelled, the cancelleds get packed into a multierror, and + both get cancelled, the cancelleds get packed into an ExceptionGroup, and then the timeout block catches the cancelled] brief discussion of KI? diff --git a/docs/source/typevars.py b/docs/source/typevars.py index c98f995a7d..17115c9298 100644 --- a/docs/source/typevars.py +++ b/docs/source/typevars.py @@ -2,17 +2,21 @@ See https://github.com/sphinx-doc/sphinx/issues/7722 also. """ + from __future__ import annotations import re from pathlib import Path +from typing import TYPE_CHECKING import trio -from sphinx.addnodes import Element, pending_xref -from sphinx.application import Sphinx -from sphinx.environment import BuildEnvironment from sphinx.errors import NoUri +if TYPE_CHECKING: + from sphinx.addnodes import Element, pending_xref + from sphinx.application import Sphinx + from sphinx.environment import BuildEnvironment + def identify_typevars(trio_folder: Path) -> None: """Record all typevars in trio.""" @@ -57,7 +61,7 @@ def lookup_reference( new_node["reftarget"] = f"typing.{target[18:]}" # This fires off this same event, with our new modified node in order to fetch the right # URL to use. - return app.emit_firstresult( + return app.emit_firstresult( # type: ignore[no-any-return] "missing-reference", env, new_node, diff --git a/newsfragments/1457.feature.rst b/newsfragments/1457.feature.rst deleted file mode 100644 index fc4250cb20..0000000000 --- a/newsfragments/1457.feature.rst +++ /dev/null @@ -1,7 +0,0 @@ -When exiting a nursery block, the parent task always waits for child -tasks to exit. This wait cannot be cancelled. However, previously, if -you tried to cancel it, it *would* inject a `Cancelled` exception, -even though it wasn't cancelled. Most users probably never noticed -either way, but injecting a `Cancelled` here is not really useful, and -in some rare cases caused confusion or problems, so Trio no longer -does that. diff --git a/newsfragments/2392.feature.rst b/newsfragments/2392.feature.rst deleted file mode 100644 index 985d3235af..0000000000 --- a/newsfragments/2392.feature.rst +++ /dev/null @@ -1,5 +0,0 @@ -If called from a thread spawned by `trio.to_thread.run_sync`, `trio.from_thread.run` and -`trio.from_thread.run_sync` now reuse the task and cancellation status of the host task; -this means that context variables and cancel scopes naturally propagate 'through' -threads spawned by Trio. You can also use `trio.from_thread.check_cancelled` -to efficiently check for cancellation without reentering the Trio thread. diff --git a/newsfragments/2611.bugfix.rst b/newsfragments/2611.bugfix.rst deleted file mode 100644 index 2af824a7d7..0000000000 --- a/newsfragments/2611.bugfix.rst +++ /dev/null @@ -1 +0,0 @@ -With ``strict_exception_groups=True``, when you ran a function in a nursery which raised an exception before calling ``task_status.started()``, it previously got wrapped twice over in ``ExceptionGroup`` in some cases. It no longer does that, and also won't wrap any ``ExceptionGroup`` raised by the function itself. diff --git a/newsfragments/2668.removal.rst b/newsfragments/2668.removal.rst deleted file mode 100644 index 512f681077..0000000000 --- a/newsfragments/2668.removal.rst +++ /dev/null @@ -1 +0,0 @@ -Drop support for Python3.7 and PyPy3.7/3.8. diff --git a/newsfragments/2696.feature.rst b/newsfragments/2696.feature.rst deleted file mode 100644 index 560cf3b365..0000000000 --- a/newsfragments/2696.feature.rst +++ /dev/null @@ -1,4 +0,0 @@ -:func:`trio.lowlevel.start_guest_run` now does a bit more setup of the guest run -before it returns to its caller, so that the caller can immediately make calls to -:func:`trio.current_time`, :func:`trio.lowlevel.spawn_system_task`, -:func:`trio.lowlevel.current_trio_token`, etc. diff --git a/newsfragments/2700.misc.rst b/newsfragments/2700.misc.rst deleted file mode 100644 index a70924816e..0000000000 --- a/newsfragments/2700.misc.rst +++ /dev/null @@ -1,4 +0,0 @@ -Trio now indicates its presence to `sniffio` using the ``sniffio.thread_local`` -interface that is preferred since sniffio v1.3.0. This should be less likely -than the previous approach to cause :func:`sniffio.current_async_library` to -return incorrect results due to unintended inheritance of contextvars. diff --git a/newsfragments/2702.removal.rst b/newsfragments/2702.removal.rst deleted file mode 100644 index 900da04498..0000000000 --- a/newsfragments/2702.removal.rst +++ /dev/null @@ -1 +0,0 @@ -Removed special ``MultiError`` traceback handling for IPython. As of `version 8.15 `_ `ExceptionGroup` is handled natively. diff --git a/newsfragments/2807.misc.rst b/newsfragments/2807.misc.rst deleted file mode 100644 index 3d7857d79e..0000000000 --- a/newsfragments/2807.misc.rst +++ /dev/null @@ -1 +0,0 @@ -On windows, if SIO_BASE_HANDLE failed and SIO_BSP_HANDLE_POLL didn't return a different socket, runtime error will now raise from the OSError that indicated the issue so that in the event it does happen it might help with debugging. diff --git a/newsfragments/2842.removal.rst b/newsfragments/2842.removal.rst deleted file mode 100644 index c249956bae..0000000000 --- a/newsfragments/2842.removal.rst +++ /dev/null @@ -1 +0,0 @@ -Removed support for ``math.inf`` for the ``backlog`` argument in ``open_tcp_listeners``, making its docstring correct in the fact that only ``TypeError``s are raised if invalid arguments are passed. diff --git a/newsfragments/543.headline.rst b/newsfragments/543.headline.rst deleted file mode 100644 index bcd4fbe2b7..0000000000 --- a/newsfragments/543.headline.rst +++ /dev/null @@ -1 +0,0 @@ -Add type hints. diff --git a/notes-to-self/aio-guest-test.py b/notes-to-self/aio-guest-test.py index 7bd92aa309..3c607d0281 100644 --- a/notes-to-self/aio-guest-test.py +++ b/notes-to-self/aio-guest-test.py @@ -37,6 +37,7 @@ async def trio_main(): from_trio.put_nowait(n + 1) if n >= 10: return + del _task_ref async def aio_pingpong(from_trio, to_trio): diff --git a/notes-to-self/blocking-read-hack.py b/notes-to-self/blocking-read-hack.py index f4a73f876d..56bcd03df9 100644 --- a/notes-to-self/blocking-read-hack.py +++ b/notes-to-self/blocking-read-hack.py @@ -11,7 +11,9 @@ class BlockingReadTimeoutError(Exception): pass -async def blocking_read_with_timeout(fd, count, timeout): +async def blocking_read_with_timeout( + fd, count, timeout # noqa: ASYNC109 # manual timeout +): print("reading from fd", fd) cancel_requested = False diff --git a/notes-to-self/fbsd-pipe-close-notify.py b/notes-to-self/fbsd-pipe-close-notify.py index ab17f94c3f..ef60d6900e 100644 --- a/notes-to-self/fbsd-pipe-close-notify.py +++ b/notes-to-self/fbsd-pipe-close-notify.py @@ -12,11 +12,11 @@ os.set_blocking(w, False) print("filling pipe buffer") -while True: - try: +try: + while True: os.write(w, b"x") - except BlockingIOError: - break +except BlockingIOError: + pass _, wfds, _ = select.select([], [w], [], 0) print("select() says the write pipe is", "writable" if w in wfds else "NOT writable") diff --git a/notes-to-self/loopy.py b/notes-to-self/loopy.py index 070068015c..99f6e050b9 100644 --- a/notes-to-self/loopy.py +++ b/notes-to-self/loopy.py @@ -6,10 +6,9 @@ async def loopy(): try: while True: - time.sleep( # noqa: ASYNC101 # synchronous sleep to avoid maxing out CPU - 0.01 - ) - await trio.sleep(0) + # synchronous sleep to avoid maxing out CPU + time.sleep(0.01) # noqa: ASYNC251 + await trio.lowlevel.checkpoint() except KeyboardInterrupt: print("KI!") diff --git a/notes-to-self/print-task-tree.py b/notes-to-self/print-task-tree.py index 38e545853e..54b97ec014 100644 --- a/notes-to-self/print-task-tree.py +++ b/notes-to-self/print-task-tree.py @@ -55,8 +55,7 @@ def _render_subtree(name, rendered_children): first_prefix = MID_PREFIX rest_prefix = MID_CONTINUE lines.append(first_prefix + child_lines[0]) - for child_line in child_lines[1:]: - lines.append(rest_prefix + child_line) + lines.extend(rest_prefix + child_line for child_line in child_lines[1:]) return lines diff --git a/notes-to-self/schedule-timing.py b/notes-to-self/schedule-timing.py index c84ec9a436..11594b7cc7 100644 --- a/notes-to-self/schedule-timing.py +++ b/notes-to-self/schedule-timing.py @@ -11,7 +11,7 @@ async def reschedule_loop(depth): global LOOPS while RUNNING: LOOPS += 1 - await trio.sleep(0) + await trio.lowlevel.checkpoint() # await trio.lowlevel.cancel_shielded_checkpoint() else: await reschedule_loop(depth - 1) diff --git a/notes-to-self/time-wait.py b/notes-to-self/time-wait.py index 772f6c2727..edc1b39172 100644 --- a/notes-to-self/time-wait.py +++ b/notes-to-self/time-wait.py @@ -29,16 +29,16 @@ import errno import socket -import attr +import attrs -@attr.s(repr=False) +@attrs.define(repr=False, slots=False) class Options: - listen1_early = attr.ib(default=None) - listen1_middle = attr.ib(default=None) - listen1_late = attr.ib(default=None) - server = attr.ib(default=None) - listen2 = attr.ib(default=None) + listen1_early = None + listen1_middle = None + listen1_late = None + server = None + listen2 = None def set(self, which, sock): value = getattr(self, which) @@ -47,7 +47,7 @@ def set(self, which, sock): def describe(self): info = [] - for f in attr.fields(self.__class__): + for f in attrs.fields(self.__class__): value = getattr(self, f.name) if value is not None: info.append(f"{f.name}={value}") diff --git a/notes-to-self/trace.py b/notes-to-self/trace.py index 32b190e993..046412d3ae 100644 --- a/notes-to-self/trace.py +++ b/notes-to-self/trace.py @@ -94,17 +94,17 @@ def task_scheduled(self, task): except RuntimeError: pass else: - id = next(self.ids) + id_ = next(self.ids) self._write( ph="s", cat="wakeup", - id=id, + id=id_, tid=waker._counter, ) self._write( cat="wakeup", ph="f", - id=id, + id=id_, tid=task._counter, ) diff --git a/pyproject.toml b/pyproject.toml index 32900ff72f..0e26fea83a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,79 @@ +[build-system] +requires = ["setuptools >= 64"] +build-backend = "setuptools.build_meta" + +[project] +name = "trio" +description = "A friendly Python library for async concurrency and I/O" +authors = [{name = "Nathaniel J. Smith", email = "njs@pobox.com"}] +license = {text = "MIT OR Apache-2.0"} +keywords = [ + "async", + "io", + "networking", + "trio", +] +classifiers = [ + "Development Status :: 3 - Alpha", + "Framework :: Trio", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "License :: OSI Approved :: Apache Software License", + "Operating System :: POSIX :: Linux", + "Operating System :: MacOS :: MacOS X", + "Operating System :: POSIX :: BSD", + "Operating System :: Microsoft :: Windows", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Topic :: System :: Networking", + "Typing :: Typed", +] +requires-python = ">=3.8" +dependencies = [ + # attrs 19.2.0 adds `eq` option to decorators + # attrs 20.1.0 adds @frozen + # attrs 21.1.0 adds a dataclass transform for type-checkers + # attrs 21.3.0 adds `import addrs` + "attrs >= 23.2.0", + "sortedcontainers", + "idna", + "outcome", + "sniffio >= 1.3.0", + # cffi 1.12 adds from_buffer(require_writable=True) and ffi.release() + # cffi 1.14 fixes memory leak inside ffi.getwinerror() + # cffi is required on Windows, except on PyPy where it is built-in + "cffi>=1.14; os_name == 'nt' and implementation_name != 'pypy'", + "exceptiongroup; python_version < '3.11'", +] +dynamic = ["version"] + +[project.readme] +file = "README.rst" +content-type = "text/x-rst" + +[project.urls] +Homepage = "https://github.com/python-trio/trio" +Documentation = "https://trio.readthedocs.io/" +Changelog = "https://trio.readthedocs.io/en/latest/history.html" + +[project.entry-points.hypothesis] +trio = "trio._core._run:_hypothesis_plugin_setup" + +[tool.setuptools] +# This means, just install *everything* you see under trio/, even if it +# doesn't look like a source file, so long as it appears in MANIFEST.in: +include-package-data = true + +[tool.setuptools.dynamic] +version = {attr = "trio._version.__version__"} + [tool.black] target-version = ['py38'] force-exclude = ''' @@ -11,36 +87,12 @@ force-exclude = ''' ignore-words-list = 'astroid,crasher,asend' [tool.ruff] -target-version = "py38" respect-gitignore = true fix = true -allowed-confusables = ["–"] - # The directories to consider when resolving first vs. third-party imports. # Does not control what files to include/exclude! -src = ["trio", "notes-to-self"] - -select = [ - "RUF", # Ruff-specific rules - "F", # pyflakes - "E", # Error - "W", # Warning - "I", # isort - "UP", # pyupgrade - "B", # flake8-bugbear - "YTT", # flake8-2020 - "ASYNC", # flake8-async - "PYI", # flake8-pyi - "SIM", # flake8-simplify -] -extend-ignore = [ - 'F403', # undefined-local-with-import-star - 'F405', # undefined-local-with-import-star-usage - 'E402', # module-import-not-at-top-of-file (usually OS-specific) - 'E501', # line-too-long - 'SIM117', # multiple-with-statements (messes up lots of context-based stuff and looks bad) -] +src = ["src/trio", "notes-to-self"] include = ["*.py", "*.pyi", "**/pyproject.toml"] @@ -49,20 +101,59 @@ extend-exclude = [ "docs/source/tutorial/*", ] -[tool.ruff.per-file-ignores] -'trio/__init__.py' = ['F401'] -'trio/_core/__init__.py' = ['F401'] -'trio/_core/_tests/test_multierror_scripts/*' = ['F401'] -'trio/abc.py' = ['F401'] -'trio/lowlevel.py' = ['F401'] -'trio/socket.py' = ['F401'] -'trio/testing/__init__.py' = ['F401'] +[tool.ruff.lint] +allowed-confusables = ["–"] -[tool.ruff.isort] +select = [ + "A", # flake8-builtins + "ASYNC", # flake8-async + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "E", # Error + "F", # pyflakes + "FA", # flake8-future-annotations + "I", # isort + "PERF", # Perflint + "PT", # flake8-pytest-style + "PYI", # flake8-pyi + "RUF", # Ruff-specific rules + "SIM", # flake8-simplify + "TCH", # flake8-type-checking + "UP", # pyupgrade + "W", # Warning + "YTT", # flake8-2020 +] +extend-ignore = [ + 'A002', # builtin-argument-shadowing + 'E402', # module-import-not-at-top-of-file (usually OS-specific) + 'E501', # line-too-long + 'F403', # undefined-local-with-import-star + 'F405', # undefined-local-with-import-star-usage + 'PERF203', # try-except-in-loop (not always possible to refactor) + 'PT012', # multiple statements in pytest.raises block + 'SIM117', # multiple-with-statements (messes up lots of context-based stuff and looks bad) +] + +[tool.ruff.lint.per-file-ignores] +# F401 is ignoring unused imports. For these particular files, +# these are public APIs where we are importing everything we want +# to export for public use. +'src/trio/__init__.py' = ['F401'] +'src/trio/_core/__init__.py' = ['F401'] +'src/trio/abc.py' = ['F401'] +'src/trio/lowlevel.py' = ['F401'] +'src/trio/socket.py' = ['F401'] +'src/trio/testing/__init__.py' = ['F401'] + +[tool.ruff.lint.isort] combine-as-imports = true +[tool.ruff.lint.flake8-pytest-style] +fixture-parentheses = false + [tool.mypy] python_version = "3.8" +files = ["src/trio/", "docs/source/*.py"] # Be flexible about dependencies that don't have stubs yet (like pytest) ignore_missing_imports = true @@ -85,21 +176,20 @@ disallow_untyped_decorators = true disallow_untyped_defs = true check_untyped_defs = true +[tool.pyright] +pythonVersion = "3.8" +reportUnnecessaryTypeIgnoreComment = true +typeCheckingMode = "strict" + [tool.pytest.ini_options] -addopts = ["--strict-markers", "--strict-config"] +addopts = ["--strict-markers", "--strict-config", "-p trio._tests.pytest_plugin"] faulthandler_timeout = 60 filterwarnings = [ "error", # https://gitter.im/python-trio/general?at=63bb8d0740557a3d5c688d67 'ignore:You are using cryptography on a 32-bit Python on a 64-bit Windows Operating System. Cryptography will be significantly faster if you switch to using a 64-bit Python.:UserWarning', - # this should remain until https://github.com/pytest-dev/pytest/pull/10894 is merged - 'ignore:ast.Str is deprecated:DeprecationWarning', - 'ignore:Attribute s is deprecated and will be removed:DeprecationWarning', - 'ignore:ast.NameConstant is deprecated:DeprecationWarning', - 'ignore:ast.Num is deprecated:DeprecationWarning', - # https://github.com/python/mypy/issues/15330 - 'ignore:ast.Ellipsis is deprecated:DeprecationWarning', - 'ignore:ast.Bytes is deprecated:DeprecationWarning' + # https://github.com/berkerpeksag/astor/issues/217 + 'ignore:ast.Num is deprecated:DeprecationWarning:astor', ] junit_family = "xunit2" markers = ["redistributors_should_skip: tests that should be skipped by downstream redistributors"] @@ -116,6 +206,7 @@ issue_format = "`#{issue} `_ # - At release time after bumping version number, run: towncrier # (or towncrier --draft) package = "trio" +package_dir = "src" underlines = ["-", "~", "^"] [[tool.towncrier.type]] @@ -148,7 +239,50 @@ directory = "deprecated" name = "Deprecations and removals" showcontent = true +[[tool.towncrier.type]] +directory = "removal" +name = "Removals without deprecations" +showcontent = true + [[tool.towncrier.type]] directory = "misc" name = "Miscellaneous internal changes" showcontent = true + +[tool.coverage.run] +branch = true +source_pkgs = ["trio"] +omit = [ + # Omit the generated files in trio/_core starting with _generated_ + "*/trio/_core/_generated_*", + # Type tests aren't intended to be run, just passed to type checkers. + "*/type_tests/*", + # Script used to check type completeness that isn't run in tests + "*/trio/_tests/check_type_completeness.py", +] +# The test suite spawns subprocesses to test some stuff, so make sure +# this doesn't corrupt the coverage files +parallel = true + +[tool.coverage.report] +precision = 1 +skip_covered = true +exclude_lines = [ + "pragma: no cover", + "abc.abstractmethod", + "if TYPE_CHECKING.*:", + "if _t.TYPE_CHECKING:", + "if t.TYPE_CHECKING:", + "@overload", + 'class .*\bProtocol\b.*\):', + "raise NotImplementedError", +] +partial_branches = [ + "pragma: no branch", + "if not TYPE_CHECKING:", + "if not _t.TYPE_CHECKING:", + "if not t.TYPE_CHECKING:", + "if .* or not TYPE_CHECKING:", + "if .* or not _t.TYPE_CHECKING:", + "if .* or not t.TYPE_CHECKING:", +] diff --git a/setup.py b/setup.py deleted file mode 100644 index aaf0ab09cf..0000000000 --- a/setup.py +++ /dev/null @@ -1,131 +0,0 @@ -from setuptools import find_packages, setup - -__version__ = "0.0.0" # Overwritten from _version.py below, needed for linter to identify that this variable is defined. - -with open("trio/_version.py", encoding="utf-8") as version_code: - exec(version_code.read()) - -LONG_DESC = """\ -.. image:: https://raw.githubusercontent.com/python-trio/trio/9b0bec646a31e0d0f67b8b6ecc6939726faf3e17/logo/logo-with-background.svg - :width: 200px - :align: right - -The Trio project's goal is to produce a production-quality, `permissively -licensed `__, -async/await-native I/O library for Python. Like all async libraries, -its main purpose is to help you write programs that do **multiple -things at the same time** with **parallelized I/O**. A web spider that -wants to fetch lots of pages in parallel, a web server that needs to -juggle lots of downloads and websocket connections at the same time, a -process supervisor monitoring multiple subprocesses... that sort of -thing. Compared to other libraries, Trio attempts to distinguish -itself with an obsessive focus on **usability** and -**correctness**. Concurrency is complicated; we try to make it *easy* -to get things *right*. - -Trio was built from the ground up to take advantage of the `latest -Python features `__, and -draws inspiration from `many sources -`__, in -particular Dave Beazley's `Curio `__. -The resulting design is radically simpler than older competitors like -`asyncio `__ and -`Twisted `__, yet just as capable. Trio is -the Python I/O library I always wanted; I find it makes building -I/O-oriented programs easier, less error-prone, and just plain more -fun. `Perhaps you'll find the same -`__. - -This project is young and still somewhat experimental: the overall -design is solid and the existing features are fully tested and -documented, but you may encounter missing functionality or rough -edges. We *do* encourage you do use it, but you should `read and -subscribe to issue #1 -`__ to get warning and a -chance to give feedback about any compatibility-breaking changes. - -Vital statistics: - -* Supported environments: Linux, macOS, or Windows running some kind of Python - 3.8-or-better (either CPython or PyPy3 is fine). \\*BSD and illumos likely - work too, but are not tested. - -* Install: ``python3 -m pip install -U trio`` (or on Windows, maybe - ``py -3 -m pip install -U trio``). No compiler needed. - -* Tutorial and reference manual: https://trio.readthedocs.io - -* Changelog: https://trio.readthedocs.io/en/latest/history.html - -* Bug tracker and source code: https://github.com/python-trio/trio - -* Real-time chat: https://gitter.im/python-trio/general - -* Discussion forum: https://trio.discourse.group - -* License: MIT or Apache 2, your choice - -* Contributor guide: https://trio.readthedocs.io/en/latest/contributing.html - -* Code of conduct: Contributors are requested to follow our `code of - conduct - `_ - in all project spaces. -""" - -setup( - name="trio", - version=__version__, - description="A friendly Python library for async concurrency and I/O", - long_description=LONG_DESC, - long_description_content_type="text/x-rst", - author="Nathaniel J. Smith", - author_email="njs@pobox.com", - url="https://github.com/python-trio/trio", - license="MIT OR Apache-2.0", - packages=find_packages(), - install_requires=[ - # attrs 19.2.0 adds `eq` option to decorators - # attrs 20.1.0 adds @frozen - "attrs >= 20.1.0", - "sortedcontainers", - "idna", - "outcome", - "sniffio >= 1.3.0", - # cffi 1.12 adds from_buffer(require_writable=True) and ffi.release() - # cffi 1.14 fixes memory leak inside ffi.getwinerror() - # cffi is required on Windows, except on PyPy where it is built-in - "cffi>=1.14; os_name == 'nt' and implementation_name != 'pypy'", - "exceptiongroup >= 1.0.0rc9; python_version < '3.11'", - ], - # This means, just install *everything* you see under trio/, even if it - # doesn't look like a source file, so long as it appears in MANIFEST.in: - include_package_data=True, - python_requires=">=3.8", - keywords=["async", "io", "networking", "trio"], - classifiers=[ - "Development Status :: 3 - Alpha", - "Framework :: Trio", - "Intended Audience :: Developers", - "License :: OSI Approved :: MIT License", - "License :: OSI Approved :: Apache Software License", - "Operating System :: POSIX :: Linux", - "Operating System :: MacOS :: MacOS X", - "Operating System :: POSIX :: BSD", - "Operating System :: Microsoft :: Windows", - "Programming Language :: Python :: Implementation :: CPython", - "Programming Language :: Python :: Implementation :: PyPy", - "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Topic :: System :: Networking", - "Typing :: Typed", - ], - project_urls={ - "Documentation": "https://trio.readthedocs.io/", - "Changelog": "https://trio.readthedocs.io/en/latest/history.html", - }, -) diff --git a/trio/__init__.py b/src/trio/__init__.py similarity index 83% rename from trio/__init__.py rename to src/trio/__init__.py index 0574186c5d..d2151677b1 100644 --- a/trio/__init__.py +++ b/src/trio/__init__.py @@ -1,7 +1,10 @@ """Trio - A friendly Python library for async concurrency and I/O """ + from __future__ import annotations +from typing import TYPE_CHECKING + # General layout: # # trio/_core/... is the self-contained core library. It does various @@ -43,10 +46,6 @@ open_nursery as open_nursery, run as run, ) -from ._core._multierror import ( - MultiError as _MultiError, - NonBaseMultiError as _NonBaseMultiError, -) from ._deprecate import TrioDeprecationWarning as TrioDeprecationWarning from ._dtls import ( DTLSChannel as DTLSChannel, @@ -74,7 +73,7 @@ open_ssl_over_tcp_stream as open_ssl_over_tcp_stream, serve_ssl_over_tcp as serve_ssl_over_tcp, ) -from ._path import Path as Path +from ._path import Path as Path, PosixPath as PosixPath, WindowsPath as WindowsPath from ._signals import open_signal_receiver as open_signal_receiver from ._ssl import ( NeedHandshakeError as NeedHandshakeError, @@ -111,39 +110,14 @@ # Not imported by default, but mentioned here so static analysis tools like # pylint will know that it exists. -if False: +if TYPE_CHECKING: from . import testing from . import _deprecate as _deprecate _deprecate.enable_attribute_deprecations(__name__) -__deprecated_attributes__: dict[str, _deprecate.DeprecatedAttribute] = { - "open_process": _deprecate.DeprecatedAttribute( - value=lowlevel.open_process, - version="0.20.0", - issue=1104, - instead="trio.lowlevel.open_process", - ), - "MultiError": _deprecate.DeprecatedAttribute( - value=_MultiError, - version="0.22.0", - issue=2211, - instead=( - "BaseExceptionGroup (on Python 3.11 and later) or " - "exceptiongroup.BaseExceptionGroup (earlier versions)" - ), - ), - "NonBaseMultiError": _deprecate.DeprecatedAttribute( - value=_NonBaseMultiError, - version="0.22.0", - issue=2211, - instead=( - "ExceptionGroup (on Python 3.11 and later) or " - "exceptiongroup.ExceptionGroup (earlier versions)" - ), - ), -} +__deprecated_attributes__: dict[str, _deprecate.DeprecatedAttribute] = {} # Having the public path in .__module__ attributes is important for: # - exception names in printed tracebacks @@ -160,3 +134,4 @@ fixup_module_metadata(from_thread.__name__, from_thread.__dict__) fixup_module_metadata(to_thread.__name__, to_thread.__dict__) del fixup_module_metadata +del TYPE_CHECKING diff --git a/src/trio/__main__.py b/src/trio/__main__.py new file mode 100644 index 0000000000..3b7c898ad5 --- /dev/null +++ b/src/trio/__main__.py @@ -0,0 +1,3 @@ +from trio._repl import main + +main(locals()) diff --git a/trio/_abc.py b/src/trio/_abc.py similarity index 98% rename from trio/_abc.py rename to src/trio/_abc.py index 76b1fdfba4..20f1614cc6 100644 --- a/trio/_abc.py +++ b/src/trio/_abc.py @@ -1,7 +1,7 @@ from __future__ import annotations import socket -from abc import ABCMeta, abstractmethod +from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Generic, TypeVar import trio @@ -16,9 +16,7 @@ from .lowlevel import Task -# We use ABCMeta instead of ABC, plus set __slots__=(), so as not to force a -# __dict__ onto subclasses. -class Clock(metaclass=ABCMeta): +class Clock(ABC): """The interface for custom run loop clocks.""" __slots__ = () @@ -68,7 +66,7 @@ def deadline_to_sleep_time(self, deadline: float) -> float: """ -class Instrument(metaclass=ABCMeta): +class Instrument(ABC): # noqa: B024 # conceptually is ABC """The interface for run loop instrumentation. Instruments don't have to inherit from this abstract base class, and all @@ -155,7 +153,7 @@ def after_io_wait(self, timeout: float) -> None: return -class HostnameResolver(metaclass=ABCMeta): +class HostnameResolver(ABC): """If you have a custom hostname resolver, then implementing :class:`HostnameResolver` allows you to register this to be used by Trio. @@ -168,7 +166,7 @@ class HostnameResolver(metaclass=ABCMeta): @abstractmethod async def getaddrinfo( self, - host: bytes | str | None, + host: bytes | None, port: bytes | str | int | None, family: int = 0, type: int = 0, @@ -209,7 +207,7 @@ async def getnameinfo( """ -class SocketFactory(metaclass=ABCMeta): +class SocketFactory(ABC): """If you write a custom class implementing the Trio socket interface, then you can use a :class:`SocketFactory` to get Trio to use it. @@ -217,6 +215,8 @@ class SocketFactory(metaclass=ABCMeta): """ + __slots__ = () + @abstractmethod def socket( self, @@ -240,7 +240,7 @@ def socket( """ -class AsyncResource(metaclass=ABCMeta): +class AsyncResource(ABC): """A standard interface for resources that needs to be cleaned up, and where that cleanup may require blocking operations. @@ -698,3 +698,5 @@ class Channel(SendChannel[T], ReceiveChannel[T]): `ReceiveChannel` interfaces, so you can both send and receive objects. """ + + __slots__ = () diff --git a/trio/_channel.py b/src/trio/_channel.py similarity index 94% rename from trio/_channel.py rename to src/trio/_channel.py index cf7520c846..3d8445bc59 100644 --- a/trio/_channel.py +++ b/src/trio/_channel.py @@ -4,7 +4,7 @@ from collections.abc import Iterable from math import inf from operator import itemgetter -from types import TracebackType + from typing import ( TYPE_CHECKING, Generic, @@ -12,7 +12,7 @@ Tuple, # only needed for typechecking on <3.9 ) -import attr +import attrs from outcome import Error, Value import trio @@ -22,6 +22,8 @@ from ._util import NoPublicConstructor, final, generic_function if TYPE_CHECKING: + from types import TracebackType + from typing_extensions import Self @@ -113,27 +115,27 @@ def __init__(self, max_buffer_size: int | float): # noqa: PYI041 open_memory_channel = generic_function(_open_memory_channel) -@attr.s(frozen=True, slots=True) +@attrs.frozen class MemoryChannelStats: - current_buffer_used: int = attr.ib() - max_buffer_size: int | float = attr.ib() - open_send_channels: int = attr.ib() - open_receive_channels: int = attr.ib() - tasks_waiting_send: int = attr.ib() - tasks_waiting_receive: int = attr.ib() + current_buffer_used: int + max_buffer_size: int | float + open_send_channels: int + open_receive_channels: int + tasks_waiting_send: int + tasks_waiting_receive: int -@attr.s(slots=True) +@attrs.define class MemoryChannelState(Generic[T]): - max_buffer_size: int | float = attr.ib() - data: deque[T] = attr.ib(factory=deque) + max_buffer_size: int | float + data: deque[T] = attrs.Factory(deque) # Counts of open endpoints using this state - open_send_channels: int = attr.ib(default=0) - open_receive_channels: int = attr.ib(default=0) + open_send_channels: int = 0 + open_receive_channels: int = 0 # {task: value} - send_tasks: OrderedDict[Task, T] = attr.ib(factory=OrderedDict) + send_tasks: OrderedDict[Task, T] = attrs.Factory(OrderedDict) # {task: None} - receive_tasks: OrderedDict[Task, None] = attr.ib(factory=OrderedDict) + receive_tasks: OrderedDict[Task, None] = attrs.Factory(OrderedDict) def statistics(self) -> MemoryChannelStats: return MemoryChannelStats( @@ -147,18 +149,17 @@ def statistics(self) -> MemoryChannelStats: @final -@attr.s(eq=False, repr=False) +@attrs.define(eq=False, repr=False, slots=False) class MemorySendChannel( SendChannel[SendType], Generic[SendType], metaclass=NoPublicConstructor, ): - _state: MemoryChannelState[SendType] = attr.ib() - _closed: bool = attr.ib(default=False) + _state: MemoryChannelState[SendType] # This is just the tasks waiting on *this* object. As compared to # self._state.send_tasks, which includes tasks from this object and # all clones. - _tasks: set[Task] = attr.ib(factory=set) + _tasks: set[Task] = attrs.Factory(set) def __attrs_post_init__(self) -> None: self._state.open_send_channels += 1 @@ -294,15 +295,15 @@ async def aclose(self) -> None: @final -@attr.s(eq=False, repr=False) +@attrs.define(eq=False, repr=False, slots=False) class MemoryReceiveChannel( ReceiveChannel[ReceiveType], Generic[ReceiveType], metaclass=NoPublicConstructor, ): - _state: MemoryChannelState[ReceiveType] = attr.ib() - _closed: bool = attr.ib(default=False) - _tasks: set[trio._core._run.Task] = attr.ib(factory=set) + _state: MemoryChannelState[ReceiveType] + _closed: bool = False + _tasks: set[trio._core._run.Task] = attrs.Factory(set) def __attrs_post_init__(self) -> None: self._state.open_receive_channels += 1 @@ -311,8 +312,8 @@ def statistics(self) -> MemoryChannelStats: return self._state.statistics() def __repr__(self) -> str: - return "".format( - id(self), id(self._state) + return ( + f"" ) @enable_ki_protection diff --git a/trio/_core/__init__.py b/src/trio/_core/__init__.py similarity index 100% rename from trio/_core/__init__.py rename to src/trio/_core/__init__.py diff --git a/trio/_core/_asyncgens.py b/src/trio/_core/_asyncgens.py similarity index 96% rename from trio/_core/_asyncgens.py rename to src/trio/_core/_asyncgens.py index 4261328278..1a622dadfc 100644 --- a/trio/_core/_asyncgens.py +++ b/src/trio/_core/_asyncgens.py @@ -4,10 +4,9 @@ import sys import warnings import weakref -from types import AsyncGeneratorType from typing import TYPE_CHECKING, NoReturn -import attr +import attrs from .. import _core from .._util import name_asyncgen @@ -17,6 +16,7 @@ ASYNCGEN_LOGGER = logging.getLogger("trio.async_generator_errors") if TYPE_CHECKING: + from types import AsyncGeneratorType from typing import Set _WEAK_ASYNC_GEN_SET = weakref.WeakSet[AsyncGeneratorType[object, NoReturn]] @@ -26,7 +26,7 @@ _ASYNC_GEN_SET = set -@attr.s(eq=False, slots=True) +@attrs.define(eq=False) class AsyncGenerators: # Async generators are added to this set when first iterated. Any # left after the main task exits will be closed before trio.run() @@ -35,14 +35,14 @@ class AsyncGenerators: # asyncgens after the system nursery has been closed, it's a # regular set so we don't have to deal with GC firing at # unexpected times. - alive: _WEAK_ASYNC_GEN_SET | _ASYNC_GEN_SET = attr.ib(factory=_WEAK_ASYNC_GEN_SET) + alive: _WEAK_ASYNC_GEN_SET | _ASYNC_GEN_SET = attrs.Factory(_WEAK_ASYNC_GEN_SET) # This collects async generators that get garbage collected during # the one-tick window between the system nursery closing and the # init task starting end-of-run asyncgen finalization. - trailing_needs_finalize: _ASYNC_GEN_SET = attr.ib(factory=_ASYNC_GEN_SET) + trailing_needs_finalize: _ASYNC_GEN_SET = attrs.Factory(_ASYNC_GEN_SET) - prev_hooks = attr.ib(init=False) + prev_hooks: sys._asyncgen_hooks = attrs.field(init=False) def install_hooks(self, runner: _run.Runner) -> None: def firstiter(agen: AsyncGeneratorType[object, NoReturn]) -> None: diff --git a/src/trio/_core/_concat_tb.py b/src/trio/_core/_concat_tb.py new file mode 100644 index 0000000000..497d37f8ad --- /dev/null +++ b/src/trio/_core/_concat_tb.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +from types import TracebackType +from typing import Any, ClassVar, cast + +################################################################ +# concat_tb +################################################################ + +# We need to compute a new traceback that is the concatenation of two existing +# tracebacks. This requires copying the entries in 'head' and then pointing +# the final tb_next to 'tail'. +# +# NB: 'tail' might be None, which requires some special handling in the ctypes +# version. +# +# The complication here is that Python doesn't actually support copying or +# modifying traceback objects, so we have to get creative... +# +# On CPython, we use ctypes. On PyPy, we use "transparent proxies". +# +# Jinja2 is a useful source of inspiration: +# https://github.com/pallets/jinja/blob/main/src/jinja2/debug.py + +try: + import tputil +except ImportError: + # ctypes it is + # How to handle refcounting? I don't want to use ctypes.py_object because + # I don't understand or trust it, and I don't want to use + # ctypes.pythonapi.Py_{Inc,Dec}Ref because we might clash with user code + # that also tries to use them but with different types. So private _ctypes + # APIs it is! + import _ctypes + import ctypes + + class CTraceback(ctypes.Structure): + _fields_: ClassVar = [ + ("PyObject_HEAD", ctypes.c_byte * object().__sizeof__()), + ("tb_next", ctypes.c_void_p), + ("tb_frame", ctypes.c_void_p), + ("tb_lasti", ctypes.c_int), + ("tb_lineno", ctypes.c_int), + ] + + def copy_tb(base_tb: TracebackType, tb_next: TracebackType | None) -> TracebackType: + # TracebackType has no public constructor, so allocate one the hard way + try: + raise ValueError + except ValueError as exc: + new_tb = exc.__traceback__ + assert new_tb is not None + c_new_tb = CTraceback.from_address(id(new_tb)) + + # At the C level, tb_next either points to the next traceback or is + # NULL. c_void_p and the .tb_next accessor both convert NULL to None, + # but we shouldn't DECREF None just because we assigned to a NULL + # pointer! Here we know that our new traceback has only 1 frame in it, + # so we can assume the tb_next field is NULL. + assert c_new_tb.tb_next is None + # If tb_next is None, then we want to set c_new_tb.tb_next to NULL, + # which it already is, so we're done. Otherwise, we have to actually + # do some work: + if tb_next is not None: + _ctypes.Py_INCREF(tb_next) # type: ignore[attr-defined] + c_new_tb.tb_next = id(tb_next) + + assert c_new_tb.tb_frame is not None + _ctypes.Py_INCREF(base_tb.tb_frame) # type: ignore[attr-defined] + old_tb_frame = new_tb.tb_frame + c_new_tb.tb_frame = id(base_tb.tb_frame) + _ctypes.Py_DECREF(old_tb_frame) # type: ignore[attr-defined] + + c_new_tb.tb_lasti = base_tb.tb_lasti + c_new_tb.tb_lineno = base_tb.tb_lineno + + try: + return new_tb + finally: + # delete references from locals to avoid creating cycles + # see test_cancel_scope_exit_doesnt_create_cyclic_garbage + del new_tb, old_tb_frame + +else: + # http://doc.pypy.org/en/latest/objspace-proxies.html + def copy_tb(base_tb: TracebackType, tb_next: TracebackType | None) -> TracebackType: + # tputil.ProxyOperation is PyPy-only, and there's no way to specify + # cpython/pypy in current type checkers. + def controller(operation: tputil.ProxyOperation) -> Any | None: # type: ignore[no-any-unimported] + # Rationale for pragma: I looked fairly carefully and tried a few + # things, and AFAICT it's not actually possible to get any + # 'opname' that isn't __getattr__ or __getattribute__. So there's + # no missing test we could add, and no value in coverage nagging + # us about adding one. + if ( + operation.opname + in { + "__getattribute__", + "__getattr__", + } + and operation.args[0] == "tb_next" + ): # pragma: no cover + return tb_next + return operation.delegate() # Delegate is reverting to original behaviour + + return cast( + TracebackType, tputil.make_proxy(controller, type(base_tb), base_tb) + ) # Returns proxy to traceback + + +# this is used for collapsing single-exception ExceptionGroups when using +# `strict_exception_groups=False`. Once that is retired this function and its helper can +# be removed as well. +def concat_tb( + head: TracebackType | None, tail: TracebackType | None +) -> TracebackType | None: + # We have to use an iterative algorithm here, because in the worst case + # this might be a RecursionError stack that is by definition too deep to + # process by recursion! + head_tbs = [] + pointer = head + while pointer is not None: + head_tbs.append(pointer) + pointer = pointer.tb_next + current_head = tail + for head_tb in reversed(head_tbs): + current_head = copy_tb(head_tb, tb_next=current_head) + return current_head diff --git a/trio/_core/_entry_queue.py b/src/trio/_core/_entry_queue.py similarity index 91% rename from trio/_core/_entry_queue.py rename to src/trio/_core/_entry_queue.py index cb91025fbb..7f1eea8e29 100644 --- a/trio/_core/_entry_queue.py +++ b/src/trio/_core/_entry_queue.py @@ -2,21 +2,24 @@ import threading from collections import deque -from typing import Callable, Iterable, NoReturn, Tuple +from typing import TYPE_CHECKING, Callable, NoReturn, Tuple -import attr +import attrs from .. import _core from .._util import NoPublicConstructor, final from ._wakeup_socketpair import WakeupSocketpair -# TODO: Type with TypeVarTuple, at least to an extent where it makes -# the public interface safe. +if TYPE_CHECKING: + from typing_extensions import TypeVarTuple, Unpack + + PosArgsT = TypeVarTuple("PosArgsT") + Function = Callable[..., object] -Job = Tuple[Function, Iterable[object]] +Job = Tuple[Function, Tuple[object, ...]] -@attr.s(slots=True) +@attrs.define class EntryQueue: # This used to use a queue.Queue. but that was broken, because Queues are # implemented in Python, and not reentrant -- so it was thread-safe, but @@ -25,11 +28,11 @@ class EntryQueue: # atomic WRT signal delivery (signal handlers can run on either side, but # not *during* a deque operation). dict makes similar guarantees - and # it's even ordered! - queue: deque[Job] = attr.ib(factory=deque) - idempotent_queue: dict[Job, None] = attr.ib(factory=dict) + queue: deque[Job] = attrs.Factory(deque) + idempotent_queue: dict[Job, None] = attrs.Factory(dict) - wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair) - done: bool = attr.ib(default=False) + wakeup: WakeupSocketpair = attrs.Factory(WakeupSocketpair) + done: bool = False # Must be a reentrant lock, because it's acquired from signal handlers. # RLock is signal-safe as of cpython 3.2. NB that this does mean that the # lock is effectively *disabled* when we enter from signal context. The @@ -38,7 +41,7 @@ class EntryQueue: # main thread -- it just might happen at some inconvenient place. But if # you look at the one place where the main thread holds the lock, it's # just to make 1 assignment, so that's atomic WRT a signal anyway. - lock: threading.RLock = attr.ib(factory=threading.RLock) + lock: threading.RLock = attrs.Factory(threading.RLock) async def task(self) -> None: assert _core.currently_ki_protected() @@ -122,7 +125,10 @@ def size(self) -> int: return len(self.queue) + len(self.idempotent_queue) def run_sync_soon( - self, sync_fn: Function, *args: object, idempotent: bool = False + self, + sync_fn: Callable[[Unpack[PosArgsT]], object], + *args: Unpack[PosArgsT], + idempotent: bool = False, ) -> None: with self.lock: if self.done: @@ -140,7 +146,7 @@ def run_sync_soon( @final -@attr.s(eq=False, hash=False, slots=True) +@attrs.define(eq=False, hash=False) class TrioToken(metaclass=NoPublicConstructor): """An opaque object representing a single call to :func:`trio.run`. @@ -160,10 +166,13 @@ class TrioToken(metaclass=NoPublicConstructor): """ - _reentry_queue: EntryQueue = attr.ib() + _reentry_queue: EntryQueue def run_sync_soon( - self, sync_fn: Function, *args: object, idempotent: bool = False + self, + sync_fn: Callable[[Unpack[PosArgsT]], object], + *args: Unpack[PosArgsT], + idempotent: bool = False, ) -> None: """Schedule a call to ``sync_fn(*args)`` to occur in the context of a Trio task. diff --git a/trio/_core/_exceptions.py b/src/trio/_core/_exceptions.py similarity index 100% rename from trio/_core/_exceptions.py rename to src/trio/_core/_exceptions.py diff --git a/trio/_core/_generated_instrumentation.py b/src/trio/_core/_generated_instrumentation.py similarity index 82% rename from trio/_core/_generated_instrumentation.py rename to src/trio/_core/_generated_instrumentation.py index c7fefc307a..568b76dffa 100644 --- a/trio/_core/_generated_instrumentation.py +++ b/src/trio/_core/_generated_instrumentation.py @@ -3,10 +3,17 @@ # ************************************************************* from __future__ import annotations -from ._instrumentation import Instrument +import sys +from typing import TYPE_CHECKING + from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._run import GLOBAL_RUN_CONTEXT +if TYPE_CHECKING: + from ._instrumentation import Instrument + +__all__ = ["add_instrument", "remove_instrument"] + def add_instrument(instrument: Instrument) -> None: """Start instrumenting the current run loop with the given instrument. @@ -17,7 +24,7 @@ def add_instrument(instrument: Instrument) -> None: If ``instrument`` is already active, does nothing. """ - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.instruments.add_instrument(instrument) except AttributeError: @@ -37,7 +44,7 @@ def remove_instrument(instrument: Instrument) -> None: deactivated. """ - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.instruments.remove_instrument(instrument) except AttributeError: diff --git a/src/trio/_core/_generated_io_epoll.py b/src/trio/_core/_generated_io_epoll.py new file mode 100644 index 0000000000..9f9ad59725 --- /dev/null +++ b/src/trio/_core/_generated_io_epoll.py @@ -0,0 +1,98 @@ +# *********************************************************** +# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** +# ************************************************************* +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING + +from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._run import GLOBAL_RUN_CONTEXT + +if TYPE_CHECKING: + from .._file_io import _HasFileNo + +assert not TYPE_CHECKING or sys.platform == "linux" + + +__all__ = ["notify_closing", "wait_readable", "wait_writable"] + + +async def wait_readable(fd: int | _HasFileNo) -> None: + """Block until the kernel reports that the given object is readable. + + On Unix systems, ``fd`` must either be an integer file descriptor, + or else an object with a ``.fileno()`` method which returns an + integer file descriptor. Any kind of file descriptor can be passed, + though the exact semantics will depend on your kernel. For example, + this probably won't do anything useful for on-disk files. + + On Windows systems, ``fd`` must either be an integer ``SOCKET`` + handle, or else an object with a ``.fileno()`` method which returns + an integer ``SOCKET`` handle. File descriptors aren't supported, + and neither are handles that refer to anything besides a + ``SOCKET``. + + :raises trio.BusyResourceError: + if another task is already waiting for the given socket to + become readable. + :raises trio.ClosedResourceError: + if another task calls :func:`notify_closing` while this + function is still working. + """ + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + try: + return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +async def wait_writable(fd: int | _HasFileNo) -> None: + """Block until the kernel reports that the given object is writable. + + See `wait_readable` for the definition of ``fd``. + + :raises trio.BusyResourceError: + if another task is already waiting for the given socket to + become writable. + :raises trio.ClosedResourceError: + if another task calls :func:`notify_closing` while this + function is still working. + """ + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + try: + return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +def notify_closing(fd: int | _HasFileNo) -> None: + """Notify waiters of the given object that it will be closed. + + Call this before closing a file descriptor (on Unix) or socket (on + Windows). This will cause any `wait_readable` or `wait_writable` + calls on the given object to immediately wake up and raise + `~trio.ClosedResourceError`. + + This doesn't actually close the object – you still have to do that + yourself afterwards. Also, you want to be careful to make sure no + new tasks start waiting on the object in between when you call this + and when it's actually closed. So to close something properly, you + usually want to do these steps in order: + + 1. Explicitly mark the object as closed, so that any new attempts + to use it will abort before they start. + 2. Call `notify_closing` to wake up any already-existing users. + 3. Actually close the object. + + It's also possible to do them in a different order if that's more + convenient, *but only if* you make sure not to have any checkpoints in + between the steps. This way they all happen in a single atomic + step, so other tasks won't be able to tell what order they happened + in anyway. + """ + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + try: + return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) + except AttributeError: + raise RuntimeError("must be called from async context") from None diff --git a/src/trio/_core/_generated_io_kqueue.py b/src/trio/_core/_generated_io_kqueue.py new file mode 100644 index 0000000000..e150fc21f9 --- /dev/null +++ b/src/trio/_core/_generated_io_kqueue.py @@ -0,0 +1,151 @@ +# *********************************************************** +# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** +# ************************************************************* +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING, Callable, ContextManager + +from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._run import GLOBAL_RUN_CONTEXT + +if TYPE_CHECKING: + import select + + from .. import _core + from .._file_io import _HasFileNo + from ._traps import Abort, RaiseCancelT + +assert not TYPE_CHECKING or sys.platform == "darwin" + + +__all__ = [ + "current_kqueue", + "monitor_kevent", + "notify_closing", + "wait_kevent", + "wait_readable", + "wait_writable", +] + + +def current_kqueue() -> select.kqueue: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + `__. + """ + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + try: + return GLOBAL_RUN_CONTEXT.runner.io_manager.current_kqueue() + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +def monitor_kevent( + ident: int, filter: int +) -> ContextManager[_core.UnboundedQueue[select.kevent]]: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + `__. + """ + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + try: + return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_kevent(ident, filter) + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +async def wait_kevent( + ident: int, filter: int, abort_func: Callable[[RaiseCancelT], Abort] +) -> Abort: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + `__. + """ + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + try: + return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_kevent( + ident, filter, abort_func + ) + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +async def wait_readable(fd: int | _HasFileNo) -> None: + """Block until the kernel reports that the given object is readable. + + On Unix systems, ``fd`` must either be an integer file descriptor, + or else an object with a ``.fileno()`` method which returns an + integer file descriptor. Any kind of file descriptor can be passed, + though the exact semantics will depend on your kernel. For example, + this probably won't do anything useful for on-disk files. + + On Windows systems, ``fd`` must either be an integer ``SOCKET`` + handle, or else an object with a ``.fileno()`` method which returns + an integer ``SOCKET`` handle. File descriptors aren't supported, + and neither are handles that refer to anything besides a + ``SOCKET``. + + :raises trio.BusyResourceError: + if another task is already waiting for the given socket to + become readable. + :raises trio.ClosedResourceError: + if another task calls :func:`notify_closing` while this + function is still working. + """ + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + try: + return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +async def wait_writable(fd: int | _HasFileNo) -> None: + """Block until the kernel reports that the given object is writable. + + See `wait_readable` for the definition of ``fd``. + + :raises trio.BusyResourceError: + if another task is already waiting for the given socket to + become writable. + :raises trio.ClosedResourceError: + if another task calls :func:`notify_closing` while this + function is still working. + """ + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + try: + return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +def notify_closing(fd: int | _HasFileNo) -> None: + """Notify waiters of the given object that it will be closed. + + Call this before closing a file descriptor (on Unix) or socket (on + Windows). This will cause any `wait_readable` or `wait_writable` + calls on the given object to immediately wake up and raise + `~trio.ClosedResourceError`. + + This doesn't actually close the object – you still have to do that + yourself afterwards. Also, you want to be careful to make sure no + new tasks start waiting on the object in between when you call this + and when it's actually closed. So to close something properly, you + usually want to do these steps in order: + + 1. Explicitly mark the object as closed, so that any new attempts + to use it will abort before they start. + 2. Call `notify_closing` to wake up any already-existing users. + 3. Actually close the object. + + It's also possible to do them in a different order if that's more + convenient, *but only if* you make sure not to have any checkpoints in + between the steps. This way they all happen in a single atomic + step, so other tasks won't be able to tell what order they happened + in anyway. + """ + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + try: + return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) + except AttributeError: + raise RuntimeError("must be called from async context") from None diff --git a/src/trio/_core/_generated_io_windows.py b/src/trio/_core/_generated_io_windows.py new file mode 100644 index 0000000000..72264f599f --- /dev/null +++ b/src/trio/_core/_generated_io_windows.py @@ -0,0 +1,200 @@ +# *********************************************************** +# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** +# ************************************************************* +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING, ContextManager + +from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED +from ._run import GLOBAL_RUN_CONTEXT + +if TYPE_CHECKING: + from typing_extensions import Buffer + + from .._file_io import _HasFileNo + from ._unbounded_queue import UnboundedQueue + from ._windows_cffi import CData, Handle + +assert not TYPE_CHECKING or sys.platform == "win32" + + +__all__ = [ + "current_iocp", + "monitor_completion_key", + "notify_closing", + "readinto_overlapped", + "register_with_iocp", + "wait_overlapped", + "wait_readable", + "wait_writable", + "write_overlapped", +] + + +async def wait_readable(sock: _HasFileNo | int) -> None: + """Block until the kernel reports that the given object is readable. + + On Unix systems, ``sock`` must either be an integer file descriptor, + or else an object with a ``.fileno()`` method which returns an + integer file descriptor. Any kind of file descriptor can be passed, + though the exact semantics will depend on your kernel. For example, + this probably won't do anything useful for on-disk files. + + On Windows systems, ``sock`` must either be an integer ``SOCKET`` + handle, or else an object with a ``.fileno()`` method which returns + an integer ``SOCKET`` handle. File descriptors aren't supported, + and neither are handles that refer to anything besides a + ``SOCKET``. + + :raises trio.BusyResourceError: + if another task is already waiting for the given socket to + become readable. + :raises trio.ClosedResourceError: + if another task calls :func:`notify_closing` while this + function is still working. + """ + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + try: + return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(sock) + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +async def wait_writable(sock: _HasFileNo | int) -> None: + """Block until the kernel reports that the given object is writable. + + See `wait_readable` for the definition of ``sock``. + + :raises trio.BusyResourceError: + if another task is already waiting for the given socket to + become writable. + :raises trio.ClosedResourceError: + if another task calls :func:`notify_closing` while this + function is still working. + """ + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + try: + return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(sock) + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +def notify_closing(handle: Handle | int | _HasFileNo) -> None: + """Notify waiters of the given object that it will be closed. + + Call this before closing a file descriptor (on Unix) or socket (on + Windows). This will cause any `wait_readable` or `wait_writable` + calls on the given object to immediately wake up and raise + `~trio.ClosedResourceError`. + + This doesn't actually close the object – you still have to do that + yourself afterwards. Also, you want to be careful to make sure no + new tasks start waiting on the object in between when you call this + and when it's actually closed. So to close something properly, you + usually want to do these steps in order: + + 1. Explicitly mark the object as closed, so that any new attempts + to use it will abort before they start. + 2. Call `notify_closing` to wake up any already-existing users. + 3. Actually close the object. + + It's also possible to do them in a different order if that's more + convenient, *but only if* you make sure not to have any checkpoints in + between the steps. This way they all happen in a single atomic + step, so other tasks won't be able to tell what order they happened + in anyway. + """ + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + try: + return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(handle) + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +def register_with_iocp(handle: int | CData) -> None: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + `__ and `#52 + `__. + """ + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + try: + return GLOBAL_RUN_CONTEXT.runner.io_manager.register_with_iocp(handle) + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +async def wait_overlapped(handle_: int | CData, lpOverlapped: CData | int) -> object: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + `__ and `#52 + `__. + """ + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + try: + return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_overlapped( + handle_, lpOverlapped + ) + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +async def write_overlapped( + handle: int | CData, data: Buffer, file_offset: int = 0 +) -> int: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + `__ and `#52 + `__. + """ + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + try: + return await GLOBAL_RUN_CONTEXT.runner.io_manager.write_overlapped( + handle, data, file_offset + ) + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +async def readinto_overlapped( + handle: int | CData, buffer: Buffer, file_offset: int = 0 +) -> int: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + `__ and `#52 + `__. + """ + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + try: + return await GLOBAL_RUN_CONTEXT.runner.io_manager.readinto_overlapped( + handle, buffer, file_offset + ) + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +def current_iocp() -> int: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + `__ and `#52 + `__. + """ + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + try: + return GLOBAL_RUN_CONTEXT.runner.io_manager.current_iocp() + except AttributeError: + raise RuntimeError("must be called from async context") from None + + +def monitor_completion_key() -> ContextManager[tuple[int, UnboundedQueue[object]]]: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + `__ and `#52 + `__. + """ + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + try: + return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_completion_key() + except AttributeError: + raise RuntimeError("must be called from async context") from None diff --git a/trio/_core/_generated_run.py b/src/trio/_core/_generated_run.py similarity index 88% rename from trio/_core/_generated_run.py rename to src/trio/_core/_generated_run.py index 399e1dba85..ac3e0f39d6 100644 --- a/trio/_core/_generated_run.py +++ b/src/trio/_core/_generated_run.py @@ -3,17 +3,35 @@ # ************************************************************* from __future__ import annotations -import contextvars -from collections.abc import Awaitable, Callable -from typing import Any +import sys +from typing import TYPE_CHECKING, Any -from outcome import Outcome - -from .._abc import Clock -from ._entry_queue import TrioToken from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT, RunStatistics, Task +if TYPE_CHECKING: + import contextvars + from collections.abc import Awaitable, Callable + + from outcome import Outcome + from typing_extensions import Unpack + + from .._abc import Clock + from ._entry_queue import TrioToken + from ._run import PosArgT + + +__all__ = [ + "current_clock", + "current_root_task", + "current_statistics", + "current_time", + "current_trio_token", + "reschedule", + "spawn_system_task", + "wait_all_tasks_blocked", +] + def current_statistics() -> RunStatistics: """Returns ``RunStatistics``, which contains run-loop-level debugging information. @@ -38,7 +56,7 @@ def current_statistics() -> RunStatistics: other attributes vary between backends. """ - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.current_statistics() except AttributeError: @@ -55,7 +73,7 @@ def current_time() -> float: RuntimeError: if not inside a call to :func:`trio.run`. """ - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.current_time() except AttributeError: @@ -64,7 +82,7 @@ def current_time() -> float: def current_clock() -> Clock: """Returns the current :class:`~trio.abc.Clock`.""" - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.current_clock() except AttributeError: @@ -77,7 +95,7 @@ def current_root_task() -> Task | None: This is the task that is the ultimate parent of all other tasks. """ - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.current_root_task() except AttributeError: @@ -102,7 +120,7 @@ def reschedule(task: Task, next_send: Outcome[Any] = _NO_SEND) -> None: raise) from :func:`wait_task_rescheduled`. """ - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.reschedule(task, next_send) except AttributeError: @@ -110,10 +128,10 @@ def reschedule(task: Task, next_send: Outcome[Any] = _NO_SEND) -> None: def spawn_system_task( - async_fn: Callable[..., Awaitable[object]], - *args: object, + async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]], + *args: Unpack[PosArgT], name: object = None, - context: (contextvars.Context | None) = None, + context: contextvars.Context | None = None, ) -> Task: """Spawn a "system" task. @@ -166,7 +184,7 @@ def spawn_system_task( Task: the newly spawned task """ - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.spawn_system_task( async_fn, *args, name=name, context=context @@ -180,7 +198,7 @@ def current_trio_token() -> TrioToken: :func:`trio.run`. """ - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return GLOBAL_RUN_CONTEXT.runner.current_trio_token() except AttributeError: @@ -245,7 +263,7 @@ async def test_lock_fairness(): print("FAIL") """ - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return await GLOBAL_RUN_CONTEXT.runner.wait_all_tasks_blocked(cushion) except AttributeError: diff --git a/trio/_core/_instrumentation.py b/src/trio/_core/_instrumentation.py similarity index 100% rename from trio/_core/_instrumentation.py rename to src/trio/_core/_instrumentation.py diff --git a/trio/_core/_io_common.py b/src/trio/_core/_io_common.py similarity index 100% rename from trio/_core/_io_common.py rename to src/trio/_core/_io_common.py diff --git a/trio/_core/_io_epoll.py b/src/trio/_core/_io_epoll.py similarity index 80% rename from trio/_core/_io_epoll.py rename to src/trio/_core/_io_epoll.py index 12ae0e7028..1f4ae49f7a 100644 --- a/trio/_core/_io_epoll.py +++ b/src/trio/_core/_io_epoll.py @@ -6,7 +6,7 @@ from collections import defaultdict from typing import TYPE_CHECKING, Literal -import attr +import attrs from .. import _core from ._io_common import wake_all @@ -20,11 +20,11 @@ from .._file_io import _HasFileNo -@attr.s(slots=True, eq=False) +@attrs.define(eq=False) class EpollWaiters: - read_task: Task | None = attr.ib(default=None) - write_task: Task | None = attr.ib(default=None) - current_flags: int = attr.ib(default=0) + read_task: Task | None = None + write_task: Task | None = None + current_flags: int = 0 assert not TYPE_CHECKING or sys.platform == "linux" @@ -33,11 +33,11 @@ class EpollWaiters: EventResult: TypeAlias = "list[tuple[int, int]]" -@attr.s(slots=True, eq=False, frozen=True) +@attrs.frozen(eq=False) class _EpollStatistics: - tasks_waiting_read: int = attr.ib() - tasks_waiting_write: int = attr.ib() - backend: Literal["epoll"] = attr.ib(init=False, default="epoll") + tasks_waiting_read: int + tasks_waiting_write: int + backend: Literal["epoll"] = attrs.field(init=False, default="epoll") # Some facts about epoll @@ -198,15 +198,17 @@ class _EpollStatistics: # wanted to about how epoll works. -@attr.s(slots=True, eq=False, hash=False) +@attrs.define(eq=False, hash=False) class EpollIOManager: - _epoll: select.epoll = attr.ib(factory=select.epoll) + # Using lambda here because otherwise crash on import with gevent monkey patching + # See https://github.com/python-trio/trio/issues/2848 + _epoll: select.epoll = attrs.Factory(lambda: select.epoll()) # {fd: EpollWaiters} - _registered: defaultdict[int, EpollWaiters] = attr.ib( - factory=lambda: defaultdict(EpollWaiters) + _registered: defaultdict[int, EpollWaiters] = attrs.Factory( + lambda: defaultdict(EpollWaiters) ) - _force_wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair) - _force_wakeup_fd: int | None = attr.ib(default=None) + _force_wakeup: WakeupSocketpair = attrs.Factory(WakeupSocketpair) + _force_wakeup_fd: int | None = None def __attrs_post_init__(self) -> None: self._epoll.register(self._force_wakeup.wakeup_sock, select.EPOLLIN) @@ -310,14 +312,70 @@ def abort(_: RaiseCancelT) -> Abort: @_public async def wait_readable(self, fd: int | _HasFileNo) -> None: + """Block until the kernel reports that the given object is readable. + + On Unix systems, ``fd`` must either be an integer file descriptor, + or else an object with a ``.fileno()`` method which returns an + integer file descriptor. Any kind of file descriptor can be passed, + though the exact semantics will depend on your kernel. For example, + this probably won't do anything useful for on-disk files. + + On Windows systems, ``fd`` must either be an integer ``SOCKET`` + handle, or else an object with a ``.fileno()`` method which returns + an integer ``SOCKET`` handle. File descriptors aren't supported, + and neither are handles that refer to anything besides a + ``SOCKET``. + + :raises trio.BusyResourceError: + if another task is already waiting for the given socket to + become readable. + :raises trio.ClosedResourceError: + if another task calls :func:`notify_closing` while this + function is still working. + """ await self._epoll_wait(fd, "read_task") @_public async def wait_writable(self, fd: int | _HasFileNo) -> None: + """Block until the kernel reports that the given object is writable. + + See `wait_readable` for the definition of ``fd``. + + :raises trio.BusyResourceError: + if another task is already waiting for the given socket to + become writable. + :raises trio.ClosedResourceError: + if another task calls :func:`notify_closing` while this + function is still working. + """ await self._epoll_wait(fd, "write_task") @_public def notify_closing(self, fd: int | _HasFileNo) -> None: + """Notify waiters of the given object that it will be closed. + + Call this before closing a file descriptor (on Unix) or socket (on + Windows). This will cause any `wait_readable` or `wait_writable` + calls on the given object to immediately wake up and raise + `~trio.ClosedResourceError`. + + This doesn't actually close the object – you still have to do that + yourself afterwards. Also, you want to be careful to make sure no + new tasks start waiting on the object in between when you call this + and when it's actually closed. So to close something properly, you + usually want to do these steps in order: + + 1. Explicitly mark the object as closed, so that any new attempts + to use it will abort before they start. + 2. Call `notify_closing` to wake up any already-existing users. + 3. Actually close the object. + + It's also possible to do them in a different order if that's more + convenient, *but only if* you make sure not to have any checkpoints in + between the steps. This way they all happen in a single atomic + step, so other tasks won't be able to tell what order they happened + in anyway. + """ if not isinstance(fd, int): fd = fd.fileno() wake_all( diff --git a/trio/_core/_io_kqueue.py b/src/trio/_core/_io_kqueue.py similarity index 63% rename from trio/_core/_io_kqueue.py rename to src/trio/_core/_io_kqueue.py index 4faa382eca..3d0aed7d35 100644 --- a/trio/_core/_io_kqueue.py +++ b/src/trio/_core/_io_kqueue.py @@ -6,7 +6,7 @@ from contextlib import contextmanager from typing import TYPE_CHECKING, Callable, Iterator, Literal -import attr +import attrs import outcome from .. import _core @@ -24,22 +24,22 @@ EventResult: TypeAlias = "list[select.kevent]" -@attr.s(slots=True, eq=False, frozen=True) +@attrs.frozen(eq=False) class _KqueueStatistics: - tasks_waiting: int = attr.ib() - monitors: int = attr.ib() - backend: Literal["kqueue"] = attr.ib(init=False, default="kqueue") + tasks_waiting: int + monitors: int + backend: Literal["kqueue"] = attrs.field(init=False, default="kqueue") -@attr.s(slots=True, eq=False) +@attrs.define(eq=False) class KqueueIOManager: - _kqueue: select.kqueue = attr.ib(factory=select.kqueue) + _kqueue: select.kqueue = attrs.Factory(select.kqueue) # {(ident, filter): Task or UnboundedQueue} - _registered: dict[tuple[int, int], Task | UnboundedQueue[select.kevent]] = attr.ib( - factory=dict + _registered: dict[tuple[int, int], Task | UnboundedQueue[select.kevent]] = ( + attrs.Factory(dict) ) - _force_wakeup: WakeupSocketpair = attr.ib(factory=WakeupSocketpair) - _force_wakeup_fd: int | None = attr.ib(default=None) + _force_wakeup: WakeupSocketpair = attrs.Factory(WakeupSocketpair) + _force_wakeup_fd: int | None = None def __attrs_post_init__(self) -> None: force_wakeup_event = select.kevent( @@ -109,13 +109,23 @@ def process_events(self, events: EventResult) -> None: @_public def current_kqueue(self) -> select.kqueue: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + `__. + """ return self._kqueue @contextmanager @_public def monitor_kevent( - self, ident: int, filter: int + self, + ident: int, + filter: int, ) -> Iterator[_core.UnboundedQueue[select.kevent]]: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + `__. + """ key = (ident, filter) if key in self._registered: raise _core.BusyResourceError( @@ -130,8 +140,15 @@ def monitor_kevent( @_public async def wait_kevent( - self, ident: int, filter: int, abort_func: Callable[[RaiseCancelT], Abort] + self, + ident: int, + filter: int, + abort_func: Callable[[RaiseCancelT], Abort], ) -> Abort: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + `__. + """ key = (ident, filter) if key in self._registered: raise _core.BusyResourceError( @@ -148,7 +165,11 @@ def abort(raise_cancel: RaiseCancelT) -> Abort: # wait_task_rescheduled does not have its return type typed return await _core.wait_task_rescheduled(abort) # type: ignore[no-any-return] - async def _wait_common(self, fd: int | _HasFileNo, filter: int) -> None: + async def _wait_common( + self, + fd: int | _HasFileNo, + filter: int, + ) -> None: if not isinstance(fd, int): fd = fd.fileno() flags = select.KQ_EV_ADD | select.KQ_EV_ONESHOT @@ -181,26 +202,82 @@ def abort(_: RaiseCancelT) -> Abort: @_public async def wait_readable(self, fd: int | _HasFileNo) -> None: + """Block until the kernel reports that the given object is readable. + + On Unix systems, ``fd`` must either be an integer file descriptor, + or else an object with a ``.fileno()`` method which returns an + integer file descriptor. Any kind of file descriptor can be passed, + though the exact semantics will depend on your kernel. For example, + this probably won't do anything useful for on-disk files. + + On Windows systems, ``fd`` must either be an integer ``SOCKET`` + handle, or else an object with a ``.fileno()`` method which returns + an integer ``SOCKET`` handle. File descriptors aren't supported, + and neither are handles that refer to anything besides a + ``SOCKET``. + + :raises trio.BusyResourceError: + if another task is already waiting for the given socket to + become readable. + :raises trio.ClosedResourceError: + if another task calls :func:`notify_closing` while this + function is still working. + """ await self._wait_common(fd, select.KQ_FILTER_READ) @_public async def wait_writable(self, fd: int | _HasFileNo) -> None: + """Block until the kernel reports that the given object is writable. + + See `wait_readable` for the definition of ``fd``. + + :raises trio.BusyResourceError: + if another task is already waiting for the given socket to + become writable. + :raises trio.ClosedResourceError: + if another task calls :func:`notify_closing` while this + function is still working. + """ await self._wait_common(fd, select.KQ_FILTER_WRITE) @_public def notify_closing(self, fd: int | _HasFileNo) -> None: + """Notify waiters of the given object that it will be closed. + + Call this before closing a file descriptor (on Unix) or socket (on + Windows). This will cause any `wait_readable` or `wait_writable` + calls on the given object to immediately wake up and raise + `~trio.ClosedResourceError`. + + This doesn't actually close the object – you still have to do that + yourself afterwards. Also, you want to be careful to make sure no + new tasks start waiting on the object in between when you call this + and when it's actually closed. So to close something properly, you + usually want to do these steps in order: + + 1. Explicitly mark the object as closed, so that any new attempts + to use it will abort before they start. + 2. Call `notify_closing` to wake up any already-existing users. + 3. Actually close the object. + + It's also possible to do them in a different order if that's more + convenient, *but only if* you make sure not to have any checkpoints in + between the steps. This way they all happen in a single atomic + step, so other tasks won't be able to tell what order they happened + in anyway. + """ if not isinstance(fd, int): fd = fd.fileno() - for filter in [select.KQ_FILTER_READ, select.KQ_FILTER_WRITE]: - key = (fd, filter) + for filter_ in [select.KQ_FILTER_READ, select.KQ_FILTER_WRITE]: + key = (fd, filter_) receiver = self._registered.get(key) if receiver is None: continue if type(receiver) is _core.Task: - event = select.kevent(fd, filter, select.KQ_EV_DELETE) + event = select.kevent(fd, filter_, select.KQ_EV_DELETE) self._kqueue.control([event], 0) exc = _core.ClosedResourceError("another task closed this fd") _core.reschedule(receiver, outcome.Error(exc)) diff --git a/trio/_core/_io_windows.py b/src/trio/_core/_io_windows.py similarity index 88% rename from trio/_core/_io_windows.py rename to src/trio/_core/_io_windows.py index dc939873a8..99cb7c76be 100644 --- a/trio/_core/_io_windows.py +++ b/src/trio/_core/_io_windows.py @@ -15,7 +15,7 @@ cast, ) -import attr +import attrs from outcome import Value from .. import _core @@ -242,22 +242,22 @@ class CKeys(enum.IntEnum): # To avoid this, we have to coalesce all the operations on a single socket # into one, and when the set of waiters changes we have to throw away the old # operation and start a new one. -@attr.s(slots=True, eq=False) +@attrs.define(eq=False) class AFDWaiters: - read_task: _core.Task | None = attr.ib(default=None) - write_task: _core.Task | None = attr.ib(default=None) - current_op: AFDPollOp | None = attr.ib(default=None) + read_task: _core.Task | None = None + write_task: _core.Task | None = None + current_op: AFDPollOp | None = None # We also need to bundle up all the info for a single op into a standalone # object, because we need to keep all these objects alive until the operation # finishes, even if we're throwing it away. -@attr.s(slots=True, eq=False, frozen=True) +@attrs.frozen(eq=False) class AFDPollOp: - lpOverlapped: CData = attr.ib() - poll_info: Any = attr.ib() - waiters: AFDWaiters = attr.ib() - afd_group: AFDGroup = attr.ib() + lpOverlapped: CData + poll_info: Any + waiters: AFDWaiters + afd_group: AFDGroup # The Windows kernel has a weird issue when using AFD handles. If you have N @@ -271,22 +271,22 @@ class AFDPollOp: MAX_AFD_GROUP_SIZE = 500 # at 1000, the cubic scaling is just starting to bite -@attr.s(slots=True, eq=False) +@attrs.define(eq=False) class AFDGroup: - size: int = attr.ib() - handle: Handle = attr.ib() + size: int + handle: Handle assert not TYPE_CHECKING or sys.platform == "win32" -@attr.s(slots=True, eq=False, frozen=True) +@attrs.frozen(eq=False) class _WindowsStatistics: - tasks_waiting_read: int = attr.ib() - tasks_waiting_write: int = attr.ib() - tasks_waiting_overlapped: int = attr.ib() - completion_key_monitors: int = attr.ib() - backend: Literal["windows"] = attr.ib(init=False, default="windows") + tasks_waiting_read: int + tasks_waiting_write: int + tasks_waiting_overlapped: int + completion_key_monitors: int + backend: Literal["windows"] = attrs.field(init=False, default="windows") # Maximum number of events to dequeue from the completion port on each pass @@ -405,10 +405,10 @@ def _afd_helper_handle() -> Handle: return handle -@attr.s(frozen=True) +@attrs.frozen(slots=False) class CompletionKeyEventInfo: - lpOverlapped: CData = attr.ib() - dwNumberOfBytesTransferred: int = attr.ib() + lpOverlapped: CData + dwNumberOfBytesTransferred: int class WindowsIOManager: @@ -728,14 +728,70 @@ def abort_fn(_: RaiseCancelT) -> Abort: @_public async def wait_readable(self, sock: _HasFileNo | int) -> None: + """Block until the kernel reports that the given object is readable. + + On Unix systems, ``sock`` must either be an integer file descriptor, + or else an object with a ``.fileno()`` method which returns an + integer file descriptor. Any kind of file descriptor can be passed, + though the exact semantics will depend on your kernel. For example, + this probably won't do anything useful for on-disk files. + + On Windows systems, ``sock`` must either be an integer ``SOCKET`` + handle, or else an object with a ``.fileno()`` method which returns + an integer ``SOCKET`` handle. File descriptors aren't supported, + and neither are handles that refer to anything besides a + ``SOCKET``. + + :raises trio.BusyResourceError: + if another task is already waiting for the given socket to + become readable. + :raises trio.ClosedResourceError: + if another task calls :func:`notify_closing` while this + function is still working. + """ await self._afd_poll(sock, "read_task") @_public async def wait_writable(self, sock: _HasFileNo | int) -> None: + """Block until the kernel reports that the given object is writable. + + See `wait_readable` for the definition of ``sock``. + + :raises trio.BusyResourceError: + if another task is already waiting for the given socket to + become writable. + :raises trio.ClosedResourceError: + if another task calls :func:`notify_closing` while this + function is still working. + """ await self._afd_poll(sock, "write_task") @_public def notify_closing(self, handle: Handle | int | _HasFileNo) -> None: + """Notify waiters of the given object that it will be closed. + + Call this before closing a file descriptor (on Unix) or socket (on + Windows). This will cause any `wait_readable` or `wait_writable` + calls on the given object to immediately wake up and raise + `~trio.ClosedResourceError`. + + This doesn't actually close the object – you still have to do that + yourself afterwards. Also, you want to be careful to make sure no + new tasks start waiting on the object in between when you call this + and when it's actually closed. So to close something properly, you + usually want to do these steps in order: + + 1. Explicitly mark the object as closed, so that any new attempts + to use it will abort before they start. + 2. Call `notify_closing` to wake up any already-existing users. + 3. Actually close the object. + + It's also possible to do them in a different order if that's more + convenient, *but only if* you make sure not to have any checkpoints in + between the steps. This way they all happen in a single atomic + step, so other tasks won't be able to tell what order they happened + in anyway. + """ handle = _get_base_socket(handle) waiters = self._afd_waiters.get(handle) if waiters is not None: @@ -748,12 +804,22 @@ def notify_closing(self, handle: Handle | int | _HasFileNo) -> None: @_public def register_with_iocp(self, handle: int | CData) -> None: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + `__ and `#52 + `__. + """ self._register_with_iocp(handle, CKeys.WAIT_OVERLAPPED) @_public async def wait_overlapped( self, handle_: int | CData, lpOverlapped: CData | int ) -> object: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + `__ and `#52 + `__. + """ handle = _handle(handle_) if isinstance(lpOverlapped, int): lpOverlapped = ffi.cast("LPOVERLAPPED", lpOverlapped) @@ -845,6 +911,11 @@ async def _perform_overlapped( async def write_overlapped( self, handle: int | CData, data: Buffer, file_offset: int = 0 ) -> int: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + `__ and `#52 + `__. + """ with ffi.from_buffer(data) as cbuf: def submit_write(lpOverlapped: _Overlapped) -> None: @@ -870,6 +941,11 @@ def submit_write(lpOverlapped: _Overlapped) -> None: async def readinto_overlapped( self, handle: int | CData, buffer: Buffer, file_offset: int = 0 ) -> int: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + `__ and `#52 + `__. + """ with ffi.from_buffer(buffer, require_writable=True) as cbuf: def submit_read(lpOverlapped: _Overlapped) -> None: @@ -895,12 +971,22 @@ def submit_read(lpOverlapped: _Overlapped) -> None: @_public def current_iocp(self) -> int: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + `__ and `#52 + `__. + """ assert self._iocp is not None return int(ffi.cast("uintptr_t", self._iocp)) @contextmanager @_public def monitor_completion_key(self) -> Iterator[tuple[int, UnboundedQueue[object]]]: + """TODO: these are implemented, but are currently more of a sketch than + anything real. See `#26 + `__ and `#52 + `__. + """ key = next(self._completion_key_counter) queue = _core.UnboundedQueue[object]() self._completion_key_queues[key] = queue diff --git a/trio/_core/_ki.py b/src/trio/_core/_ki.py similarity index 94% rename from trio/_core/_ki.py rename to src/trio/_core/_ki.py index 0ea34619b5..a8431f89db 100644 --- a/trio/_core/_ki.py +++ b/src/trio/_core/_ki.py @@ -3,12 +3,10 @@ import inspect import signal import sys -import types -from collections.abc import Callable from functools import wraps from typing import TYPE_CHECKING, Final, Protocol, TypeVar -import attr +import attrs from .._util import is_main_thread @@ -16,6 +14,9 @@ RetT = TypeVar("RetT") if TYPE_CHECKING: + import types + from collections.abc import Callable + from typing_extensions import ParamSpec, TypeGuard ArgsT = ParamSpec("ArgsT") @@ -131,6 +132,7 @@ def _ki_protection_decorator( ) -> Callable[[Callable[ArgsT, RetT]], Callable[ArgsT, RetT]]: # The "ignore[return-value]" below is because the inspect functions cast away the # original return type of fn, making it just CoroutineType[Any, Any, Any] etc. + # ignore[misc] is because @wraps() is passed a callable with Any in the return type. def decorator(fn: Callable[ArgsT, RetT]) -> Callable[ArgsT, RetT]: # In some version of Python, isgeneratorfunction returns true for # coroutine functions, so we have to check for coroutine functions @@ -138,9 +140,10 @@ def decorator(fn: Callable[ArgsT, RetT]) -> Callable[ArgsT, RetT]: if inspect.iscoroutinefunction(fn): @wraps(fn) - def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: + def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: # type: ignore[misc] # See the comment for regular generators below coro = fn(*args, **kwargs) + assert coro.cr_frame is not None, "Coroutine frame should exist" coro.cr_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled return coro # type: ignore[return-value] @@ -148,7 +151,7 @@ def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: elif inspect.isgeneratorfunction(fn): @wraps(fn) - def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: + def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: # type: ignore[misc] # It's important that we inject this directly into the # generator's locals, as opposed to setting it here and then # doing 'yield from'. The reason is, if a generator is @@ -164,8 +167,8 @@ def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: return wrapper elif inspect.isasyncgenfunction(fn) or legacy_isasyncgenfunction(fn): - @wraps(fn) - def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: + @wraps(fn) # type: ignore[arg-type] + def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: # type: ignore[misc] # See the comment for regular generators above agen = fn(*args, **kwargs) agen.ag_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled @@ -176,7 +179,7 @@ def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: @wraps(fn) def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled return fn(*args, **kwargs) return wrapper @@ -200,11 +203,9 @@ def __call__(self, f: CallableT, /) -> CallableT: disable_ki_protection.__name__ = "disable_ki_protection" -@attr.s +@attrs.define(slots=False) class KIManager: - handler: Callable[[int, types.FrameType | None], None] | None = attr.ib( - default=None - ) + handler: Callable[[int, types.FrameType | None], None] | None = None def install( self, diff --git a/trio/_core/_local.py b/src/trio/_core/_local.py similarity index 88% rename from trio/_core/_local.py rename to src/trio/_core/_local.py index 7a367eceef..dd20776c54 100644 --- a/trio/_core/_local.py +++ b/src/trio/_core/_local.py @@ -3,7 +3,7 @@ from typing import Generic, TypeVar, cast # Runvar implementations -import attr +import attrs from .._util import NoPublicConstructor, final from . import _run @@ -12,16 +12,15 @@ @final -class _NoValue: - ... +class _NoValue: ... @final -@attr.s(eq=False, hash=False, slots=True) +@attrs.define(eq=False, hash=False) class RunVarToken(Generic[T], metaclass=NoPublicConstructor): - _var: RunVar[T] = attr.ib() - previous_value: T | type[_NoValue] = attr.ib(default=_NoValue) - redeemed: bool = attr.ib(default=False, init=False) + _var: RunVar[T] + previous_value: T | type[_NoValue] = _NoValue + redeemed: bool = attrs.field(default=False, init=False) @classmethod def _empty(cls, var: RunVar[T]) -> RunVarToken[T]: @@ -29,7 +28,7 @@ def _empty(cls, var: RunVar[T]) -> RunVarToken[T]: @final -@attr.s(eq=False, hash=False, slots=True, repr=False) +@attrs.define(eq=False, hash=False, repr=False) class RunVar(Generic[T]): """The run-local variant of a context variable. @@ -39,8 +38,8 @@ class RunVar(Generic[T]): """ - _name: str = attr.ib() - _default: T | type[_NoValue] = attr.ib(default=_NoValue) + _name: str + _default: T | type[_NoValue] = _NoValue def get(self, default: T | type[_NoValue] = _NoValue) -> T: """Gets the value of this :class:`RunVar` for the current run call.""" diff --git a/trio/_core/_mock_clock.py b/src/trio/_core/_mock_clock.py similarity index 97% rename from trio/_core/_mock_clock.py rename to src/trio/_core/_mock_clock.py index deb239c417..70c4e58a2d 100644 --- a/trio/_core/_mock_clock.py +++ b/src/trio/_core/_mock_clock.py @@ -79,9 +79,7 @@ def __init__(self, rate: float = 0.0, autojump_threshold: float = inf): self.autojump_threshold = autojump_threshold def __repr__(self) -> str: - return "".format( - self.current_time(), self._rate, id(self) - ) + return f"" @property def rate(self) -> float: diff --git a/trio/_core/_parking_lot.py b/src/trio/_core/_parking_lot.py similarity index 97% rename from trio/_core/_parking_lot.py rename to src/trio/_core/_parking_lot.py index d9579a613d..340c62508a 100644 --- a/trio/_core/_parking_lot.py +++ b/src/trio/_core/_parking_lot.py @@ -73,19 +73,20 @@ import math from collections import OrderedDict -from collections.abc import Iterator from typing import TYPE_CHECKING -import attr +import attrs from .. import _core from .._util import final if TYPE_CHECKING: + from collections.abc import Iterator + from ._run import Task -@attr.s(frozen=True, slots=True) +@attrs.frozen class ParkingLotStatistics: """An object containing debugging information for a ParkingLot. @@ -96,11 +97,11 @@ class ParkingLotStatistics: """ - tasks_waiting: int = attr.ib() + tasks_waiting: int @final -@attr.s(eq=False, hash=False, slots=True) +@attrs.define(eq=False, hash=False) class ParkingLot: """A fair wait queue with cancellation and requeueing. @@ -116,7 +117,7 @@ class ParkingLot: # {task: None}, we just want a deque where we can quickly delete random # items - _parked: OrderedDict[Task, None] = attr.ib(factory=OrderedDict, init=False) + _parked: OrderedDict[Task, None] = attrs.field(factory=OrderedDict, init=False) def __len__(self) -> int: """Returns the number of parked tasks.""" diff --git a/trio/_core/_run.py b/src/trio/_core/_run.py similarity index 89% rename from trio/_core/_run.py rename to src/trio/_core/_run.py index e07d56d66e..5453c3602e 100644 --- a/trio/_core/_run.py +++ b/src/trio/_core/_run.py @@ -10,45 +10,36 @@ import threading import warnings from collections import deque -from collections.abc import ( - Awaitable, - Callable, - Coroutine, - Generator, - Iterator, - Sequence, -) from contextlib import AbstractAsyncContextManager, contextmanager, suppress from contextvars import copy_context from heapq import heapify, heappop, heappush from math import inf from time import perf_counter -from types import TracebackType from typing import ( TYPE_CHECKING, Any, Final, NoReturn, Protocol, - TypeVar, cast, overload, ) -import attr +import attrs from outcome import Error, Outcome, Value, capture from sniffio import thread_local as sniffio_library from sortedcontainers import SortedDict from .. import _core from .._abc import Clock, Instrument +from .._deprecate import warn_deprecated from .._util import NoPublicConstructor, coroutine_or_error, final from ._asyncgens import AsyncGenerators +from ._concat_tb import concat_tb from ._entry_queue import EntryQueue, TrioToken from ._exceptions import Cancelled, RunFinishedError, TrioInternalError from ._instrumentation import Instruments from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED, KIManager, enable_ki_protection -from ._multierror import MultiError, concat_tb from ._thread_cache import start_thread_soon from ._traps import ( Abort, @@ -62,23 +53,45 @@ if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup -from types import FrameType if TYPE_CHECKING: import contextvars + import types + from collections.abc import ( + Awaitable, + Callable, + Coroutine, + Generator, + Iterator, + Sequence, + ) + from types import TracebackType + + # for some strange reason Sphinx works with outcome.Outcome, but not Outcome, in + # start_guest_run. Same with types.FrameType in iter_await_frames + import outcome + from typing_extensions import Self, TypeVar, TypeVarTuple, Unpack + + PosArgT = TypeVarTuple("PosArgT") + StatusT = TypeVar("StatusT", default=None) + StatusT_contra = TypeVar("StatusT_contra", contravariant=True, default=None) +else: + from typing import TypeVar + + StatusT = TypeVar("StatusT") + StatusT_contra = TypeVar("StatusT_contra", contravariant=True) + +FnT = TypeVar("FnT", bound="Callable[..., Any]") +RetT = TypeVar("RetT") - from typing_extensions import Self DEADLINE_HEAP_MIN_PRUNE_THRESHOLD: Final = 1000 # Passed as a sentinel _NO_SEND: Final[Outcome[Any]] = cast("Outcome[Any]", object()) -FnT = TypeVar("FnT", bound="Callable[..., Any]") -StatusT = TypeVar("StatusT") -StatusT_co = TypeVar("StatusT_co", covariant=True) -StatusT_contra = TypeVar("StatusT_contra", contravariant=True) -RetT = TypeVar("RetT") +# Used to track if an exceptiongroup can be collapsed +NONSTRICT_EXCEPTIONGROUP_NOTE = 'This is a "loose" ExceptionGroup, and may be collapsed by Trio if it only contains one exception - typically after `Cancelled` has been stripped from it. Note this has consequences for exception handling, and strict_exception_groups=True is recommended.' @final @@ -93,14 +106,23 @@ def _public(fn: FnT) -> FnT: # When running under Hypothesis, we want examples to be reproducible and -# shrinkable. pytest-trio's Hypothesis integration monkeypatches this -# variable to True, and registers the Random instance _r for Hypothesis -# to manage for each test case, which together should make Trio's task +# shrinkable. We therefore register `_hypothesis_plugin_setup()` as a +# plugin, so that importing *Hypothesis* will make Trio's task # scheduling loop deterministic. We have a test for that, of course. +# Before Hypothesis supported entry-point plugins this integration was +# handled by pytest-trio, but we want it to work in e.g. unittest too. _ALLOW_DETERMINISTIC_SCHEDULING: Final = False _r = random.Random() +def _hypothesis_plugin_setup() -> None: + from hypothesis import register_random + + global _ALLOW_DETERMINISTIC_SCHEDULING + _ALLOW_DETERMINISTIC_SCHEDULING = True # type: ignore + register_random(_r) + + def _count_context_run_tb_frames() -> int: """Count implementation dependent traceback frames from Context.run() @@ -145,12 +167,12 @@ def function_with_unique_name_xyzzy() -> NoReturn: CONTEXT_RUN_TB_FRAMES: Final = _count_context_run_tb_frames() -@attr.s(frozen=True, slots=True) +@attrs.frozen class SystemClock(Clock): # Add a large random offset to our clock to ensure that if people # accidentally call time.perf_counter() directly or start comparing clocks # between different runs, then they'll notice the bug quickly: - offset: float = attr.ib(factory=lambda: _r.uniform(10000, 200000)) + offset: float = attrs.Factory(lambda: _r.uniform(10000, 200000)) def start_clock(self) -> None: pass @@ -191,7 +213,11 @@ def collapse_exception_group( modified = True exceptions[i] = new_exc - if len(exceptions) == 1 and isinstance(excgroup, MultiError) and excgroup.collapse: + if ( + len(exceptions) == 1 + and isinstance(excgroup, BaseExceptionGroup) + and NONSTRICT_EXCEPTIONGROUP_NOTE in getattr(excgroup, "__notes__", ()) + ): exceptions[0].__traceback__ = concat_tb( excgroup.__traceback__, exceptions[0].__traceback__ ) @@ -202,7 +228,7 @@ def collapse_exception_group( return excgroup -@attr.s(eq=False, slots=True) +@attrs.define(eq=False) class Deadlines: """A container of deadlined cancel scopes. @@ -212,9 +238,9 @@ class Deadlines: """ # Heap of (deadline, id(CancelScope), CancelScope) - _heap: list[tuple[float, int, CancelScope]] = attr.ib(factory=list) + _heap: list[tuple[float, int, CancelScope]] = attrs.Factory(list) # Count of active deadlines (those that haven't been changed) - _active: int = attr.ib(default=0) + _active: int = 0 def add(self, deadline: float, cancel_scope: CancelScope) -> None: heappush(self._heap, (deadline, id(cancel_scope), cancel_scope)) @@ -271,7 +297,7 @@ def expire(self, now: float) -> bool: return did_something -@attr.s(eq=False, slots=True) +@attrs.define(eq=False) class CancelStatus: """Tracks the cancellation status for a contiguous extent of code that will become cancelled, or not, as a unit. @@ -304,7 +330,7 @@ class CancelStatus: # Our associated cancel scope. Can be any object with attributes # `deadline`, `shield`, and `cancel_called`, but in current usage # is always a CancelScope object. Must not be None. - _scope: CancelScope = attr.ib() + _scope: CancelScope = attrs.field(alias="scope") # True iff the tasks in self._tasks should receive cancellations # when they checkpoint. Always True when scope.cancel_called is True; @@ -314,29 +340,29 @@ class CancelStatus: # effectively cancelled due to the cancel scope two levels out # becoming cancelled, but then the cancel scope one level out # becomes shielded so we're not effectively cancelled anymore. - effectively_cancelled: bool = attr.ib(default=False) + effectively_cancelled: bool = False # The CancelStatus whose cancellations can propagate to us; we # become effectively cancelled when they do, unless scope.shield # is True. May be None (for the outermost CancelStatus in a call # to trio.run(), briefly during TaskStatus.started(), or during # recovery from mis-nesting of cancel scopes). - _parent: CancelStatus | None = attr.ib(default=None, repr=False) + _parent: CancelStatus | None = attrs.field(default=None, repr=False, alias="parent") # All of the CancelStatuses that have this CancelStatus as their parent. - _children: set[CancelStatus] = attr.ib(factory=set, init=False, repr=False) + _children: set[CancelStatus] = attrs.field(factory=set, init=False, repr=False) # Tasks whose cancellation state is currently tied directly to # the cancellation state of this CancelStatus object. Don't modify # this directly; instead, use Task._activate_cancel_status(). # Invariant: all(task._cancel_status is self for task in self._tasks) - _tasks: set[Task] = attr.ib(factory=set, init=False, repr=False) + _tasks: set[Task] = attrs.field(factory=set, init=False, repr=False) # Set to True on still-active cancel statuses that are children # of a cancel status that's been closed. This is used to permit # recovery from mis-nested cancel scopes (well, at least enough # recovery to show a useful traceback). - abandoned_by_misnesting: bool = attr.ib(default=False, init=False, repr=False) + abandoned_by_misnesting: bool = attrs.field(default=False, init=False, repr=False) def __attrs_post_init__(self) -> None: if self._parent is not None: @@ -469,7 +495,7 @@ def effective_deadline(self) -> float: @final -@attr.s(eq=False, repr=False, slots=True) +@attrs.define(eq=False, repr=False) class CancelScope: """A *cancellation scope*: the link between a unit of cancellable work and Trio's cancellation system. @@ -509,15 +535,15 @@ class CancelScope: has been entered yet, and changes take immediate effect. """ - _cancel_status: CancelStatus | None = attr.ib(default=None, init=False) - _has_been_entered: bool = attr.ib(default=False, init=False) - _registered_deadline: float = attr.ib(default=inf, init=False) - _cancel_called: bool = attr.ib(default=False, init=False) - cancelled_caught: bool = attr.ib(default=False, init=False) + _cancel_status: CancelStatus | None = attrs.field(default=None, init=False) + _has_been_entered: bool = attrs.field(default=False, init=False) + _registered_deadline: float = attrs.field(default=inf, init=False) + _cancel_called: bool = attrs.field(default=False, init=False) + cancelled_caught: bool = attrs.field(default=False, init=False) # Constructor arguments: - _deadline: float = attr.ib(default=inf, kw_only=True) - _shield: bool = attr.ib(default=False, kw_only=True) + _deadline: float = attrs.field(default=inf, kw_only=True, alias="deadline") + _shield: bool = attrs.field(default=False, kw_only=True, alias="shield") @enable_ki_protection def __enter__(self) -> Self: @@ -571,13 +597,8 @@ def _close(self, exc: BaseException | None) -> BaseException | None: # we just need to make sure we don't let the error # pass silently. new_exc = RuntimeError( - "Cancel scope stack corrupted: attempted to exit {!r} " - "in {!r} that's still within its child {!r}\n{}".format( - self, - scope_task, - scope_task._cancel_status._scope, - MISNESTING_ADVICE, - ) + f"Cancel scope stack corrupted: attempted to exit {self!r} " + f"in {scope_task!r} that's still within its child {scope_task._cancel_status._scope!r}\n{MISNESTING_ADVICE}" ) new_exc.__context__ = exc exc = new_exc @@ -605,6 +626,7 @@ def _close(self, exc: BaseException | None) -> BaseException | None: self._cancel_status = None return exc + @enable_ki_protection def __exit__( self, etype: type[BaseException] | None, @@ -615,10 +637,6 @@ def __exit__( # so __exit__() must be just _close() plus this logic for adapting # the exception-filtering result to the context manager API. - # This inlines the enable_ki_protection decorator so we can fix - # f_locals *locally* below to avoid reference cycles - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True - # Tracebacks show the 'raise' line below out of context, so let's give # this variable a name that makes sense out of context. remaining_error_after_cancel_scope = self._close(exc) @@ -627,7 +645,7 @@ def __exit__( elif remaining_error_after_cancel_scope is exc: return False else: - # Copied verbatim from MultiErrorCatcher. Python doesn't + # Copied verbatim from the old MultiErrorCatcher. Python doesn't # allow us to encapsulate this __context__ fixup. old_context = remaining_error_after_cancel_scope.__context__ try: @@ -638,11 +656,8 @@ def __exit__( value.__context__ = old_context # delete references from locals to avoid creating cycles # see test_cancel_scope_exit_doesnt_create_cyclic_garbage + # Note: still relevant del remaining_error_after_cancel_scope, value, _, exc - # deep magic to remove refs via f_locals - locals() - # TODO: check if PEP558 changes the need for this call - # https://github.com/python/cpython/pull/3640 def __repr__(self) -> str: if self._cancel_status is not None: @@ -816,12 +831,10 @@ class TaskStatus(Protocol[StatusT_contra]): """ @overload - def started(self: TaskStatus[None]) -> None: - ... + def started(self: TaskStatus[None]) -> None: ... @overload - def started(self, value: StatusT_contra) -> None: - ... + def started(self, value: StatusT_contra) -> None: ... def started(self, value: StatusT_contra | None = None) -> None: """Tasks call this method to indicate that they have initialized. @@ -832,23 +845,21 @@ def started(self, value: StatusT_contra | None = None) -> None: # This code needs to be read alongside the code from Nursery.start to make # sense. -@attr.s(eq=False, hash=False, repr=False) +@attrs.define(eq=False, hash=False, repr=False, slots=False) class _TaskStatus(TaskStatus[StatusT]): - _old_nursery: Nursery = attr.ib() - _new_nursery: Nursery = attr.ib() + _old_nursery: Nursery + _new_nursery: Nursery # NoStatus is a sentinel. - _value: StatusT | type[_NoStatus] = attr.ib(default=_NoStatus) + _value: StatusT | type[_NoStatus] = _NoStatus def __repr__(self) -> str: return f"" @overload - def started(self: _TaskStatus[None]) -> None: - ... + def started(self: _TaskStatus[None]) -> None: ... @overload - def started(self: _TaskStatus[StatusT], value: StatusT) -> None: - ... + def started(self: _TaskStatus[StatusT], value: StatusT) -> None: ... def started(self, value: StatusT | None = None) -> None: if self._value is not _NoStatus: @@ -903,7 +914,7 @@ def started(self, value: StatusT | None = None) -> None: self._old_nursery._check_nursery_closed() -@attr.s +@attrs.define(slots=False) class NurseryManager: """Nursery context manager. @@ -914,7 +925,7 @@ class NurseryManager: """ - strict_exception_groups: bool = attr.ib(default=False) + strict_exception_groups: bool = True @enable_ki_protection async def __aenter__(self) -> Nursery: @@ -941,7 +952,7 @@ async def __aexit__( elif combined_error_from_nursery is exc: return False else: - # Copied verbatim from MultiErrorCatcher. Python doesn't + # Copied verbatim from the old MultiErrorCatcher. Python doesn't # allow us to encapsulate this __context__ fixup. old_context = combined_error_from_nursery.__context__ try: @@ -951,7 +962,7 @@ async def __aexit__( assert value is combined_error_from_nursery value.__context__ = old_context # delete references from locals to avoid creating cycles - # see test_simple_cancel_scope_usage_doesnt_create_cyclic_garbage + # see test_cancel_scope_exit_doesnt_create_cyclic_garbage del _, combined_error_from_nursery, value, new_exc # make sure these raise errors in static analysis if called @@ -978,14 +989,32 @@ def open_nursery( new `Nursery`. It does not block on entry; on exit it blocks until all child tasks - have exited. + have exited. If no child tasks are running on exit, it will insert a + schedule point (but no cancellation point) - equivalent to + :func:`trio.lowlevel.cancel_shielded_checkpoint`. This means a nursery + is never the source of a cancellation exception, it only propagates it + from sub-tasks. Args: - strict_exception_groups (bool): If true, even a single raised exception will be - wrapped in an exception group. This will eventually become the default - behavior. If not specified, uses the value passed to :func:`run`. + strict_exception_groups (bool): Unless set to False, even a single raised exception + will be wrapped in an exception group. If not specified, uses the value passed + to :func:`run`, which defaults to true. Setting it to False will be deprecated + and ultimately removed in a future version of Trio. """ + # only warn if explicitly set to falsy, not if we get it from the global context. + if strict_exception_groups is not None and not strict_exception_groups: + warn_deprecated( + "open_nursery(strict_exception_groups=False)", + version="0.25.0", + issue=2929, + instead=( + "the default value of True and rewrite exception handlers to handle ExceptionGroups. " + "See https://trio.readthedocs.io/en/stable/reference-core.html#designing-for-multiple-errors" + ), + use_triodeprecationwarning=True, + ) + if strict_exception_groups is None: strict_exception_groups = GLOBAL_RUN_CONTEXT.runner.strict_exception_groups @@ -1072,7 +1101,7 @@ def _child_finished(self, task: Task, outcome: Outcome[Any]) -> None: async def _nested_child_finished( self, nested_child_exc: BaseException | None ) -> BaseException | None: - # Returns MultiError instance (or any exception if the nursery is in loose mode + # Returns ExceptionGroup instance (or any exception if the nursery is in loose mode # and there is just one contained exception) if there are pending exceptions if nested_child_exc is not None: self._add_exc(nested_child_exc) @@ -1087,6 +1116,7 @@ def aborted(raise_cancel: _core.RaiseCancelT) -> Abort: exn = capture(raise_cancel).error if not isinstance(exn, Cancelled): self._add_exc(exn) + # see test_cancel_scope_exit_doesnt_create_cyclic_garbage del exn # prevent cyclic garbage creation return Abort.FAILED @@ -1103,29 +1133,26 @@ def aborted(raise_cancel: _core.RaiseCancelT) -> Abort: popped = self._parent_task._child_nurseries.pop() assert popped is self - - # don't unnecessarily wrap an exceptiongroup in another exceptiongroup - # see https://github.com/python-trio/trio/issues/2611 - if len(self._pending_excs) == 1 and isinstance( - self._pending_excs[0], BaseExceptionGroup - ): - return self._pending_excs[0] if self._pending_excs: try: - return MultiError( - self._pending_excs, _collapse=not self._strict_exception_groups + if not self._strict_exception_groups and len(self._pending_excs) == 1: + return self._pending_excs[0] + exception = BaseExceptionGroup( + "Exceptions from Trio nursery", self._pending_excs ) + if not self._strict_exception_groups: + exception.add_note(NONSTRICT_EXCEPTIONGROUP_NOTE) + return exception finally: # avoid a garbage cycle - # (see test_nursery_cancel_doesnt_create_cyclic_garbage) + # (see test_locals_destroyed_promptly_on_cancel) del self._pending_excs return None def start_soon( self, - # TODO: TypeVarTuple - async_fn: Callable[..., Awaitable[object]], - *args: object, + async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]], + *args: Unpack[PosArgT], name: object = None, ) -> None: """Creates a child task, scheduling ``await async_fn(*args)``. @@ -1174,7 +1201,7 @@ async def start( async_fn: Callable[..., Awaitable[object]], *args: object, name: object = None, - ) -> StatusT: + ) -> Any: r"""Creates and initializes a child task. Like :meth:`start_soon`, but blocks until the new task has @@ -1218,21 +1245,36 @@ async def async_fn(arg1, arg2, *, task_status=trio.TASK_STATUS_IGNORED): raise RuntimeError("Nursery is closed to new arrivals") try: self._pending_starts += 1 - async with open_nursery() as old_nursery: - task_status: _TaskStatus[StatusT] = _TaskStatus(old_nursery, self) - thunk = functools.partial(async_fn, task_status=task_status) - task = GLOBAL_RUN_CONTEXT.runner.spawn_impl( - thunk, args, old_nursery, name - ) - task._eventual_parent_nursery = self - # Wait for either TaskStatus.started or an exception to - # cancel this nursery: + # wrap internal nursery in try-except to unroll any exceptiongroups + # to avoid wrapping pre-started() exceptions in an extra ExceptionGroup. + # See #2611. + try: + # set strict_exception_groups = True to make sure we always unwrap + # *this* nursery's exceptiongroup + async with open_nursery(strict_exception_groups=True) as old_nursery: + task_status: _TaskStatus[Any] = _TaskStatus(old_nursery, self) + thunk = functools.partial(async_fn, task_status=task_status) + task = GLOBAL_RUN_CONTEXT.runner.spawn_impl( + thunk, args, old_nursery, name + ) + task._eventual_parent_nursery = self + # Wait for either TaskStatus.started or an exception to + # cancel this nursery: + except BaseExceptionGroup as exc: + if len(exc.exceptions) == 1: + raise exc.exceptions[0] from None + raise TrioInternalError( + "Internal nursery should not have multiple tasks. This can be " + 'caused by the user managing to access the "old" nursery in ' + "`task_status` and spawning tasks in it." + ) from exc + # If we get here, then the child either got reparented or exited # normally. The complicated logic is all in TaskStatus.started(). # (Any exceptions propagate directly out of the above.) if task_status._value is _NoStatus: raise RuntimeError("child exited without calling task_status.started()") - return task_status._value # type: ignore[return-value] # Mypy doesn't narrow yet. + return task_status._value finally: self._pending_starts -= 1 self._check_nursery_closed() @@ -1247,14 +1289,14 @@ def __del__(self) -> None: @final -@attr.s(eq=False, hash=False, repr=False, slots=True) +@attrs.define(eq=False, hash=False, repr=False) class Task(metaclass=NoPublicConstructor): - _parent_nursery: Nursery | None = attr.ib() - coro: Coroutine[Any, Outcome[object], Any] = attr.ib() - _runner: Runner = attr.ib() - name: str = attr.ib() - context: contextvars.Context = attr.ib() - _counter: int = attr.ib(init=False, factory=itertools.count().__next__) + _parent_nursery: Nursery | None + coro: Coroutine[Any, Outcome[object], Any] + _runner: Runner + name: str + context: contextvars.Context + _counter: int = attrs.field(init=False, factory=itertools.count().__next__) # Invariant: # - for unscheduled tasks, _next_send_fn and _next_send are both None @@ -1267,20 +1309,20 @@ class Task(metaclass=NoPublicConstructor): # tracebacks with extraneous frames. # - for scheduled tasks, custom_sleep_data is None # Tasks start out unscheduled. - _next_send_fn: Callable[[Any], object] = attr.ib(default=None) - _next_send: Outcome[Any] | None | BaseException = attr.ib(default=None) - _abort_func: Callable[[_core.RaiseCancelT], Abort] | None = attr.ib(default=None) - custom_sleep_data: Any = attr.ib(default=None) + _next_send_fn: Callable[[Any], object] | None = None + _next_send: Outcome[Any] | None | BaseException = None + _abort_func: Callable[[_core.RaiseCancelT], Abort] | None = None + custom_sleep_data: Any = None # For introspection and nursery.start() - _child_nurseries: list[Nursery] = attr.ib(factory=list) - _eventual_parent_nursery: Nursery | None = attr.ib(default=None) + _child_nurseries: list[Nursery] = attrs.Factory(list) + _eventual_parent_nursery: Nursery | None = None # these are counts of how many cancel/schedule points this task has # executed, for assert{_no,}_checkpoints # XX maybe these should be exposed as part of a statistics() method? - _cancel_points: int = attr.ib(default=0) - _schedule_points: int = attr.ib(default=0) + _cancel_points: int = 0 + _schedule_points: int = 0 def __repr__(self) -> str: return f"" @@ -1317,7 +1359,7 @@ def child_nurseries(self) -> list[Nursery]: """ return list(self._child_nurseries) - def iter_await_frames(self) -> Iterator[tuple[FrameType, int]]: + def iter_await_frames(self) -> Iterator[tuple[types.FrameType, int]]: """Iterates recursively over the coroutine-like objects this task is waiting on, yielding the frame and line number at each frame. @@ -1372,7 +1414,7 @@ def print_stack_for_task(task): # The CancelStatus object that is currently active for this task. # Don't change this directly; instead, use _activate_cancel_status(). # This can be None, but only in the init task. - _cancel_status: CancelStatus = attr.ib(default=None, repr=False) + _cancel_status: CancelStatus = attrs.field(default=None, repr=False) def _activate_cancel_status(self, cancel_status: CancelStatus | None) -> None: if self._cancel_status is not None: @@ -1438,7 +1480,7 @@ class RunContext(threading.local): GLOBAL_RUN_CONTEXT: Final = RunContext() -@attr.frozen +@attrs.frozen class RunStatistics: """An object containing run-loop-level debugging information. @@ -1490,14 +1532,14 @@ class RunStatistics: # worker thread. -@attr.s(eq=False, hash=False, slots=True) +@attrs.define(eq=False, hash=False) class GuestState: - runner: Runner = attr.ib() - run_sync_soon_threadsafe: Callable[[Callable[[], object]], object] = attr.ib() - run_sync_soon_not_threadsafe: Callable[[Callable[[], object]], object] = attr.ib() - done_callback: Callable[[Outcome[Any]], object] = attr.ib() - unrolled_run_gen: Generator[float, EventResult, None] = attr.ib() - unrolled_run_next_send: Outcome[Any] = attr.ib(factory=lambda: Value(None)) + runner: Runner + run_sync_soon_threadsafe: Callable[[Callable[[], object]], object] + run_sync_soon_not_threadsafe: Callable[[Callable[[], object]], object] + done_callback: Callable[[Outcome[Any]], object] + unrolled_run_gen: Generator[float, EventResult, None] + unrolled_run_next_send: Outcome[Any] = attrs.Factory(lambda: Value(None)) def guest_tick(self) -> None: prev_library, sniffio_library.name = sniffio_library.name, "trio" @@ -1540,38 +1582,38 @@ def in_main_thread() -> None: start_thread_soon(get_events, deliver) -@attr.s(eq=False, hash=False, slots=True) +@attrs.define(eq=False, hash=False) class Runner: - clock: Clock = attr.ib() - instruments: Instruments = attr.ib() - io_manager: TheIOManager = attr.ib() - ki_manager: KIManager = attr.ib() - strict_exception_groups: bool = attr.ib() + clock: Clock + instruments: Instruments + io_manager: TheIOManager + ki_manager: KIManager + strict_exception_groups: bool # Run-local values, see _local.py - _locals: dict[_core.RunVar[Any], Any] = attr.ib(factory=dict) + _locals: dict[_core.RunVar[Any], Any] = attrs.Factory(dict) - runq: deque[Task] = attr.ib(factory=deque) - tasks: set[Task] = attr.ib(factory=set) + runq: deque[Task] = attrs.Factory(deque) + tasks: set[Task] = attrs.Factory(set) - deadlines: Deadlines = attr.ib(factory=Deadlines) + deadlines: Deadlines = attrs.Factory(Deadlines) - init_task: Task | None = attr.ib(default=None) - system_nursery: Nursery | None = attr.ib(default=None) - system_context: contextvars.Context = attr.ib(kw_only=True) - main_task: Task | None = attr.ib(default=None) - main_task_outcome: Outcome[Any] | None = attr.ib(default=None) + init_task: Task | None = None + system_nursery: Nursery | None = None + system_context: contextvars.Context = attrs.field(kw_only=True) + main_task: Task | None = None + main_task_outcome: Outcome[Any] | None = None - entry_queue: EntryQueue = attr.ib(factory=EntryQueue) - trio_token: TrioToken | None = attr.ib(default=None) - asyncgens: AsyncGenerators = attr.ib(factory=AsyncGenerators) + entry_queue: EntryQueue = attrs.Factory(EntryQueue) + trio_token: TrioToken | None = None + asyncgens: AsyncGenerators = attrs.Factory(AsyncGenerators) # If everything goes idle for this long, we call clock._autojump() - clock_autojump_threshold: float = attr.ib(default=inf) + clock_autojump_threshold: float = inf # Guest mode stuff - is_guest: bool = attr.ib(default=False) - guest_tick_scheduled: bool = attr.ib(default=False) + is_guest: bool = False + guest_tick_scheduled: bool = False def force_guest_tick_asap(self) -> None: if self.guest_tick_scheduled: @@ -1690,9 +1732,8 @@ def reschedule( # type: ignore[misc] def spawn_impl( self, - # TODO: TypeVarTuple - async_fn: Callable[..., Awaitable[object]], - args: tuple[object, ...], + async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]], + args: tuple[Unpack[PosArgT]], nursery: Nursery | None, name: object, *, @@ -1721,7 +1762,7 @@ def spawn_impl( # Call the function and get the coroutine object, while giving helpful # errors for common mistakes. ###### - # TODO: resolve the type: ignore when implementing TypeVarTuple + # TypeVarTuple passed into ParamSpec function confuses Mypy. coro = context.run(coroutine_or_error, async_fn, *args) # type: ignore[arg-type] if name is None: @@ -1734,12 +1775,14 @@ def spawn_impl( except AttributeError: name = repr(name) - if not hasattr(coro, "cr_frame"): + # very old Cython versions (<0.29.24) has the attribute, but with a value of None + if getattr(coro, "cr_frame", None) is None: # This async function is implemented in C or Cython async def python_wrapper(orig_coro: Awaitable[RetT]) -> RetT: return await orig_coro coro = python_wrapper(coro) + assert coro.cr_frame is not None, "Coroutine frame should exist" coro.cr_frame.f_locals.setdefault(LOCALS_KEY_KI_PROTECTION_ENABLED, system_task) ###### @@ -1758,8 +1801,7 @@ async def python_wrapper(orig_coro: Awaitable[RetT]) -> RetT: self.instruments.call("task_spawned", task) # Special case: normally next_send should be an Outcome, but for the # very first send we have to send a literal unboxed None. - # TODO: remove [unused-ignore] when Outcome is typed - self.reschedule(task, None) # type: ignore[arg-type, unused-ignore] + self.reschedule(task, None) # type: ignore[arg-type] return task def task_exited(self, task: Task, outcome: Outcome[Any]) -> None: @@ -1809,12 +1851,11 @@ def task_exited(self, task: Task, outcome: Outcome[Any]) -> None: # System tasks and init ################ - @_public # Type-ignore due to use of Any here. - def spawn_system_task( # type: ignore[misc] + @_public + def spawn_system_task( self, - # TODO: TypeVarTuple - async_fn: Callable[..., Awaitable[object]], - *args: object, + async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]], + *args: Unpack[PosArgT], name: object = None, context: contextvars.Context | None = None, ) -> Task: @@ -1879,10 +1920,9 @@ def spawn_system_task( # type: ignore[misc] ) async def init( - # TODO: TypeVarTuple self, - async_fn: Callable[..., Awaitable[object]], - args: tuple[object, ...], + async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]], + args: tuple[Unpack[PosArgT]], ) -> None: # run_sync_soon task runs here: async with open_nursery() as run_sync_soon_nursery: @@ -1934,7 +1974,7 @@ def current_trio_token(self) -> TrioToken: # KI handling ################ - ki_pending: bool = attr.ib(default=False) + ki_pending: bool = False # deliver_ki is broke. Maybe move all the actual logic and state into # RunToken, and we'll only have one instance per runner? But then we can't @@ -1967,7 +2007,7 @@ def _deliver_ki_cb(self) -> None: # sortedcontainers doesn't have types, and is reportedly very hard to type: # https://github.com/grantjenks/python-sortedcontainers/issues/68 - waiting_for_idle: Any = attr.ib(factory=SortedDict) + waiting_for_idle: Any = attrs.Factory(SortedDict) @_public async def wait_all_tasks_blocked(self, cushion: float = 0.0) -> None: @@ -2145,12 +2185,12 @@ def setup_runner( def run( - async_fn: Callable[..., Awaitable[RetT]], - *args: object, + async_fn: Callable[[Unpack[PosArgT]], Awaitable[RetT]], + *args: Unpack[PosArgT], clock: Clock | None = None, instruments: Sequence[Instrument] = (), restrict_keyboard_interrupt_to_checkpoints: bool = False, - strict_exception_groups: bool = False, + strict_exception_groups: bool = True, ) -> RetT: """Run a Trio-flavored async function, and return the result. @@ -2207,9 +2247,10 @@ def run( main thread (this is a Python limitation), or if you use :func:`open_signal_receiver` to catch SIGINT. - strict_exception_groups (bool): If true, nurseries will always wrap even a single - raised exception in an exception group. This can be overridden on the level of - individual nurseries. This will eventually become the default behavior. + strict_exception_groups (bool): Unless set to False, nurseries will always wrap + even a single raised exception in an exception group. This can be overridden + on the level of individual nurseries. Setting it to False will be deprecated + and ultimately removed in a future version of Trio. Returns: Whatever ``async_fn`` returns. @@ -2223,6 +2264,17 @@ def run( propagates it. """ + if strict_exception_groups is not None and not strict_exception_groups: + warn_deprecated( + "trio.run(..., strict_exception_groups=False)", + version="0.25.0", + issue=2929, + instead=( + "the default value of True and rewrite exception handlers to handle ExceptionGroups. " + "See https://trio.readthedocs.io/en/stable/reference-core.html#designing-for-multiple-errors" + ), + use_triodeprecationwarning=True, + ) __tracebackhide__ = True @@ -2260,14 +2312,15 @@ def start_guest_run( async_fn: Callable[..., Awaitable[RetT]], *args: object, run_sync_soon_threadsafe: Callable[[Callable[[], object]], object], - done_callback: Callable[[Outcome[RetT]], object], - run_sync_soon_not_threadsafe: Callable[[Callable[[], object]], object] - | None = None, + done_callback: Callable[[outcome.Outcome[RetT]], object], + run_sync_soon_not_threadsafe: ( + Callable[[Callable[[], object]], object] | None + ) = None, host_uses_signal_set_wakeup_fd: bool = False, clock: Clock | None = None, instruments: Sequence[Instrument] = (), restrict_keyboard_interrupt_to_checkpoints: bool = False, - strict_exception_groups: bool = False, + strict_exception_groups: bool = True, ) -> None: """Start a "guest" run of Trio on top of some other "host" event loop. @@ -2328,6 +2381,18 @@ def my_done_callback(run_outcome): For the meaning of other arguments, see `trio.run`. """ + if strict_exception_groups is not None and not strict_exception_groups: + warn_deprecated( + "trio.start_guest_run(..., strict_exception_groups=False)", + version="0.25.0", + issue=2929, + instead=( + "the default value of True and rewrite exception handlers to handle ExceptionGroups. " + "See https://trio.readthedocs.io/en/stable/reference-core.html#designing-for-multiple-errors" + ), + use_triodeprecationwarning=True, + ) + runner = setup_runner( clock, instruments, @@ -2408,11 +2473,11 @@ def my_done_callback(run_outcome): # straight through. def unrolled_run( runner: Runner, - async_fn: Callable[..., object], - args: tuple[object, ...], + async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]], + args: tuple[Unpack[PosArgT]], host_uses_signal_set_wakeup_fd: bool = False, ) -> Generator[float, EventResult, None]: - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True + sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True __tracebackhide__ = True try: @@ -2546,9 +2611,11 @@ def unrolled_run( try: # We used to unwrap the Outcome object here and send/throw # its contents in directly, but it turns out that .throw() - # is buggy, at least before CPython 3.9: + # is buggy on CPython (all versions at time of writing): # https://bugs.python.org/issue29587 # https://bugs.python.org/issue29590 + # https://bugs.python.org/issue40694 + # https://github.com/python/cpython/issues/108668 # So now we send in the Outcome object and unwrap it on the # other side. msg = task.context.run(next_send_fn, next_send) @@ -2605,8 +2672,7 @@ def unrolled_run( # protocol of unwrapping whatever outcome gets sent in. # Instead, we'll arrange to throw `exc` in directly, # which works for at least asyncio and curio. - # TODO: remove [unused-ignore] when Outcome is typed - runner.reschedule(task, exc) # type: ignore[arg-type, unused-ignore] + runner.reschedule(task, exc) # type: ignore[arg-type] task._next_send_fn = task.coro.throw # prevent long-lived reference # TODO: develop test for this deletion @@ -2616,7 +2682,7 @@ def unrolled_run( runner.instruments.call("after_task_step", task) del GLOBAL_RUN_CONTEXT.task # prevent long-lived references - # TODO: develop test for these deletions + # TODO: develop test for this deletion del task, next_send, next_send_fn except GeneratorExit: diff --git a/trio/_core/_tests/__init__.py b/src/trio/_core/_tests/__init__.py similarity index 100% rename from trio/_core/_tests/__init__.py rename to src/trio/_core/_tests/__init__.py diff --git a/trio/_core/_tests/test_asyncgen.py b/src/trio/_core/_tests/test_asyncgen.py similarity index 88% rename from trio/_core/_tests/test_asyncgen.py rename to src/trio/_core/_tests/test_asyncgen.py index 071be8f2b3..aa64a49d70 100644 --- a/trio/_core/_tests/test_asyncgen.py +++ b/src/trio/_core/_tests/test_asyncgen.py @@ -3,14 +3,16 @@ import contextlib import sys import weakref -from collections.abc import AsyncGenerator from math import inf -from typing import NoReturn +from typing import TYPE_CHECKING, NoReturn import pytest from ... import _core -from .tutil import buggy_pypy_asyncgens, gc_collect_harder, restore_unraisablehook +from .tutil import gc_collect_harder, restore_unraisablehook + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator def test_asyncgen_basics() -> None: @@ -48,11 +50,11 @@ async def async_main() -> None: await _core.wait_all_tasks_blocked() assert collected.pop() == "abandoned" - aiter = example("exhausted 1") + aiter_ = example("exhausted 1") try: - assert await aiter.asend(None) == 42 + assert await aiter_.asend(None) == 42 finally: - await aiter.aclose() + await aiter_.aclose() assert collected.pop() == "exhausted 1" # Also fine if you exhaust it at point of use @@ -63,12 +65,12 @@ async def async_main() -> None: gc_collect_harder() # No problems saving the geniter when using either of these patterns - aiter = example("exhausted 3") + aiter_ = example("exhausted 3") try: - saved.append(aiter) - assert await aiter.asend(None) == 42 + saved.append(aiter_) + assert await aiter_.asend(None) == 42 finally: - await aiter.aclose() + await aiter_.aclose() assert collected.pop() == "exhausted 3" # Also fine if you exhaust it at point of use @@ -78,12 +80,9 @@ async def async_main() -> None: assert collected.pop() == "exhausted 4" # Leave one referenced-but-unexhausted and make sure it gets cleaned up - if buggy_pypy_asyncgens: - collected.append("outlived run") - else: - saved.append(example("outlived run")) - assert await saved[-1].asend(None) == 42 - assert collected == [] + saved.append(example("outlived run")) + assert await saved[-1].asend(None) == 42 + assert collected == [] _core.run(async_main) assert collected.pop() == "outlived run" @@ -116,7 +115,6 @@ async def agen() -> AsyncGenerator[int, None]: assert "during finalization of async generator" in caplog.records[0].message -@pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy") def test_firstiter_after_closing() -> None: saved = [] record = [] @@ -134,16 +132,15 @@ async def funky_agen() -> AsyncGenerator[int, None]: await funky_agen().asend(None) async def async_main() -> None: - aiter = funky_agen() - saved.append(aiter) - assert await aiter.asend(None) == 1 - assert await aiter.asend(None) == 2 + aiter_ = funky_agen() + saved.append(aiter_) + assert await aiter_.asend(None) == 1 + assert await aiter_.asend(None) == 2 _core.run(async_main) assert record == ["cleanup 2", "cleanup 1"] -@pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy") def test_interdependent_asyncgen_cleanup_order() -> None: saved: list[AsyncGenerator[int, None]] = [] record: list[int | str] = [] @@ -231,8 +228,7 @@ async def async_main() -> None: del saved[:] _core.run(async_main) if needs_retry: # pragma: no cover - if not buggy_pypy_asyncgens: - assert record == ["cleaned up"] + assert record == ["cleaned up"] else: assert record == ["final collection", "done", "cleaned up"] break @@ -243,7 +239,7 @@ async def async_main() -> None: ) -async def step_outside_async_context(aiter: AsyncGenerator[int, None]) -> None: +async def step_outside_async_context(aiter_: AsyncGenerator[int, None]) -> None: # abort_fns run outside of task context, at least if they're # triggered by a deadline expiry rather than a direct # cancellation. Thus, an asyncgen first iterated inside one @@ -260,7 +256,7 @@ def abort_fn(_: _core.RaiseCancelT) -> _core.Abort: del abort_fn.aiter # type: ignore[attr-defined] return _core.Abort.SUCCEEDED - abort_fn.aiter = aiter # type: ignore[attr-defined] + abort_fn.aiter = aiter_ # type: ignore[attr-defined] async with _core.open_nursery() as nursery: nursery.start_soon(_core.wait_task_rescheduled, abort_fn) @@ -268,7 +264,6 @@ def abort_fn(_: _core.RaiseCancelT) -> _core.Abort: nursery.cancel_scope.deadline = _core.current_time() -@pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy") async def test_fallback_when_no_hook_claims_it( capsys: pytest.CaptureFixture[str], ) -> None: @@ -299,7 +294,6 @@ async def awaits_after_yield() -> AsyncGenerator[int, None]: assert "awaited something during finalization" in capsys.readouterr().err -@pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy") def test_delegation_to_existing_hooks() -> None: record = [] diff --git a/src/trio/_core/_tests/test_exceptiongroup_gc.py b/src/trio/_core/_tests/test_exceptiongroup_gc.py new file mode 100644 index 0000000000..8957a581a5 --- /dev/null +++ b/src/trio/_core/_tests/test_exceptiongroup_gc.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import gc +import sys +from traceback import extract_tb +from typing import TYPE_CHECKING, Callable, NoReturn + +import pytest + +from .._concat_tb import concat_tb + +if TYPE_CHECKING: + from types import TracebackType + +if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup + + +def raiser1() -> NoReturn: + raiser1_2() + + +def raiser1_2() -> NoReturn: + raiser1_3() + + +def raiser1_3() -> NoReturn: + raise ValueError("raiser1_string") + + +def raiser2() -> NoReturn: + raiser2_2() + + +def raiser2_2() -> NoReturn: + raise KeyError("raiser2_string") + + +def get_exc(raiser: Callable[[], NoReturn]) -> Exception: + try: + raiser() + except Exception as exc: + return exc + raise AssertionError("raiser should always raise") # pragma: no cover + + +def get_tb(raiser: Callable[[], NoReturn]) -> TracebackType | None: + return get_exc(raiser).__traceback__ + + +def test_concat_tb() -> None: + tb1 = get_tb(raiser1) + tb2 = get_tb(raiser2) + + # These return a list of (filename, lineno, fn name, text) tuples + # https://docs.python.org/3/library/traceback.html#traceback.extract_tb + entries1 = extract_tb(tb1) + entries2 = extract_tb(tb2) + + tb12 = concat_tb(tb1, tb2) + assert extract_tb(tb12) == entries1 + entries2 + + tb21 = concat_tb(tb2, tb1) + assert extract_tb(tb21) == entries2 + entries1 + + # Check degenerate cases + assert extract_tb(concat_tb(None, tb1)) == entries1 + assert extract_tb(concat_tb(tb1, None)) == entries1 + assert concat_tb(None, None) is None + + # Make sure the original tracebacks didn't get mutated by mistake + assert extract_tb(get_tb(raiser1)) == entries1 + assert extract_tb(get_tb(raiser2)) == entries2 + + +# Unclear if this can still fail, removing the `del` from _concat_tb.copy_tb does not seem +# to trigger it (on a platform where the `del` is executed) +@pytest.mark.skipif( + sys.implementation.name != "cpython", reason="Only makes sense with refcounting GC" +) +def test_ExceptionGroup_catch_doesnt_create_cyclic_garbage() -> None: + # https://github.com/python-trio/trio/pull/2063 + gc.collect() + old_flags = gc.get_debug() + + def make_multi() -> NoReturn: + raise ExceptionGroup("", [get_exc(raiser1), get_exc(raiser2)]) + + try: + gc.set_debug(gc.DEBUG_SAVEALL) + with pytest.raises(ExceptionGroup) as excinfo: + # covers ~~MultiErrorCatcher.__exit__ and~~ _concat_tb.copy_tb + # TODO: is the above comment true anymore? as this no longer uses MultiError.catch + raise make_multi() + for exc in excinfo.value.exceptions: + assert isinstance(exc, (ValueError, KeyError)) + gc.collect() + assert not gc.garbage + finally: + gc.set_debug(old_flags) + gc.garbage.clear() diff --git a/trio/_core/_tests/test_guest_mode.py b/src/trio/_core/_tests/test_guest_mode.py similarity index 95% rename from trio/_core/_tests/test_guest_mode.py rename to src/trio/_core/_tests/test_guest_mode.py index c1a9b11815..8972ec735a 100644 --- a/trio/_core/_tests/test_guest_mode.py +++ b/src/trio/_core/_tests/test_guest_mode.py @@ -25,19 +25,19 @@ import pytest from outcome import Outcome -from pytest import MonkeyPatch, WarningsRecorder import trio import trio.testing -from trio._channel import MemorySendChannel from trio.abc import Instrument from ..._util import signal_raise -from .tutil import buggy_pypy_asyncgens, gc_collect_harder, restore_unraisablehook +from .tutil import gc_collect_harder, restore_unraisablehook if TYPE_CHECKING: from typing_extensions import TypeAlias + from trio._channel import MemorySendChannel + T = TypeVar("T") InHost: TypeAlias = Callable[[object], None] @@ -111,7 +111,7 @@ def done_callback(outcome: Outcome[T]) -> None: def test_guest_trivial() -> None: async def trio_return(in_host: InHost) -> str: - await trio.sleep(0) + await trio.lowlevel.checkpoint() return "ok" assert trivial_guest_run(trio_return) == "ok" @@ -149,7 +149,7 @@ def test_guest_is_initialized_when_start_returns() -> None: async def trio_main(in_host: InHost) -> str: record.append("main task ran") - await trio.sleep(0) + await trio.lowlevel.checkpoint() assert trio.lowlevel.current_trio_token() is trio_token return "ok" @@ -164,23 +164,22 @@ def after_start() -> None: @trio.lowlevel.spawn_system_task async def early_task() -> None: record.append("system task ran") - await trio.sleep(0) + await trio.lowlevel.checkpoint() res = trivial_guest_run(trio_main, in_host_after_start=after_start) assert res == "ok" assert set(record) == {"system task ran", "main task ran", "run_sync_soon cb ran"} + class BadClock: + def start_clock(self) -> NoReturn: + raise ValueError("whoops") + + def after_start_never_runs() -> None: # pragma: no cover + pytest.fail("shouldn't get here") + # Errors during initialization (which can only be TrioInternalErrors) # are raised out of start_guest_run, not out of the done_callback with pytest.raises(trio.TrioInternalError): - - class BadClock: - def start_clock(self) -> NoReturn: - raise ValueError("whoops") - - def after_start_never_runs() -> None: # pragma: no cover - pytest.fail("shouldn't get here") - trivial_guest_run( trio_main, clock=BadClock(), in_host_after_start=after_start_never_runs ) @@ -397,7 +396,7 @@ def do_abandoned_guest_run() -> None: async def abandoned_main(in_host: InHost) -> None: in_host(lambda: 1 / 0) while True: - await trio.sleep(0) + await trio.lowlevel.checkpoint() with pytest.raises(ZeroDivisionError): trivial_guest_run(abandoned_main) @@ -473,7 +472,7 @@ async def trio_main() -> str: # Make sure we have at least one tick where we don't need to go into # the thread - await trio.sleep(0) + await trio.lowlevel.checkpoint() from_trio.put_nowait(0) @@ -526,7 +525,7 @@ async def aio_pingpong( def test_guest_mode_internal_errors( - monkeypatch: MonkeyPatch, recwarn: WarningsRecorder + monkeypatch: pytest.MonkeyPatch, recwarn: pytest.WarningsRecorder ) -> None: with monkeypatch.context() as m: @@ -541,7 +540,7 @@ async def crash_in_run_loop(in_host: InHost) -> None: async def crash_in_io(in_host: InHost) -> None: m.setattr("trio._core._run.TheIOManager.get_events", None) - await trio.sleep(0) + await trio.lowlevel.checkpoint() with pytest.raises(trio.TrioInternalError): trivial_guest_run(crash_in_io) @@ -623,7 +622,6 @@ async def trio_main(in_host: InHost) -> None: assert end - start < DURATION / 2 -@pytest.mark.skipif(buggy_pypy_asyncgens, reason="PyPy 7.2 is buggy") @restore_unraisablehook() def test_guest_mode_asyncgens() -> None: import sniffio @@ -657,9 +655,6 @@ async def trio_main() -> None: # Ensure we don't pollute the thread-level context if run under # an asyncio without contextvars support (3.6) context = contextvars.copy_context() - if TYPE_CHECKING: - aiotrio_run(trio_main, host_uses_signal_set_wakeup_fd=True) - # this type error is a bug in typeshed or mypy, as it's equivalent to the above line - context.run(aiotrio_run, trio_main, host_uses_signal_set_wakeup_fd=True) # type: ignore[arg-type] + context.run(aiotrio_run, trio_main, host_uses_signal_set_wakeup_fd=True) assert record == {("asyncio", "asyncio"), ("trio", "trio")} diff --git a/trio/_core/_tests/test_instrumentation.py b/src/trio/_core/_tests/test_instrumentation.py similarity index 96% rename from trio/_core/_tests/test_instrumentation.py rename to src/trio/_core/_tests/test_instrumentation.py index f743f2b3d4..32335ae7fa 100644 --- a/trio/_core/_tests/test_instrumentation.py +++ b/src/trio/_core/_tests/test_instrumentation.py @@ -1,18 +1,20 @@ from __future__ import annotations -from typing import Container, Iterable, NoReturn +from typing import TYPE_CHECKING, Container, Iterable, NoReturn -import attr +import attrs import pytest from ... import _abc, _core -from ...lowlevel import Task from .tutil import check_sequence_matches +if TYPE_CHECKING: + from ...lowlevel import Task -@attr.s(eq=False, hash=False) + +@attrs.define(eq=False, hash=False, slots=False) class TaskRecorder(_abc.Instrument): - record: list[tuple[str, Task | None]] = attr.ib(factory=list) + record: list[tuple[str, Task | None]] = attrs.Factory(list) def before_run(self) -> None: self.record.append(("before_run", None)) @@ -253,7 +255,7 @@ def after_run(self) -> NoReturn: raise ValueError("oops") async def main() -> None: - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="^oops$"): _core.add_instrument(EvilInstrument()) # Make sure the instrument is fully removed from the per-method lists diff --git a/trio/_core/_tests/test_io.py b/src/trio/_core/_tests/test_io.py similarity index 98% rename from trio/_core/_tests/test_io.py rename to src/trio/_core/_tests/test_io.py index 039dfbef01..acecc9d6c6 100644 --- a/trio/_core/_tests/test_io.py +++ b/src/trio/_core/_tests/test_io.py @@ -2,7 +2,6 @@ import random import socket as stdlib_socket -from collections.abc import Generator from contextlib import suppress from typing import TYPE_CHECKING, Awaitable, Callable, Tuple, TypeVar @@ -16,6 +15,8 @@ # Cross-platform tests for IO handling if TYPE_CHECKING: + from collections.abc import Generator + from typing_extensions import ParamSpec ArgsT = ParamSpec("ArgsT") @@ -318,7 +319,10 @@ async def test_wait_on_invalid_object() -> None: fileno = s.fileno() # We just closed the socket and don't do anything else in between, so # we can be confident that the fileno hasn't be reassigned. - with pytest.raises(OSError): + with pytest.raises( + OSError, + match=r"^\[\w+ \d+] (Bad file descriptor|An operation was attempted on something that is not a socket)$", + ): await wait(fileno) diff --git a/trio/_core/_tests/test_ki.py b/src/trio/_core/_tests/test_ki.py similarity index 97% rename from trio/_core/_tests/test_ki.py rename to src/trio/_core/_tests/test_ki.py index cd98bc9bca..e4241fc762 100644 --- a/trio/_core/_tests/test_ki.py +++ b/src/trio/_core/_tests/test_ki.py @@ -9,6 +9,8 @@ import outcome import pytest +from trio.testing import RaisesGroup + try: from async_generator import async_generator, yield_ except ImportError: # pragma: no cover @@ -293,7 +295,8 @@ async def check_unprotected_kill() -> None: nursery.start_soon(sleeper, "s2", record_set) nursery.start_soon(raiser, "r1", record_set) - with pytest.raises(KeyboardInterrupt): + # raises inside a nursery, so the KeyboardInterrupt is wrapped in an ExceptionGroup + with RaisesGroup(KeyboardInterrupt): _core.run(check_unprotected_kill) assert record_set == {"s1 ok", "s2 ok", "r1 raise ok"} @@ -309,7 +312,8 @@ async def check_protected_kill() -> None: nursery.start_soon(_core.enable_ki_protection(raiser), "r1", record_set) # __aexit__ blocks, and then receives the KI - with pytest.raises(KeyboardInterrupt): + # raises inside a nursery, so the KeyboardInterrupt is wrapped in an ExceptionGroup + with RaisesGroup(KeyboardInterrupt): _core.run(check_protected_kill) assert record_set == {"s1 ok", "s2 ok", "r1 cancel ok"} @@ -331,6 +335,7 @@ def kill_during_shutdown() -> None: token.run_sync_soon(kill_during_shutdown) + # no nurseries involved, so the KeyboardInterrupt isn't wrapped with pytest.raises(KeyboardInterrupt): _core.run(check_kill_during_shutdown) @@ -344,6 +349,7 @@ def before_run(self) -> None: async def main_1() -> None: await _core.checkpoint() + # no nurseries involved, so the KeyboardInterrupt isn't wrapped with pytest.raises(KeyboardInterrupt): _core.run(main_1, instruments=[InstrumentOfDeath()]) diff --git a/trio/_core/_tests/test_local.py b/src/trio/_core/_tests/test_local.py similarity index 91% rename from trio/_core/_tests/test_local.py rename to src/trio/_core/_tests/test_local.py index 5fdf54b13c..16f763814e 100644 --- a/trio/_core/_tests/test_local.py +++ b/src/trio/_core/_tests/test_local.py @@ -57,13 +57,13 @@ async def reset_check() -> None: t2.reset(token2) assert t2.get() == "dogfish" - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="^token has already been used$"): t2.reset(token2) token3 = t3.set("basculin") assert t3.get() == "basculin" - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="^token is not for us$"): t1.reset(token3) run(reset_check) @@ -77,8 +77,8 @@ async def task1() -> None: t1.set("plaice") assert t1.get() == "plaice" - async def task2(tok: str) -> None: - t1.reset(token) + async def task2(tok: RunVarToken[str]) -> None: + t1.reset(tok) with pytest.raises(LookupError): t1.get() diff --git a/trio/_core/_tests/test_mock_clock.py b/src/trio/_core/_tests/test_mock_clock.py similarity index 97% rename from trio/_core/_tests/test_mock_clock.py rename to src/trio/_core/_tests/test_mock_clock.py index 1a0c8b3444..6b0f1ca76b 100644 --- a/trio/_core/_tests/test_mock_clock.py +++ b/src/trio/_core/_tests/test_mock_clock.py @@ -20,14 +20,14 @@ def test_mock_clock() -> None: assert c.current_time() == 0 c.jump(1.2) assert c.current_time() == 1.2 - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="^time can't go backwards$"): c.jump(-1) assert c.current_time() == 1.2 assert c.deadline_to_sleep_time(1.1) == 0 assert c.deadline_to_sleep_time(1.2) == 0 assert c.deadline_to_sleep_time(1.3) > 999999 - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="^rate must be >= 0$"): c.rate = -1 assert c.rate == 0 diff --git a/trio/_core/_tests/test_parking_lot.py b/src/trio/_core/_tests/test_parking_lot.py similarity index 98% rename from trio/_core/_tests/test_parking_lot.py rename to src/trio/_core/_tests/test_parking_lot.py index 40c55f1f2e..353c1ba45d 100644 --- a/trio/_core/_tests/test_parking_lot.py +++ b/src/trio/_core/_tests/test_parking_lot.py @@ -78,7 +78,9 @@ async def waiter(i: int, lot: ParkingLot) -> None: ) lot.unpark_all() - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match=r"^Cannot pop a non-integer number of tasks\.$" + ): lot.unpark(count=1.5) diff --git a/trio/_core/_tests/test_run.py b/src/trio/_core/_tests/test_run.py similarity index 80% rename from trio/_core/_tests/test_run.py rename to src/trio/_core/_tests/test_run.py index a9f663e7ef..ee823cb81a 100644 --- a/trio/_core/_tests/test_run.py +++ b/src/trio/_core/_tests/test_run.py @@ -8,29 +8,26 @@ import time import types import weakref -from collections.abc import ( - AsyncGenerator, - AsyncIterator, - Awaitable, - Callable, - Generator, -) from contextlib import ExitStack, contextmanager, suppress from math import inf -from typing import Any, NoReturn, TypeVar, cast +from typing import TYPE_CHECKING, Any, NoReturn, TypeVar, cast import outcome import pytest import sniffio from ... import _core -from ..._core._multierror import MultiError, NonBaseMultiError from ..._threads import to_thread_run_sync from ..._timeouts import fail_after, sleep -from ...testing import Sequencer, assert_checkpoints, wait_all_tasks_blocked +from ...testing import ( + Matcher, + RaisesGroup, + Sequencer, + assert_checkpoints, + wait_all_tasks_blocked, +) from .._run import DEADLINE_HEAP_MIN_PRUNE_THRESHOLD from .tutil import ( - buggy_pypy_asyncgens, check_sequence_matches, create_asyncio_future_in_new_loop, gc_collect_harder, @@ -39,6 +36,15 @@ slow, ) +if TYPE_CHECKING: + from collections.abc import ( + AsyncGenerator, + AsyncIterator, + Awaitable, + Callable, + Generator, + ) + if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup, ExceptionGroup @@ -70,7 +76,7 @@ async def trivial(x: T) -> T: with pytest.raises(TypeError): # Missing an argument - _core.run(trivial) + _core.run(trivial) # type: ignore[arg-type] with pytest.raises(TypeError): # Not an async function @@ -87,7 +93,7 @@ def test_initial_task_error() -> None: async def main(x: object) -> NoReturn: raise ValueError(x) - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError, match="^17$") as excinfo: _core.run(main, 17) assert excinfo.value.args == (17,) @@ -105,8 +111,8 @@ async def main() -> None: # pragma: no cover async def test_nursery_warn_use_async_with() -> None: + on = _core.open_nursery() with pytest.raises(RuntimeError) as excinfo: - on = _core.open_nursery() with on: # type: ignore pass # pragma: no cover excinfo.match( @@ -121,24 +127,21 @@ async def test_nursery_warn_use_async_with() -> None: async def test_nursery_main_block_error_basic() -> None: exc = ValueError("whoops") - with pytest.raises(ValueError) as excinfo: + with RaisesGroup(Matcher(check=lambda e: e is exc)): async with _core.open_nursery(): raise exc - assert excinfo.value is exc async def test_child_crash_basic() -> None: - exc = ValueError("uh oh") + my_exc = ValueError("uh oh") async def erroring() -> NoReturn: - raise exc + raise my_exc - try: + with RaisesGroup(Matcher(check=lambda e: e is my_exc)): # nursery.__aexit__ propagates exception from child back to parent async with _core.open_nursery() as nursery: nursery.start_soon(erroring) - except ValueError as e: - assert e is exc async def test_basic_interleave() -> None: @@ -176,16 +179,15 @@ async def main() -> None: nursery.start_soon(looper) nursery.start_soon(crasher) - with pytest.raises(ValueError) as excinfo: + with RaisesGroup(Matcher(ValueError, "^argh$")): _core.run(main) assert looper_record == ["cancelled"] - assert excinfo.value.args == ("argh",) def test_main_and_task_both_crash() -> None: - # If main crashes and there's also a task crash, then we get both in a - # MultiError + # If main crashes and there's also a task crash, then we get both in an + # ExceptionGroup async def crasher() -> NoReturn: raise ValueError @@ -194,13 +196,8 @@ async def main() -> NoReturn: nursery.start_soon(crasher) raise KeyError - with pytest.raises(MultiError) as excinfo: + with RaisesGroup(ValueError, KeyError): _core.run(main) - print(excinfo.value) - assert {type(exc) for exc in excinfo.value.exceptions} == { - ValueError, - KeyError, - } def test_two_child_crashes() -> None: @@ -212,19 +209,15 @@ async def main() -> None: nursery.start_soon(crasher, KeyError) nursery.start_soon(crasher, ValueError) - with pytest.raises(MultiError) as excinfo: + with RaisesGroup(ValueError, KeyError): _core.run(main) - assert {type(exc) for exc in excinfo.value.exceptions} == { - ValueError, - KeyError, - } async def test_child_crash_wakes_parent() -> None: async def crasher() -> NoReturn: - raise ValueError + raise ValueError("this is a crash") - with pytest.raises(ValueError): + with RaisesGroup(Matcher(ValueError, "^this is a crash$")): async with _core.open_nursery() as nursery: nursery.start_soon(crasher) await sleep_forever() @@ -242,7 +235,7 @@ async def child1() -> None: print("child1 woke") assert x == 0 print("child1 rescheduling t2") - _core.reschedule(not_none(t2), outcome.Error(ValueError())) + _core.reschedule(not_none(t2), outcome.Error(ValueError("error message"))) print("child1 exit") async def child2() -> None: @@ -251,7 +244,7 @@ async def child2() -> None: t2 = _core.current_task() _core.reschedule(not_none(t1), outcome.Value(0)) print("child2 sleep") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="^error message$"): await sleep_forever() print("child2 successful exit") @@ -266,7 +259,7 @@ async def test_current_time() -> None: t1 = _core.current_time() # Windows clock is pretty low-resolution -- appveyor tests fail unless we # sleep for a bit here. - time.sleep(time.get_clock_info("perf_counter").resolution) # noqa: ASYNC101 + time.sleep(time.get_clock_info("perf_counter").resolution) # noqa: ASYNC251 t2 = _core.current_time() assert t1 < t2 @@ -427,13 +420,22 @@ async def test_cancel_edge_cases() -> None: await sleep_forever() -async def test_cancel_scope_multierror_filtering() -> None: +async def test_cancel_scope_exceptiongroup_filtering() -> None: async def crasher() -> NoReturn: raise KeyError - try: + # This is outside the outer scope, so all the Cancelled + # exceptions should have been absorbed, leaving just a regular + # KeyError from crasher(), wrapped in an ExceptionGroup + with RaisesGroup(KeyError): with _core.CancelScope() as outer: - try: + # Since the outer scope became cancelled before the + # nursery block exited, all cancellations inside the + # nursery block continue propagating to reach the + # outer scope. + with RaisesGroup( + _core.Cancelled, _core.Cancelled, _core.Cancelled, KeyError + ) as excinfo: async with _core.open_nursery() as nursery: # Two children that get cancelled by the nursery scope nursery.start_soon(sleep_forever) # t1 @@ -447,27 +449,9 @@ async def crasher() -> NoReturn: # And one that raises a different error nursery.start_soon(crasher) # t4 # and then our __aexit__ also receives an outer Cancelled - except MultiError as multi_exc: - # Since the outer scope became cancelled before the - # nursery block exited, all cancellations inside the - # nursery block continue propagating to reach the - # outer scope. - assert len(multi_exc.exceptions) == 4 - summary: dict[type, int] = {} - for exc in multi_exc.exceptions: - summary.setdefault(type(exc), 0) - summary[type(exc)] += 1 - assert summary == {_core.Cancelled: 3, KeyError: 1} - raise - except AssertionError: # pragma: no cover - raise - except BaseException as exc: - # This is outside the outer scope, so all the Cancelled - # exceptions should have been absorbed, leaving just a regular - # KeyError from crasher() - assert type(exc) is KeyError - else: # pragma: no cover - raise AssertionError() + # reraise the exception caught by RaisesGroup for the + # CancelScope to handle + raise excinfo.value async def test_precancelled_task() -> None: @@ -664,7 +648,8 @@ async def test_unshield_while_cancel_propagating() -> None: await _core.checkpoint() finally: inner.shield = True - assert outer.cancelled_caught and not inner.cancelled_caught + assert outer.cancelled_caught + assert not inner.cancelled_caught async def test_cancel_unbound() -> None: @@ -783,17 +768,26 @@ async def task2() -> None: await wait_all_tasks_blocked() nursery.cancel_scope.__exit__(None, None, None) finally: - with pytest.raises(RuntimeError) as exc_info: + with pytest.raises( + RuntimeError, match="which had already been exited" + ) as exc_info: await nursery_mgr.__aexit__(*sys.exc_info()) - assert "which had already been exited" in str(exc_info.value) - assert type(exc_info.value.__context__) is NonBaseMultiError - assert len(exc_info.value.__context__.exceptions) == 3 - cancelled_in_context = False - for exc in exc_info.value.__context__.exceptions: - assert isinstance(exc, RuntimeError) - assert "closed before the task exited" in str(exc) - cancelled_in_context |= isinstance(exc.__context__, _core.Cancelled) - assert cancelled_in_context # for the sleep_forever + + def no_context(exc: RuntimeError) -> bool: + return exc.__context__ is None + + msg = "closed before the task exited" + group = RaisesGroup( + Matcher(RuntimeError, match=msg, check=no_context), + Matcher(RuntimeError, match=msg, check=no_context), + # sleep_forever + Matcher( + RuntimeError, + match=msg, + check=lambda x: isinstance(x.__context__, _core.Cancelled), + ), + ) + assert group.matches(exc_info.value.__context__) # Trying to exit a cancel scope from an unrelated task raises an error # without affecting any state @@ -928,7 +922,7 @@ async def main() -> None: _core.run(main) -def test_system_task_crash_MultiError() -> None: +def test_system_task_crash_ExceptionGroup() -> None: async def crasher1() -> NoReturn: raise KeyError @@ -944,19 +938,21 @@ async def main() -> None: _core.spawn_system_task(system_task) await sleep_forever() + # TrioInternalError is not wrapped with pytest.raises(_core.TrioInternalError) as excinfo: _core.run(main) - me = excinfo.value.__cause__ - assert isinstance(me, MultiError) - assert len(me.exceptions) == 2 - for exc in me.exceptions: - assert isinstance(exc, (KeyError, ValueError)) + # the first exceptiongroup is from the first nursery opened in Runner.init() + # the second exceptiongroup is from the second nursery opened in Runner.init() + # the third exceptongroup is from the nursery defined in `system_task` above + assert RaisesGroup(RaisesGroup(RaisesGroup(KeyError, ValueError))).matches( + excinfo.value.__cause__ + ) def test_system_task_crash_plus_Cancelled() -> None: # Set up a situation where a system task crashes with a - # MultiError([Cancelled, ValueError]) + # ExceptionGroup([Cancelled, ValueError]) async def crasher() -> None: try: await sleep_forever() @@ -977,7 +973,11 @@ async def main() -> None: with pytest.raises(_core.TrioInternalError) as excinfo: _core.run(main) - assert type(excinfo.value.__cause__) is ValueError + + # See explanation for triple-wrap in test_system_task_crash_ExceptionGroup + assert RaisesGroup(RaisesGroup(RaisesGroup(ValueError))).matches( + excinfo.value.__cause__ + ) def test_system_task_crash_KeyboardInterrupt() -> None: @@ -990,7 +990,8 @@ async def main() -> None: with pytest.raises(_core.TrioInternalError) as excinfo: _core.run(main) - assert isinstance(excinfo.value.__cause__, KeyboardInterrupt) + # "Only" double-wrapped since ki() doesn't create an exceptiongroup + assert RaisesGroup(RaisesGroup(KeyboardInterrupt)).matches(excinfo.value.__cause__) # This used to fail because checkpoint was a yield followed by an immediate @@ -1005,8 +1006,8 @@ async def main() -> None: async def test_yield_briefly_checks_for_timeout(mock_clock: _core.MockClock) -> None: with _core.CancelScope(deadline=_core.current_time() + 5): await _core.checkpoint() + mock_clock.jump(10) with pytest.raises(_core.Cancelled): - mock_clock.jump(10) await _core.checkpoint() @@ -1021,11 +1022,11 @@ async def test_exc_info() -> None: seq = Sequencer() async def child1() -> None: - with pytest.raises(ValueError) as excinfo: + async with seq(0): + pass # we don't yield until seq(2) below + record.append("child1 raise") + with pytest.raises(ValueError, match="^child1$") as excinfo: try: - async with seq(0): - pass # we don't yield until seq(2) below - record.append("child1 raise") raise ValueError("child1") except ValueError: record.append("child1 sleep") @@ -1038,12 +1039,12 @@ async def child1() -> None: record.append("child1 success") async def child2() -> None: + async with seq(1): + pass # we don't yield until seq(3) below + assert "child1 sleep" in record + record.append("child2 wake") + assert sys.exc_info() == (None, None, None) with pytest.raises(KeyError) as excinfo: - async with seq(1): - pass # we don't yield until seq(3) below - assert "child1 sleep" in record - record.append("child2 wake") - assert sys.exc_info() == (None, None, None) try: raise KeyError("child2") except KeyError: @@ -1072,12 +1073,19 @@ async def child2() -> None: ] -# Before CPython 3.9, using .throw() to raise an exception inside a -# coroutine/generator causes the original exc_info state to be lost, so things -# like re-raising and exception chaining are broken. +# On all CPython versions (at time of writing), using .throw() to raise an +# exception inside a coroutine/generator can cause the original `exc_info` state +# to be lost, so things like re-raising and exception chaining are broken unless +# Trio implements its workaround. At time of writing, CPython main (3.13-dev) +# and every CPython release (excluding releases for old Python versions not +# supported by Trio) is affected (albeit in differing ways). # -# https://bugs.python.org/issue29587 -async def test_exc_info_after_yield_error() -> None: +# If the `ValueError()` gets sent in via `throw` and is suppressed, then CPython +# loses track of the original `exc_info`: +# https://bugs.python.org/issue25612 (Example 1) +# https://bugs.python.org/issue29587 (Example 2) +# This is fixed in CPython >= 3.7. +async def test_exc_info_after_throw_suppressed() -> None: child_task: _core.Task | None = None async def child() -> None: @@ -1086,21 +1094,26 @@ async def child() -> None: try: raise KeyError - except Exception: - with suppress(Exception): + except KeyError: + with suppress(ValueError): await sleep_forever() raise - with pytest.raises(KeyError): + with RaisesGroup(Matcher(KeyError, check=lambda e: e.__context__ is None)): async with _core.open_nursery() as nursery: nursery.start_soon(child) await wait_all_tasks_blocked() _core.reschedule(not_none(child_task), outcome.Error(ValueError())) -# Similar to previous test -- if the ValueError() gets sent in via 'throw', -# then Python's normal implicit chaining stuff is broken. -async def test_exception_chaining_after_yield_error() -> None: +# Similar to previous test -- if the `ValueError()` gets sent in via 'throw' and +# propagates out, then CPython doesn't set its `__context__` so normal implicit +# exception chaining is broken: +# https://bugs.python.org/issue25612 (Example 2) +# https://bugs.python.org/issue25683 +# https://bugs.python.org/issue29587 (Example 1) +# This is fixed in CPython >= 3.9. +async def test_exception_chaining_after_throw() -> None: child_task: _core.Task | None = None async def child() -> None: @@ -1109,31 +1122,98 @@ async def child() -> None: try: raise KeyError - except Exception: + except KeyError: await sleep_forever() - with pytest.raises(ValueError) as excinfo: + with RaisesGroup( + Matcher(ValueError, "error text", lambda e: isinstance(e.__context__, KeyError)) + ): + async with _core.open_nursery() as nursery: + nursery.start_soon(child) + await wait_all_tasks_blocked() + _core.reschedule( + not_none(child_task), outcome.Error(ValueError("error text")) + ) + + +# Similar to previous tests -- if the `ValueError()` gets sent into an inner +# `await` via 'throw' and is suppressed there, then CPython loses track of +# `exc_info` in the inner coroutine: +# https://github.com/python/cpython/issues/108668 +# This is unfixed in CPython at time of writing. +async def test_exc_info_after_throw_to_inner_suppressed() -> None: + child_task: _core.Task | None = None + + async def child() -> None: + nonlocal child_task + child_task = _core.current_task() + + try: + raise KeyError + except KeyError as exc: + await inner(exc) + raise + + async def inner(exc: BaseException) -> None: + with suppress(ValueError): + await sleep_forever() + assert not_none(sys.exc_info()[1]) is exc + + with RaisesGroup(Matcher(KeyError, check=lambda e: e.__context__ is None)): async with _core.open_nursery() as nursery: nursery.start_soon(child) await wait_all_tasks_blocked() _core.reschedule(not_none(child_task), outcome.Error(ValueError())) - assert isinstance(excinfo.value.__context__, KeyError) + +# Similar to previous tests -- if the `ValueError()` gets sent into an inner +# `await` via `throw` and propagates out, then CPython incorrectly sets its +# `__context__` so normal implicit exception chaining is broken: +# https://bugs.python.org/issue40694 +# This is unfixed in CPython at time of writing. +async def test_exception_chaining_after_throw_to_inner() -> None: + child_task: _core.Task | None = None + + async def child() -> None: + nonlocal child_task + child_task = _core.current_task() + + try: + raise KeyError + except KeyError: + await inner() + + async def inner() -> None: + try: + raise IndexError + except IndexError: + await sleep_forever() + + with RaisesGroup( + Matcher( + ValueError, + "^Unique Text$", + lambda e: isinstance(e.__context__, IndexError) + and isinstance(e.__context__.__context__, KeyError), + ) + ): + async with _core.open_nursery() as nursery: + nursery.start_soon(child) + await wait_all_tasks_blocked() + _core.reschedule( + not_none(child_task), outcome.Error(ValueError("Unique Text")) + ) -@pytest.mark.skipif( - sys.version_info < (3, 6, 2), reason="https://bugs.python.org/issue29600" -) async def test_nursery_exception_chaining_doesnt_make_context_loops() -> None: async def crasher() -> NoReturn: raise KeyError - with pytest.raises(MultiError) as excinfo: + # the ExceptionGroup should not have the KeyError or ValueError as context + with RaisesGroup(ValueError, KeyError, check=lambda x: x.__context__ is None): async with _core.open_nursery() as nursery: nursery.start_soon(crasher) raise ValueError - # the MultiError should not have the KeyError or ValueError as context - assert excinfo.value.__context__ is None def test_TrioToken_identity() -> None: @@ -1234,9 +1314,9 @@ async def main() -> None: # After main exits but before finally cleaning up, callback processed # normally token.run_sync_soon(lambda: record.append("sync-cb")) - raise ValueError + raise ValueError("error text") - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="^error text$"): _core.run(main) assert record == ["sync-cb"] @@ -1258,8 +1338,9 @@ async def main() -> None: with pytest.raises(_core.TrioInternalError) as excinfo: _core.run(main) - - assert type(excinfo.value.__cause__) is KeyError + # the first exceptiongroup is from the first nursery opened in Runner.init() + # the second exceptiongroup is from the second nursery opened in Runner.init() + assert RaisesGroup(RaisesGroup(KeyError)).matches(excinfo.value.__cause__) assert record == {"2nd run_sync_soon ran", "cancelled!"} @@ -1351,7 +1432,6 @@ def cb(i: int) -> None: assert counter[0] == COUNT -@pytest.mark.skipif(buggy_pypy_asyncgens, reason="PyPy 7.2 is buggy") def test_TrioToken_run_sync_soon_late_crash() -> None: # Crash after system nursery is closed -- easiest way to do that is # from an async generator finalizer. @@ -1374,22 +1454,23 @@ async def main() -> None: with pytest.raises(_core.TrioInternalError) as excinfo: _core.run(main) - assert type(excinfo.value.__cause__) is KeyError + assert RaisesGroup(KeyError).matches(excinfo.value.__cause__) assert record == ["main exiting", "2nd ran"] async def test_slow_abort_basic() -> None: with _core.CancelScope() as scope: scope.cancel() - with pytest.raises(_core.Cancelled): - task = _core.current_task() - token = _core.current_trio_token() - def slow_abort(raise_cancel: _core.RaiseCancelT) -> _core.Abort: - result = outcome.capture(raise_cancel) - token.run_sync_soon(_core.reschedule, task, result) - return _core.Abort.FAILED + task = _core.current_task() + token = _core.current_trio_token() + + def slow_abort(raise_cancel: _core.RaiseCancelT) -> _core.Abort: + result = outcome.capture(raise_cancel) + token.run_sync_soon(_core.reschedule, task, result) + return _core.Abort.FAILED + with pytest.raises(_core.Cancelled): await _core.wait_task_rescheduled(slow_abort) @@ -1406,8 +1487,8 @@ def slow_abort(raise_cancel: _core.RaiseCancelT) -> _core.Abort: token.run_sync_soon(_core.reschedule, task, result) return _core.Abort.FAILED + record.append("sleeping") with pytest.raises(_core.Cancelled): - record.append("sleeping") await _core.wait_task_rescheduled(slow_abort) record.append("cancelled") # blocking again, this time it's okay, because we're shielded @@ -1480,6 +1561,7 @@ async def child2() -> None: assert tasks["child2"].child_nurseries == [] async def child1( + *, task_status: _core.TaskStatus[None] = _core.TASK_STATUS_IGNORED, ) -> None: me = tasks["child1"] = _core.current_task() @@ -1583,21 +1665,34 @@ async def main() -> None: _core.run(main) - for bad_call in bad_call_run, bad_call_spawn: - - async def f() -> None: # pragma: no cover - pass - - with pytest.raises(TypeError, match="expecting an async function"): - bad_call(f()) # type: ignore[arg-type] - - async def async_gen(arg: T) -> AsyncGenerator[T, None]: # pragma: no cover - yield arg + async def f() -> None: # pragma: no cover + pass - with pytest.raises( - TypeError, match="expected an async function but got an async generator" - ): - bad_call(async_gen, 0) # type: ignore + async def async_gen(arg: T) -> AsyncGenerator[T, None]: # pragma: no cover + yield arg + + # If/when RaisesGroup/Matcher is added to pytest in some form this test can be + # rewritten to use a loop again, and avoid specifying the exceptions twice in + # different ways + with pytest.raises( + TypeError, + match="^Trio was expecting an async function, but instead it got a coroutine object <.*>", + ): + bad_call_run(f()) # type: ignore[arg-type] + with pytest.raises( + TypeError, match="expected an async function but got an async generator" + ): + bad_call_run(async_gen, 0) # type: ignore + + # bad_call_spawn calls the function inside a nursery, so the exception will be + # wrapped in an exceptiongroup + with RaisesGroup(Matcher(TypeError, "expecting an async function")): + bad_call_spawn(f()) # type: ignore[arg-type] + + with RaisesGroup( + Matcher(TypeError, "expected an async function but got an async generator") + ): + bad_call_spawn(async_gen, 0) # type: ignore def test_calling_asyncio_function_gives_nice_error() -> None: @@ -1607,10 +1702,9 @@ async def child_xyzzy() -> None: async def misguided() -> None: await child_xyzzy() - with pytest.raises(TypeError) as excinfo: + with pytest.raises(TypeError, match="asyncio") as excinfo: _core.run(misguided) - assert "asyncio" in str(excinfo.value) # The traceback should point to the location of the foreign await assert any( # pragma: no branch entry.name == "child_xyzzy" for entry in excinfo.traceback @@ -1619,11 +1713,10 @@ async def misguided() -> None: async def test_asyncio_function_inside_nursery_does_not_explode() -> None: # Regression test for https://github.com/python-trio/trio/issues/552 - with pytest.raises(TypeError) as excinfo: + with RaisesGroup(Matcher(TypeError, "asyncio")): async with _core.open_nursery() as nursery: nursery.start_soon(sleep_forever) await create_asyncio_future_in_new_loop() - assert "asyncio" in str(excinfo.value) async def test_trivial_yields() -> None: @@ -1660,7 +1753,7 @@ async def noop_with_no_checkpoint() -> None: with _core.CancelScope() as cancel_scope: cancel_scope.cancel() - with pytest.raises(KeyError): + with RaisesGroup(KeyError): async with _core.open_nursery(): raise KeyError @@ -1702,6 +1795,7 @@ async def sleep_then_start( # calling started twice async def double_started( + *, task_status: _core.TaskStatus[None] = _core.TASK_STATUS_IGNORED, ) -> None: task_status.started() @@ -1713,6 +1807,7 @@ async def double_started( # child crashes before calling started -> error comes out of .start() async def raise_keyerror( + *, task_status: _core.TaskStatus[None] = _core.TASK_STATUS_IGNORED, ) -> None: raise KeyError("oops") @@ -1723,6 +1818,7 @@ async def raise_keyerror( # child exiting cleanly before calling started -> triggers a RuntimeError async def nothing( + *, task_status: _core.TaskStatus[None] = _core.TASK_STATUS_IGNORED, ) -> None: return @@ -1736,6 +1832,7 @@ async def nothing( # nothing -- the child keeps executing under start(). The value it passed # is ignored; start() raises Cancelled. async def just_started( + *, task_status: _core.TaskStatus[str] = _core.TASK_STATUS_IGNORED, ) -> None: task_status.started("hi") @@ -1781,7 +1878,8 @@ async def raise_keyerror_after_started( t0 = _core.current_time() with pytest.raises(RuntimeError): await closed_nursery.start(sleep_then_start, 7) - assert _core.current_time() == t0 + # sub-second delays can be caused by unrelated multitasking by an OS + assert int(_core.current_time()) == int(t0) async def test_task_nursery_stack() -> None: @@ -1789,7 +1887,7 @@ async def test_task_nursery_stack() -> None: assert task._child_nurseries == [] async with _core.open_nursery() as nursery1: assert task._child_nurseries == [nursery1] - with pytest.raises(KeyError): + with RaisesGroup(KeyError): async with _core.open_nursery() as nursery2: assert task._child_nurseries == [nursery1, nursery2] raise KeyError @@ -1882,7 +1980,7 @@ async def start_sleep_then_crash(nursery: _core.Nursery) -> None: async def test_nursery_explicit_exception() -> None: - with pytest.raises(KeyError): + with RaisesGroup(KeyError): async with _core.open_nursery(): raise KeyError() @@ -1891,12 +1989,10 @@ async def test_nursery_stop_iteration() -> None: async def fail() -> NoReturn: raise ValueError - try: + with RaisesGroup(StopIteration, ValueError): async with _core.open_nursery() as nursery: nursery.start_soon(fail) raise StopIteration - except MultiError as e: - assert tuple(map(type, e.exceptions)) == (StopIteration, ValueError) async def test_nursery_stop_async_iteration() -> None: @@ -1918,7 +2014,7 @@ def __init__(self, *largs: it) -> None: self.nexts = [obj.__anext__ for obj in largs] async def _accumulate( - self, f: Callable[[], Awaitable[int]], items: list[int | None], i: int + self, f: Callable[[], Awaitable[int]], items: list[int], i: int ) -> None: items[i] = await f() @@ -1929,15 +2025,26 @@ async def __anext__(self) -> list[int]: nexts = self.nexts items: list[int] = [-1] * len(nexts) - async with _core.open_nursery() as nursery: - for i, f in enumerate(nexts): - nursery.start_soon(self._accumulate, f, items, i) + try: + async with _core.open_nursery() as nursery: + for i, f in enumerate(nexts): + nursery.start_soon(self._accumulate, f, items, i) + except ExceptionGroup as e: + # With strict_exception_groups enabled, users now need to unwrap + # StopAsyncIteration and re-raise it. + # This would be relatively clean on python3.11+ with except*. + # We could also use RaisesGroup, but that's primarily meant as + # test infra, not as a runtime tool. + if len(e.exceptions) == 1 and isinstance( + e.exceptions[0], StopAsyncIteration + ): + raise e.exceptions[0] from None + else: # pragma: no cover + raise AssertionError("unknown error in _accumulate") from e return items - result: list[list[int]] = [] - async for vals in async_zip(it(4), it(2)): - result.append(vals) + result: list[list[int]] = [vals async for vals in async_zip(it(4), it(2))] assert result == [[0, 0], [1, 1]] @@ -1945,25 +2052,26 @@ async def test_traceback_frame_removal() -> None: async def my_child_task() -> NoReturn: raise KeyError() - try: + def check_traceback(exc: KeyError) -> bool: + # The top frame in the exception traceback should be inside the child + # task, not trio/contextvars internals. And there's only one frame + # inside the child task, so this will also detect if our frame-removal + # is too eager. + tb = exc.__traceback__ + assert tb is not None + return tb.tb_frame.f_code is my_child_task.__code__ + + expected_exception = Matcher(KeyError, check=check_traceback) + + with RaisesGroup(expected_exception, expected_exception): # Trick: For now cancel/nursery scopes still leave a bunch of tb gunk - # behind. But if there's a MultiError, they leave it on the MultiError, + # behind. But if there's an ExceptionGroup, they leave it on the group, # which lets us get a clean look at the KeyError itself. Someday I - # guess this will always be a MultiError (#611), but for now we can + # guess this will always be an ExceptionGroup (#611), but for now we can # force it by raising two exceptions. async with _core.open_nursery() as nursery: nursery.start_soon(my_child_task) nursery.start_soon(my_child_task) - except MultiError as exc: - first_exc = exc.exceptions[0] - assert isinstance(first_exc, KeyError) - # The top frame in the exception traceback should be inside the child - # task, not trio/contextvars internals. And there's only one frame - # inside the child task, so this will also detect if our frame-removal - # is too eager. - tb = first_exc.__traceback__ - assert tb is not None - assert tb.tb_frame.f_code is my_child_task.__code__ def test_contextvar_support() -> None: @@ -2155,7 +2263,7 @@ async def detachable_coroutine( # Check the exception paths too task = None pdco_outcome = None - with pytest.raises(KeyError): + with RaisesGroup(KeyError): async with _core.open_nursery() as nursery: nursery.start_soon(detachable_coroutine, outcome.Error(KeyError()), "uh oh") throw_in = ValueError() @@ -2297,6 +2405,9 @@ async def test_cancel_scope_deadline_duplicates() -> None: await sleep(0.01) +# I don't know if this one can fail anymore, the `del` next to the comment that used to +# refer to this only seems to break test_cancel_scope_exit_doesnt_create_cyclic_garbage +# We're keeping it for now to cover Outcome and potential future refactoring @pytest.mark.skipif( sys.implementation.name != "cpython", reason="Only makes sense with refcounting GC" ) @@ -2310,7 +2421,7 @@ async def do_a_cancel() -> None: await sleep_forever() async def crasher() -> NoReturn: - raise ValueError + raise ValueError("this is a crash") old_flags = gc.get_debug() try: @@ -2324,7 +2435,7 @@ async def crasher() -> NoReturn: # (See https://github.com/python-trio/trio/pull/1864) await do_a_cancel() - with pytest.raises(ValueError): + with RaisesGroup(Matcher(ValueError, "^this is a crash$")): async with _core.open_nursery() as nursery: # cover NurseryManager.__aexit__ nursery.start_soon(crasher) @@ -2344,11 +2455,13 @@ async def test_cancel_scope_exit_doesnt_create_cyclic_garbage() -> None: gc.collect() async def crasher() -> NoReturn: - raise ValueError + raise ValueError("this is a crash") old_flags = gc.get_debug() try: - with pytest.raises(ValueError), _core.CancelScope() as outer: + with RaisesGroup( + Matcher(ValueError, "^this is a crash$") + ), _core.CancelScope() as outer: async with _core.open_nursery() as nursery: gc.collect() gc.set_debug(gc.DEBUG_SAVEALL) @@ -2357,7 +2470,7 @@ async def crasher() -> NoReturn: outer.cancel() # And one that raises a different error nursery.start_soon(crasher) - # so that outer filters a Cancelled from the MultiError and + # so that outer filters a Cancelled from the ExceptionGroup and # covers CancelScope.__exit__ (and NurseryManager.__aexit__) # (See https://github.com/python-trio/trio/pull/2063) @@ -2428,97 +2541,84 @@ async def task() -> None: assert destroyed -def test_run_strict_exception_groups() -> None: - """ - Test that nurseries respect the global context setting of strict_exception_groups. - """ - - async def main() -> NoReturn: - async with _core.open_nursery(): - raise Exception("foo") - - with pytest.raises(MultiError) as exc: - _core.run(main, strict_exception_groups=True) - - assert len(exc.value.exceptions) == 1 - assert type(exc.value.exceptions[0]) is Exception - assert exc.value.exceptions[0].args == ("foo",) - - -def test_run_strict_exception_groups_nursery_override() -> None: - """ - Test that a nursery can override the global context setting of - strict_exception_groups. - """ - - async def main() -> NoReturn: - async with _core.open_nursery(strict_exception_groups=False): - raise Exception("foo") - - with pytest.raises(Exception, match="foo"): - _core.run(main, strict_exception_groups=True) - +def _create_kwargs(strictness: bool | None) -> dict[str, bool]: + """Turn a bool|None into a kwarg dict that can be passed to `run` or `open_nursery`""" -async def test_nursery_strict_exception_groups() -> None: - """Test that strict exception groups can be enabled on a per-nursery basis.""" - with pytest.raises(MultiError) as exc: - async with _core.open_nursery(strict_exception_groups=True): - raise Exception("foo") + if strictness is None: + return {} + return {"strict_exception_groups": strictness} - assert len(exc.value.exceptions) == 1 - assert type(exc.value.exceptions[0]) is Exception - assert exc.value.exceptions[0].args == ("foo",) - -async def test_nursery_collapse_strict() -> None: +@pytest.mark.filterwarnings( + "ignore:.*strict_exception_groups=False:trio.TrioDeprecationWarning" +) +@pytest.mark.parametrize("run_strict", [True, False, None]) +@pytest.mark.parametrize("open_nursery_strict", [True, False, None]) +@pytest.mark.parametrize("multiple_exceptions", [True, False]) +def test_setting_strict_exception_groups( + run_strict: bool | None, open_nursery_strict: bool | None, multiple_exceptions: bool +) -> None: """ - Test that a single exception from a nested nursery with strict semantics doesn't get - collapsed when CancelledErrors are stripped from it. + Test default values and that nurseries can both inherit and override the global context + setting of strict_exception_groups. """ async def raise_error() -> NoReturn: raise RuntimeError("test error") - with pytest.raises(MultiError) as exc: - async with _core.open_nursery() as nursery: - nursery.start_soon(sleep_forever) + async def main() -> None: + """Open a nursery, and raise one or two errors inside""" + async with _core.open_nursery(**_create_kwargs(open_nursery_strict)) as nursery: nursery.start_soon(raise_error) - async with _core.open_nursery(strict_exception_groups=True) as nursery2: - nursery2.start_soon(sleep_forever) - nursery2.start_soon(raise_error) - nursery.cancel_scope.cancel() - - exceptions = exc.value.exceptions - assert len(exceptions) == 2 - assert isinstance(exceptions[0], RuntimeError) - assert isinstance(exceptions[1], MultiError) - assert len(exceptions[1].exceptions) == 1 - assert isinstance(exceptions[1].exceptions[0], RuntimeError) - - -async def test_nursery_collapse_loose() -> None: + if multiple_exceptions: + nursery.start_soon(raise_error) + + def run_main() -> None: + # mypy doesn't like kwarg magic + _core.run(main, **_create_kwargs(run_strict)) # type: ignore[arg-type] + + matcher = Matcher(RuntimeError, "^test error$") + + if multiple_exceptions: + with RaisesGroup(matcher, matcher): + run_main() + elif open_nursery_strict or ( + open_nursery_strict is None and run_strict is not False + ): + with RaisesGroup(matcher): + run_main() + else: + with pytest.raises(RuntimeError, match="^test error$"): + run_main() + + +@pytest.mark.filterwarnings( + "ignore:.*strict_exception_groups=False:trio.TrioDeprecationWarning" +) +@pytest.mark.parametrize("strict", [True, False, None]) +async def test_nursery_collapse(strict: bool | None) -> None: """ - Test that a single exception from a nested nursery with loose semantics gets - collapsed when CancelledErrors are stripped from it. + Test that a single exception from a nested nursery gets collapsed correctly + depending on strict_exception_groups value when CancelledErrors are stripped from it. """ async def raise_error() -> NoReturn: raise RuntimeError("test error") - with pytest.raises(MultiError) as exc: + # mypy requires explicit type for conditional expression + maybe_wrapped_runtime_error: type[RuntimeError] | RaisesGroup[RuntimeError] = ( + RuntimeError if strict is False else RaisesGroup(RuntimeError) + ) + + with RaisesGroup(RuntimeError, maybe_wrapped_runtime_error): async with _core.open_nursery() as nursery: nursery.start_soon(sleep_forever) nursery.start_soon(raise_error) - async with _core.open_nursery() as nursery2: + async with _core.open_nursery(**_create_kwargs(strict)) as nursery2: nursery2.start_soon(sleep_forever) nursery2.start_soon(raise_error) nursery.cancel_scope.cancel() - exceptions = exc.value.exceptions - assert len(exceptions) == 2 - assert isinstance(exceptions[0], RuntimeError) - assert isinstance(exceptions[1], RuntimeError) - async def test_cancel_scope_no_cancellederror() -> None: """ @@ -2526,7 +2626,7 @@ async def test_cancel_scope_no_cancellederror() -> None: a Cancelled exception, it will NOT set the ``cancelled_caught`` flag. """ - with pytest.raises(ExceptionGroup): + with RaisesGroup(RuntimeError, RuntimeError, match="test"): with _core.CancelScope() as scope: scope.cancel() raise ExceptionGroup("test", [RuntimeError(), RuntimeError()]) @@ -2534,60 +2634,92 @@ async def test_cancel_scope_no_cancellederror() -> None: assert not scope.cancelled_caught -"""These tests are for fixing https://github.com/python-trio/trio/issues/2611 -where exceptions raised before `task_status.started()` got wrapped twice. -""" - - -async def raise_before(*, task_status: _core.TaskStatus[None]) -> None: - raise ValueError - - -async def raise_after_started(*, task_status: _core.TaskStatus[None]) -> None: - task_status.started() - raise ValueError - - -async def raise_custom_exception_group_before( - *, task_status: _core.TaskStatus[None] +@pytest.mark.filterwarnings( + "ignore:.*strict_exception_groups=False:trio.TrioDeprecationWarning" +) +@pytest.mark.parametrize("run_strict", [False, True]) +@pytest.mark.parametrize("start_raiser_strict", [False, True, None]) +@pytest.mark.parametrize("raise_after_started", [False, True]) +@pytest.mark.parametrize("raise_custom_exc_grp", [False, True]) +def test_trio_run_strict_before_started( + run_strict: bool, + start_raiser_strict: bool | None, + raise_after_started: bool, + raise_custom_exc_grp: bool, ) -> None: - raise ExceptionGroup("my group", [ValueError()]) - + """ + Regression tests for #2611, where exceptions raised before + `TaskStatus.started()` caused `Nursery.start()` to wrap them in an + ExceptionGroup when using `run(..., strict_exception_groups=True)`. -def _check_exception(exc: pytest.ExceptionInfo[BaseException]) -> None: - assert isinstance(exc.value, BaseExceptionGroup) - assert len(exc.value.exceptions) == 1 - assert isinstance(exc.value.exceptions[0], ValueError) + Regression tests for #2844, where #2611 was initially fixed in a way that + had unintended side effects. + """ + raiser_exc: ValueError | ExceptionGroup[ValueError] + if raise_custom_exc_grp: + raiser_exc = ExceptionGroup("my group", [ValueError()]) + else: + raiser_exc = ValueError() -async def _start_raiser( - raiser: Callable[[], Awaitable[None]], strict: bool | None = None -) -> None: - async with _core.open_nursery(strict_exception_groups=strict) as nursery: - await nursery.start(raiser) + async def raiser(*, task_status: _core.TaskStatus[None]) -> None: + if raise_after_started: + task_status.started() + raise raiser_exc + async def start_raiser() -> None: + try: + async with _core.open_nursery( + strict_exception_groups=start_raiser_strict + ) as nursery: + await nursery.start(raiser) + except BaseExceptionGroup as exc_group: + if start_raiser_strict: + # Iff the code using the nursery *forced* it to be strict + # (overriding the runner setting) then it may replace the bland + # exception group raised by trio with a more specific one (subtype, + # different message, etc.). + raise BaseExceptionGroup( + "start_raiser nursery custom message", exc_group.exceptions + ) from None + raise -@pytest.mark.parametrize("strict", [False, True]) -@pytest.mark.parametrize("raiser", [raise_before, raise_after_started]) -async def test_strict_before_started( - strict: bool, raiser: Callable[[], Awaitable[None]] -) -> None: - with pytest.raises(BaseExceptionGroup if strict else ValueError) as exc: - await _start_raiser(raiser, strict) - if strict: - _check_exception(exc) + with pytest.raises(BaseException) as exc_info: # noqa: PT011 # no `match` + _core.run(start_raiser, strict_exception_groups=run_strict) + + if start_raiser_strict or (run_strict and start_raiser_strict is None): + # start_raiser's nursery was strict. + assert isinstance(exc_info.value, BaseExceptionGroup) + if start_raiser_strict: + # start_raiser didn't unknowingly inherit its nursery strictness + # from `run`---it explicitly chose for its nursery to be strict. + assert exc_info.value.message == "start_raiser nursery custom message" + assert len(exc_info.value.exceptions) == 1 + should_be_raiser_exc = exc_info.value.exceptions[0] + else: + # start_raiser's nursery was not strict. + should_be_raiser_exc = exc_info.value + if isinstance(raiser_exc, ValueError): + assert should_be_raiser_exc is raiser_exc + else: + # Check attributes, not identity, because should_be_raiser_exc may be a + # copy of raiser_exc rather than raiser_exc by identity. + assert type(should_be_raiser_exc) is type(raiser_exc) + assert should_be_raiser_exc.message == raiser_exc.message + assert should_be_raiser_exc.exceptions == raiser_exc.exceptions + + +async def test_internal_error_old_nursery_multiple_tasks() -> None: + async def error_func() -> None: + raise ValueError + async def spawn_tasks_in_old_nursery(task_status: _core.TaskStatus[None]) -> None: + old_nursery = _core.current_task().parent_nursery + assert old_nursery is not None + old_nursery.start_soon(error_func) + old_nursery.start_soon(error_func) -# it was only when run from `trio.run` that the double wrapping happened -@pytest.mark.parametrize("strict", [False, True]) -@pytest.mark.parametrize( - "raiser", [raise_before, raise_after_started, raise_custom_exception_group_before] -) -def test_trio_run_strict_before_started( - strict: bool, raiser: Callable[[], Awaitable[None]] -) -> None: - expect_group = strict or raiser is raise_custom_exception_group_before - with pytest.raises(BaseExceptionGroup if expect_group else ValueError) as exc: - _core.run(_start_raiser, raiser, strict_exception_groups=strict) - if expect_group: - _check_exception(exc) + async with _core.open_nursery() as nursery: + with pytest.raises(_core.TrioInternalError) as excinfo: + await nursery.start(spawn_tasks_in_old_nursery) + assert RaisesGroup(ValueError, ValueError).matches(excinfo.value.__cause__) diff --git a/trio/_core/_tests/test_thread_cache.py b/src/trio/_core/_tests/test_thread_cache.py similarity index 95% rename from trio/_core/_tests/test_thread_cache.py rename to src/trio/_core/_tests/test_thread_cache.py index 77fdf46664..ee301d17fd 100644 --- a/trio/_core/_tests/test_thread_cache.py +++ b/src/trio/_core/_tests/test_thread_cache.py @@ -4,16 +4,17 @@ import time from contextlib import contextmanager from queue import Queue -from typing import Iterator, NoReturn +from typing import TYPE_CHECKING, Iterator, NoReturn import pytest -from outcome import Outcome -from pytest import MonkeyPatch from .. import _thread_cache from .._thread_cache import ThreadCache, start_thread_soon from .tutil import gc_collect_harder, slow +if TYPE_CHECKING: + from outcome import Outcome + def test_thread_cache_basics() -> None: q: Queue[Outcome[object]] = Queue() @@ -90,7 +91,7 @@ def deliver(n: int, _: object) -> None: @slow -def test_idle_threads_exit(monkeypatch: MonkeyPatch) -> None: +def test_idle_threads_exit(monkeypatch: pytest.MonkeyPatch) -> None: # Temporarily set the idle timeout to something tiny, to speed up the # test. (But non-zero, so that the worker loop will at least yield the # CPU.) @@ -117,7 +118,9 @@ def _join_started_threads() -> Iterator[None]: assert not thread.is_alive() -def test_race_between_idle_exit_and_job_assignment(monkeypatch: MonkeyPatch) -> None: +def test_race_between_idle_exit_and_job_assignment( + monkeypatch: pytest.MonkeyPatch, +) -> None: # This is a lock where the first few times you try to acquire it with a # timeout, it waits until the lock is available and then pretends to time # out. Using this in our thread cache implementation causes the following diff --git a/trio/_core/_tests/test_tutil.py b/src/trio/_core/_tests/test_tutil.py similarity index 100% rename from trio/_core/_tests/test_tutil.py rename to src/trio/_core/_tests/test_tutil.py diff --git a/trio/_core/_tests/test_unbounded_queue.py b/src/trio/_core/_tests/test_unbounded_queue.py similarity index 100% rename from trio/_core/_tests/test_unbounded_queue.py rename to src/trio/_core/_tests/test_unbounded_queue.py diff --git a/trio/_core/_tests/test_windows.py b/src/trio/_core/_tests/test_windows.py similarity index 94% rename from trio/_core/_tests/test_windows.py rename to src/trio/_core/_tests/test_windows.py index 2588c2530f..486a405590 100644 --- a/trio/_core/_tests/test_windows.py +++ b/src/trio/_core/_tests/test_windows.py @@ -3,9 +3,7 @@ import os import sys import tempfile -from collections.abc import Generator from contextlib import contextmanager -from io import BufferedWriter from typing import TYPE_CHECKING from unittest.mock import create_autospec @@ -23,6 +21,10 @@ from ...testing import wait_all_tasks_blocked from .tutil import gc_collect_harder, restore_unraisablehook, slow +if TYPE_CHECKING: + from collections.abc import Generator + from io import BufferedWriter + if on_windows: from .._windows_cffi import ( INVALID_HANDLE_VALUE, @@ -40,18 +42,20 @@ def test_winerror(monkeypatch: pytest.MonkeyPatch) -> None: # Returning none = no error, should not happen. mock.return_value = None - with pytest.raises(RuntimeError, match="No error set"): + with pytest.raises(RuntimeError, match=r"^No error set\?$"): raise_winerror() mock.assert_called_once_with() mock.reset_mock() - with pytest.raises(RuntimeError, match="No error set"): + with pytest.raises(RuntimeError, match=r"^No error set\?$"): raise_winerror(38) mock.assert_called_once_with(38) mock.reset_mock() mock.return_value = (12, "test error") - with pytest.raises(OSError) as exc: + with pytest.raises( + OSError, match=r"^\[WinError 12\] test error: 'file_1' -> 'file_2'$" + ) as exc: raise_winerror(filename="file_1", filename2="file_2") mock.assert_called_once_with() mock.reset_mock() @@ -61,7 +65,9 @@ def test_winerror(monkeypatch: pytest.MonkeyPatch) -> None: assert exc.value.filename2 == "file_2" # With an explicit number passed in, it overrides what getwinerror() returns. - with pytest.raises(OSError) as exc: + with pytest.raises( + OSError, match=r"^\[WinError 18\] test error: 'a/file' -> 'b/file'$" + ) as exc: raise_winerror(18, filename="a/file", filename2="b/file") mock.assert_called_once_with(18) mock.reset_mock() @@ -110,7 +116,7 @@ async def test_readinto_overlapped() -> None: with tempfile.TemporaryDirectory() as tdir: tfile = os.path.join(tdir, "numbers.txt") - with open( # noqa: ASYNC101 # This is a test, synchronous is ok + with open( # noqa: ASYNC230 # This is a test, synchronous is ok tfile, "wb" ) as fp: fp.write(data) @@ -218,7 +224,7 @@ async def test_too_late_to_cancel() -> None: # Note: not trio.sleep! We're making sure the OS level # ReadFile completes, before Trio has a chance to execute # another checkpoint and notice it completed. - time.sleep(1) # noqa: ASYNC101 + time.sleep(1) # noqa: ASYNC251 nursery.cancel_scope.cancel() assert target[:6] == b"test1\n" diff --git a/trio/_core/_tests/tutil.py b/src/trio/_core/_tests/tutil.py similarity index 87% rename from trio/_core/_tests/tutil.py rename to src/trio/_core/_tests/tutil.py index 6ed9b5fe14..81370ed76e 100644 --- a/trio/_core/_tests/tutil.py +++ b/src/trio/_core/_tests/tutil.py @@ -7,7 +7,6 @@ import socket as stdlib_socket import sys import warnings -from collections.abc import Generator, Iterable, Sequence from contextlib import closing, contextmanager from typing import TYPE_CHECKING, TypeVar @@ -16,22 +15,13 @@ # See trio/_tests/conftest.py for the other half of this from trio._tests.pytest_plugin import RUN_SLOW +if TYPE_CHECKING: + from collections.abc import Generator, Iterable, Sequence + slow = pytest.mark.skipif(not RUN_SLOW, reason="use --run-slow to run slow tests") T = TypeVar("T") -# PyPy 7.2 was released with a bug that just never called the async -# generator 'firstiter' hook at all. This impacts tests of end-of-run -# finalization (nothing gets added to runner.asyncgens) and tests of -# "foreign" async generator behavior (since the firstiter hook is what -# marks the asyncgen as foreign), but most tests of GC-mediated -# finalization still work. -buggy_pypy_asyncgens = ( - not TYPE_CHECKING - and sys.implementation.name == "pypy" - and sys.pypy_version_info < (7, 3) -) - try: s = stdlib_socket.socket(stdlib_socket.AF_INET6, stdlib_socket.SOCK_STREAM, 0) except OSError: # pragma: no cover @@ -86,7 +76,7 @@ def ignore_coroutine_never_awaited_warnings() -> Generator[None, None, None]: def _noop(*args: object, **kwargs: object) -> None: - pass + pass # pragma: no cover @contextmanager diff --git a/src/trio/_core/_tests/type_tests/nursery_start.py b/src/trio/_core/_tests/type_tests/nursery_start.py new file mode 100644 index 0000000000..77667590b9 --- /dev/null +++ b/src/trio/_core/_tests/type_tests/nursery_start.py @@ -0,0 +1,77 @@ +"""Test variadic generic typing for Nursery.start[_soon]().""" + +from typing import Awaitable, Callable + +from trio import TASK_STATUS_IGNORED, Nursery, TaskStatus + + +async def task_0() -> None: ... + + +async def task_1a(value: int) -> None: ... + + +async def task_1b(value: str) -> None: ... + + +async def task_2a(a: int, b: str) -> None: ... + + +async def task_2b(a: str, b: int) -> None: ... + + +async def task_2c(a: str, b: int, optional: bool = False) -> None: ... + + +async def task_requires_kw(a: int, *, b: bool) -> None: ... + + +async def task_startable_1( + a: str, + *, + task_status: TaskStatus[bool] = TASK_STATUS_IGNORED, +) -> None: ... + + +async def task_startable_2( + a: str, + b: float, + *, + task_status: TaskStatus[bool] = TASK_STATUS_IGNORED, +) -> None: ... + + +async def task_requires_start(*, task_status: TaskStatus[str]) -> None: + """Check a function requiring start() to be used.""" + + +async def task_pos_or_kw(value: str, task_status: TaskStatus[int]) -> None: + """Check a function which doesn't use the *-syntax works.""" + ... + + +def check_start_soon(nursery: Nursery) -> None: + """start_soon() functionality.""" + nursery.start_soon(task_0) + nursery.start_soon(task_1a) # type: ignore + nursery.start_soon(task_2b) # type: ignore + + nursery.start_soon(task_0, 45) # type: ignore + nursery.start_soon(task_1a, 32) + nursery.start_soon(task_1b, 32) # type: ignore + nursery.start_soon(task_1a, "abc") # type: ignore + nursery.start_soon(task_1b, "abc") + + nursery.start_soon(task_2b, "abc") # type: ignore + nursery.start_soon(task_2a, 38, "46") + nursery.start_soon(task_2c, "abc", 12, True) + + nursery.start_soon(task_2c, "abc", 12) + task_2c_cast: Callable[[str, int], Awaitable[object]] = ( + task_2c # The assignment makes it work. + ) + nursery.start_soon(task_2c_cast, "abc", 12) + + nursery.start_soon(task_requires_kw, 12, True) # type: ignore + # Tasks following the start() API can be made to work. + nursery.start_soon(task_startable_1, "cdf") diff --git a/src/trio/_core/_tests/type_tests/run.py b/src/trio/_core/_tests/type_tests/run.py new file mode 100644 index 0000000000..c121ce6c7a --- /dev/null +++ b/src/trio/_core/_tests/type_tests/run.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import Sequence, overload + +import trio +from typing_extensions import assert_type + + +async def sleep_sort(values: Sequence[float]) -> list[float]: + return [1] + + +async def has_optional(arg: int | None = None) -> int: + return 5 + + +@overload +async def foo_overloaded(arg: int) -> str: ... + + +@overload +async def foo_overloaded(arg: str) -> int: ... + + +async def foo_overloaded(arg: int | str) -> int | str: + if isinstance(arg, str): + return 5 + return "hello" + + +v = trio.run( + sleep_sort, (1, 3, 5, 2, 4), clock=trio.testing.MockClock(autojump_threshold=0) +) +assert_type(v, "list[float]") +trio.run(sleep_sort, ["hi", "there"]) # type: ignore[arg-type] +trio.run(sleep_sort) # type: ignore[arg-type] + +r = trio.run(has_optional) +assert_type(r, int) +r = trio.run(has_optional, 5) +trio.run(has_optional, 7, 8) # type: ignore[arg-type] +trio.run(has_optional, "hello") # type: ignore[arg-type] + + +assert_type(trio.run(foo_overloaded, 5), str) +assert_type(trio.run(foo_overloaded, ""), int) diff --git a/trio/_core/_thread_cache.py b/src/trio/_core/_thread_cache.py similarity index 94% rename from trio/_core/_thread_cache.py rename to src/trio/_core/_thread_cache.py index 60279c87e2..d338ec1ee7 100644 --- a/trio/_core/_thread_cache.py +++ b/src/trio/_core/_thread_cache.py @@ -41,10 +41,15 @@ def darwin_namefunc( setname(_to_os_thread_name(name)) # find the pthread library - # this will fail on windows + # this will fail on windows and musl libpthread_path = ctypes.util.find_library("pthread") if not libpthread_path: - return None + # musl includes pthread functions directly in libc.so + # (but note that find_library("c") does not work on musl, + # see: https://github.com/python/cpython/issues/65821) + # so try that library instead + # if it doesn't exist, CDLL() will fail below + libpthread_path = "libc.so" # Sometimes windows can find the path, but gives a permission error when # accessing it. Catching a wider exception in case of more esoteric errors. @@ -118,11 +123,14 @@ def darwin_namefunc( class WorkerThread(Generic[RetT]): def __init__(self, thread_cache: ThreadCache) -> None: - self._job: tuple[ - Callable[[], RetT], - Callable[[outcome.Outcome[RetT]], object], - str | None, - ] | None = None + self._job: ( + tuple[ + Callable[[], RetT], + Callable[[outcome.Outcome[RetT]], object], + str | None, + ] + | None + ) = None self._thread_cache = thread_cache # This Lock is used in an unconventional way. # diff --git a/trio/_core/_traps.py b/src/trio/_core/_traps.py similarity index 98% rename from trio/_core/_traps.py rename to src/trio/_core/_traps.py index f08fd69090..85e6b57306 100644 --- a/trio/_core/_traps.py +++ b/src/trio/_core/_traps.py @@ -1,11 +1,12 @@ """These are the only functions that ever yield back to the task runner.""" + from __future__ import annotations import enum import types from typing import TYPE_CHECKING, Any, Callable, NoReturn -import attr +import attrs import outcome from . import _run @@ -66,9 +67,9 @@ class Abort(enum.Enum): # Not exported in the trio._core namespace, but imported directly by _run. -@attr.s(frozen=True) +@attrs.frozen(slots=False) class WaitTaskRescheduled: - abort_func: Callable[[RaiseCancelT], Abort] = attr.ib() + abort_func: Callable[[RaiseCancelT], Abort] RaiseCancelT: TypeAlias = Callable[[], NoReturn] @@ -179,9 +180,9 @@ def abort(inner_raise_cancel): # Not exported in the trio._core namespace, but imported directly by _run. -@attr.s(frozen=True) +@attrs.frozen(slots=False) class PermanentlyDetachCoroutineObject: - final_outcome: outcome.Outcome[Any] = attr.ib() + final_outcome: outcome.Outcome[Any] async def permanently_detach_coroutine_object( diff --git a/trio/_core/_unbounded_queue.py b/src/trio/_core/_unbounded_queue.py similarity index 97% rename from trio/_core/_unbounded_queue.py rename to src/trio/_core/_unbounded_queue.py index 7c5c536676..b9ebe484d7 100644 --- a/trio/_core/_unbounded_queue.py +++ b/src/trio/_core/_unbounded_queue.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Generic, TypeVar -import attr +import attrs from .. import _core from .._deprecate import deprecated @@ -14,7 +14,7 @@ from typing_extensions import Self -@attr.s(slots=True, frozen=True) +@attrs.frozen class UnboundedQueueStatistics: """An object containing debugging information. @@ -26,8 +26,8 @@ class UnboundedQueueStatistics: """ - qsize: int = attr.ib() - tasks_waiting: int = attr.ib() + qsize: int + tasks_waiting: int @final @@ -66,6 +66,7 @@ class UnboundedQueue(Generic[T]): issue=497, thing="trio.lowlevel.UnboundedQueue", instead="trio.open_memory_channel(math.inf)", + use_triodeprecationwarning=True, ) def __init__(self) -> None: self._lot = _core.ParkingLot() diff --git a/trio/_core/_wakeup_socketpair.py b/src/trio/_core/_wakeup_socketpair.py similarity index 94% rename from trio/_core/_wakeup_socketpair.py rename to src/trio/_core/_wakeup_socketpair.py index aff28a1bd8..fb821a23e7 100644 --- a/trio/_core/_wakeup_socketpair.py +++ b/src/trio/_core/_wakeup_socketpair.py @@ -11,6 +11,10 @@ class WakeupSocketpair: def __init__(self) -> None: + # explicitly typed to please `pyright --verifytypes` without `--ignoreexternal` + self.wakeup_sock: socket.socket + self.write_sock: socket.socket + self.wakeup_sock, self.write_sock = socket.socketpair() self.wakeup_sock.setblocking(False) self.write_sock.setblocking(False) diff --git a/trio/_core/_windows_cffi.py b/src/trio/_core/_windows_cffi.py similarity index 95% rename from trio/_core/_windows_cffi.py rename to src/trio/_core/_windows_cffi.py index 72bfa64d68..244ea773c5 100644 --- a/trio/_core/_windows_cffi.py +++ b/src/trio/_core/_windows_cffi.py @@ -241,8 +241,7 @@ def CreateIoCompletionPort( CompletionKey: int, NumberOfConcurrentThreads: int, /, - ) -> Handle: - ... + ) -> Handle: ... def CreateEventA( self, @@ -251,13 +250,11 @@ def CreateEventA( bInitialState: bool, lpName: AlwaysNull, /, - ) -> Handle: - ... + ) -> Handle: ... def SetFileCompletionNotificationModes( self, handle: Handle, flags: CompletionModes, / - ) -> int: - ... + ) -> int: ... def PostQueuedCompletionStatus( self, @@ -266,16 +263,14 @@ def PostQueuedCompletionStatus( dwCompletionKey: int, lpOverlapped: CData | AlwaysNull, /, - ) -> bool: - ... + ) -> bool: ... def CancelIoEx( self, hFile: Handle, lpOverlapped: CData | AlwaysNull, /, - ) -> bool: - ... + ) -> bool: ... def WriteFile( self, @@ -286,8 +281,7 @@ def WriteFile( lpNumberOfBytesWritten: AlwaysNull, lpOverlapped: _Overlapped, /, - ) -> bool: - ... + ) -> bool: ... def ReadFile( self, @@ -298,8 +292,7 @@ def ReadFile( lpNumberOfBytesRead: AlwaysNull, lpOverlapped: _Overlapped, /, - ) -> bool: - ... + ) -> bool: ... def GetQueuedCompletionStatusEx( self, @@ -310,8 +303,7 @@ def GetQueuedCompletionStatusEx( dwMilliseconds: int, fAlertable: bool | int, /, - ) -> CData: - ... + ) -> CData: ... def CreateFileW( self, @@ -323,11 +315,9 @@ def CreateFileW( dwFlagsAndAttributes: FileFlags, hTemplateFile: AlwaysNull, /, - ) -> Handle: - ... + ) -> Handle: ... - def WaitForSingleObject(self, hHandle: Handle, dwMilliseconds: int, /) -> CData: - ... + def WaitForSingleObject(self, hHandle: Handle, dwMilliseconds: int, /) -> CData: ... def WaitForMultipleObjects( self, @@ -336,14 +326,11 @@ def WaitForMultipleObjects( bWaitAll: bool, dwMilliseconds: int, /, - ) -> ErrorCodes: - ... + ) -> ErrorCodes: ... - def SetEvent(self, handle: Handle, /) -> None: - ... + def SetEvent(self, handle: Handle, /) -> None: ... - def CloseHandle(self, handle: Handle, /) -> bool: - ... + def CloseHandle(self, handle: Handle, /) -> bool: ... def DeviceIoControl( self, @@ -358,22 +345,19 @@ def DeviceIoControl( lpBytesReturned: AlwaysNull, lpOverlapped: CData, /, - ) -> bool: - ... + ) -> bool: ... class _Nt(Protocol): """Statically typed version of the dtdll.dll functions we use.""" - def RtlNtStatusToDosError(self, status: int, /) -> ErrorCodes: - ... + def RtlNtStatusToDosError(self, status: int, /) -> ErrorCodes: ... class _Ws2(Protocol): """Statically typed version of the ws2_32.dll functions we use.""" - def WSAGetLastError(self) -> int: - ... + def WSAGetLastError(self) -> int: ... def WSAIoctl( self, @@ -388,8 +372,7 @@ def WSAIoctl( # actually LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine: AlwaysNull, /, - ) -> int: - ... + ) -> int: ... class _DummyStruct(Protocol): diff --git a/trio/_deprecate.py b/src/trio/_deprecate.py similarity index 84% rename from trio/_deprecate.py rename to src/trio/_deprecate.py index 0a9553b854..51c51f7378 100644 --- a/trio/_deprecate.py +++ b/src/trio/_deprecate.py @@ -2,14 +2,15 @@ import sys import warnings -from collections.abc import Callable from functools import wraps from types import ModuleType from typing import TYPE_CHECKING, ClassVar, TypeVar -import attr +import attrs if TYPE_CHECKING: + from collections.abc import Callable + from typing_extensions import ParamSpec ArgsT = ParamSpec("ArgsT") @@ -57,6 +58,7 @@ def warn_deprecated( issue: int | None, instead: object, stacklevel: int = 2, + use_triodeprecationwarning: bool = False, ) -> None: stacklevel += 1 msg = f"{_stringify(thing)} is deprecated since Trio {version}" @@ -66,20 +68,35 @@ def warn_deprecated( msg += f"; use {_stringify(instead)} instead" if issue is not None: msg += f" ({_url_for_issue(issue)})" - warnings.warn(TrioDeprecationWarning(msg), stacklevel=stacklevel) + if use_triodeprecationwarning: + warning_class: type[Warning] = TrioDeprecationWarning + else: + warning_class = DeprecationWarning + warnings.warn(warning_class(msg), stacklevel=stacklevel) # @deprecated("0.2.0", issue=..., instead=...) # def ... def deprecated( - version: str, *, thing: object = None, issue: int | None, instead: object + version: str, + *, + thing: object = None, + issue: int | None, + instead: object, + use_triodeprecationwarning: bool = False, ) -> Callable[[Callable[ArgsT, RetT]], Callable[ArgsT, RetT]]: def do_wrap(fn: Callable[ArgsT, RetT]) -> Callable[ArgsT, RetT]: nonlocal thing @wraps(fn) def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: - warn_deprecated(thing, version, instead=instead, issue=issue) + warn_deprecated( + thing, + version, + instead=instead, + issue=issue, + use_triodeprecationwarning=use_triodeprecationwarning, + ) return fn(*args, **kwargs) # If our __module__ or __qualname__ get modified, we want to pick up @@ -96,9 +113,7 @@ def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: if instead is not None: doc += f" Use {_stringify(instead)} instead.\n" if issue is not None: - doc += " For details, see `issue #{} <{}>`__.\n".format( - issue, _url_for_issue(issue) - ) + doc += f" For details, see `issue #{issue} <{_url_for_issue(issue)}>`__.\n" doc += "\n" wrapper.__doc__ = doc @@ -125,14 +140,14 @@ def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: return wrapper -@attr.s(frozen=True) +@attrs.frozen(slots=False) class DeprecatedAttribute: _not_set: ClassVar[object] = object() - value: object = attr.ib() - version: str = attr.ib() - issue: int | None = attr.ib() - instead: object = attr.ib(default=_not_set) + value: object + version: str + issue: int | None + instead: object = _not_set class _ModuleWithDeprecations(ModuleType): diff --git a/trio/_dtls.py b/src/trio/_dtls.py similarity index 97% rename from trio/_dtls.py rename to src/trio/_dtls.py index beb7058b46..31f7817e1c 100644 --- a/trio/_dtls.py +++ b/src/trio/_dtls.py @@ -30,7 +30,7 @@ ) from weakref import ReferenceType, WeakValueDictionary -import attr +import attrs import trio @@ -40,12 +40,13 @@ from types import TracebackType # See DTLSEndpoint.__init__ for why this is imported here - from OpenSSL import SSL - from OpenSSL.SSL import Context - from typing_extensions import Self, TypeAlias + from OpenSSL import SSL # noqa: TCH004 + from typing_extensions import Self, TypeAlias, TypeVarTuple, Unpack from trio.socket import SocketType + PosArgsT = TypeVarTuple("PosArgsT") + MAX_UDP_PACKET_SIZE = 65527 @@ -159,12 +160,12 @@ def to_hex(data: bytes) -> str: # pragma: no cover return data.hex() -@attr.frozen +@attrs.frozen class Record: content_type: int - version: bytes = attr.ib(repr=to_hex) + version: bytes = attrs.field(repr=to_hex) epoch_seqno: int - payload: bytes = attr.ib(repr=to_hex) + payload: bytes = attrs.field(repr=to_hex) def records_untrusted(packet: bytes) -> Iterator[Record]: @@ -204,14 +205,14 @@ def encode_record(record: Record) -> bytes: HANDSHAKE_MESSAGE_HEADER = struct.Struct("!B3sH3s3s") -@attr.frozen +@attrs.frozen class HandshakeFragment: msg_type: int msg_len: int msg_seq: int frag_offset: int frag_len: int - frag: bytes = attr.ib(repr=to_hex) + frag: bytes = attrs.field(repr=to_hex) def decode_handshake_fragment_untrusted(payload: bytes) -> HandshakeFragment: @@ -324,21 +325,21 @@ def decode_client_hello_untrusted(packet: bytes) -> tuple[int, bytes, bytes]: raise BadPacket("bad ClientHello") from exc -@attr.frozen +@attrs.frozen class HandshakeMessage: - record_version: bytes = attr.ib(repr=to_hex) + record_version: bytes = attrs.field(repr=to_hex) msg_type: HandshakeType msg_seq: int - body: bytearray = attr.ib(repr=to_hex) + body: bytearray = attrs.field(repr=to_hex) # ChangeCipherSpec is part of the handshake, but it's not a "handshake # message" and can't be fragmented the same way. Sigh. -@attr.frozen +@attrs.frozen class PseudoHandshakeMessage: - record_version: bytes = attr.ib(repr=to_hex) + record_version: bytes = attrs.field(repr=to_hex) content_type: int - payload: bytes = attr.ib(repr=to_hex) + payload: bytes = attrs.field(repr=to_hex) # The final record in a handshake is Finished, which is encrypted, can't be fragmented @@ -346,7 +347,7 @@ class PseudoHandshakeMessage: # just pass it through unchanged. (Fortunately, the payload is only a single hash value, # so the largest it will ever be is 64 bytes for a 512-bit hash. Which is small enough # that it never requires fragmenting to fit into a UDP packet. -@attr.frozen +@attrs.frozen class OpaqueHandshakeMessage: record: Record @@ -667,7 +668,12 @@ def _read_loop(read_fn: Callable[[int], bytes]) -> bytes: async def handle_client_hello_untrusted( endpoint: DTLSEndpoint, address: Any, packet: bytes ) -> None: - if endpoint._listening_context is None: + # it's trivial to write a simple function that directly calls this to + # get code coverage, but it should maybe: + # 1. be removed + # 2. be asserted + # 3. Write a complicated test case where this happens "organically" + if endpoint._listening_context is None: # pragma: no cover return try: @@ -703,7 +709,7 @@ async def handle_client_hello_untrusted( try: stream._ssl.bio_write(packet) stream._ssl.DTLSv1_listen() - except SSL.Error: + except SSL.Error: # pragma: no cover # ...OpenSSL didn't like it, so I guess we didn't have a valid ClientHello # after all. return @@ -797,7 +803,7 @@ async def dtls_receive_loop( raise -@attr.frozen +@attrs.frozen class DTLSChannelStatistics: """Currently this has only one attribute: @@ -830,7 +836,12 @@ class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor): """ - def __init__(self, endpoint: DTLSEndpoint, peer_address: Any, ctx: Context): + def __init__( + self, + endpoint: DTLSEndpoint, + peer_address: Any, + ctx: SSL.Context, + ) -> None: self.endpoint = endpoint self.peer_address = peer_address self._packets_dropped_in_trio = 0 @@ -1176,7 +1187,12 @@ class DTLSEndpoint: """ - def __init__(self, socket: SocketType, *, incoming_packets_buffer: int = 10): + def __init__( + self, + socket: SocketType, + *, + incoming_packets_buffer: int = 10, + ) -> None: # We do this lazily on first construction, so only people who actually use DTLS # have to install PyOpenSSL. global SSL @@ -1197,7 +1213,7 @@ def __init__(self, socket: SocketType, *, incoming_packets_buffer: int = 10): # old connection. # {remote address: DTLSChannel} self._streams: WeakValueDictionary[Any, DTLSChannel] = WeakValueDictionary() - self._listening_context: Context | None = None + self._listening_context: SSL.Context | None = None self._listening_key: bytes | None = None self._incoming_connections_q = _Queue[DTLSChannel](float("inf")) self._send_lock = trio.Lock() @@ -1258,14 +1274,11 @@ def _check_closed(self) -> None: if self._closed: raise trio.ClosedResourceError - # async_fn cannot be typed with ParamSpec, since we don't accept - # kwargs. Can be typed with TypeVarTuple once it's fully supported - # in mypy. async def serve( self, - ssl_context: Context, - async_fn: Callable[..., Awaitable[object]], - *args: Any, + ssl_context: SSL.Context, + async_fn: Callable[[DTLSChannel, Unpack[PosArgsT]], Awaitable[object]], + *args: Unpack[PosArgsT], task_status: trio.TaskStatus[None] = trio.TASK_STATUS_IGNORED, ) -> None: """Listen for incoming connections, and spawn a handler for each using an @@ -1294,6 +1307,7 @@ async def handler(dtls_channel): incoming connections. async_fn: The handler function that will be invoked for each incoming connection. + *args: Additional arguments to pass to the handler function. """ self._check_closed() @@ -1324,7 +1338,11 @@ async def handler_wrapper(stream: DTLSChannel) -> None: finally: self._listening_context = None - def connect(self, address: tuple[str, int], ssl_context: Context) -> DTLSChannel: + def connect( + self, + address: tuple[str, int], + ssl_context: SSL.Context, + ) -> DTLSChannel: """Initiate an outgoing DTLS connection. Notice that this is a synchronous method. That's because it doesn't actually diff --git a/trio/_file_io.py b/src/trio/_file_io.py similarity index 97% rename from trio/_file_io.py rename to src/trio/_file_io.py index 958740e15e..ef867243f0 100644 --- a/trio/_file_io.py +++ b/src/trio/_file_io.py @@ -184,7 +184,7 @@ class _CanReadLine(Protocol[AnyStr_co]): def readline(self, size: int = ..., /) -> AnyStr_co: ... class _CanReadLines(Protocol[AnyStr]): - def readlines(self, hint: int = ...) -> list[AnyStr]: ... + def readlines(self, hint: int = ..., /) -> list[AnyStr]: ... class _CanSeek(Protocol): def seek(self, target: int, whence: int = 0, /) -> int: ... @@ -355,8 +355,7 @@ async def open_file( newline: str | None = None, closefd: bool = True, opener: _Opener | None = None, -) -> AsyncIOWrapper[io.TextIOWrapper]: - ... +) -> AsyncIOWrapper[io.TextIOWrapper]: ... @overload @@ -369,8 +368,7 @@ async def open_file( newline: None = None, closefd: bool = True, opener: _Opener | None = None, -) -> AsyncIOWrapper[io.FileIO]: - ... +) -> AsyncIOWrapper[io.FileIO]: ... @overload @@ -383,8 +381,7 @@ async def open_file( newline: None = None, closefd: bool = True, opener: _Opener | None = None, -) -> AsyncIOWrapper[io.BufferedRandom]: - ... +) -> AsyncIOWrapper[io.BufferedRandom]: ... @overload @@ -397,8 +394,7 @@ async def open_file( newline: None = None, closefd: bool = True, opener: _Opener | None = None, -) -> AsyncIOWrapper[io.BufferedWriter]: - ... +) -> AsyncIOWrapper[io.BufferedWriter]: ... @overload @@ -411,8 +407,7 @@ async def open_file( newline: None = None, closefd: bool = True, opener: _Opener | None = None, -) -> AsyncIOWrapper[io.BufferedReader]: - ... +) -> AsyncIOWrapper[io.BufferedReader]: ... @overload @@ -425,8 +420,7 @@ async def open_file( newline: None = None, closefd: bool = True, opener: _Opener | None = None, -) -> AsyncIOWrapper[BinaryIO]: - ... +) -> AsyncIOWrapper[BinaryIO]: ... @overload @@ -439,8 +433,7 @@ async def open_file( # type: ignore[misc] # Any usage matches builtins.open(). newline: str | None = None, closefd: bool = True, opener: _Opener | None = None, -) -> AsyncIOWrapper[IO[Any]]: - ... +) -> AsyncIOWrapper[IO[Any]]: ... async def open_file( diff --git a/trio/_highlevel_generic.py b/src/trio/_highlevel_generic.py similarity index 97% rename from trio/_highlevel_generic.py rename to src/trio/_highlevel_generic.py index 9c7878f2b9..88a86318a3 100644 --- a/trio/_highlevel_generic.py +++ b/src/trio/_highlevel_generic.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Generic, TypeVar -import attr +import attrs import trio from trio._util import final @@ -53,7 +53,7 @@ def _is_halfclosable(stream: SendStream) -> TypeGuard[HalfCloseableStream]: @final -@attr.s(eq=False, hash=False) +@attrs.define(eq=False, hash=False, slots=False) class StapledStream( HalfCloseableStream, Generic[SendStreamT, ReceiveStreamT], @@ -92,8 +92,8 @@ class StapledStream( """ - send_stream: SendStreamT = attr.ib() - receive_stream: ReceiveStreamT = attr.ib() + send_stream: SendStreamT + receive_stream: ReceiveStreamT async def send_all(self, data: bytes | bytearray | memoryview) -> None: """Calls ``self.send_stream.send_all``.""" diff --git a/trio/_highlevel_open_tcp_listeners.py b/src/trio/_highlevel_open_tcp_listeners.py similarity index 95% rename from trio/_highlevel_open_tcp_listeners.py rename to src/trio/_highlevel_open_tcp_listeners.py index 33cb7e1a64..80555be33e 100644 --- a/trio/_highlevel_open_tcp_listeners.py +++ b/src/trio/_highlevel_open_tcp_listeners.py @@ -1,15 +1,16 @@ from __future__ import annotations import errno -import math import sys -from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING import trio from trio import TaskStatus from . import socket as tsocket -from ._deprecate import warn_deprecated + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable if sys.version_info < (3, 11): from exceptiongroup import ExceptionGroup @@ -46,14 +47,6 @@ def _compute_backlog(backlog: int | None) -> int: # Many systems (Linux, BSDs, ...) store the backlog in a uint16 and are # missing overflow protection, so we apply our own overflow protection. # https://github.com/golang/go/issues/5030 - if backlog == math.inf: - backlog = None - warn_deprecated( - thing="math.inf as a backlog", - version="0.23.0", - instead="None", - issue=2842, - ) if not isinstance(backlog, int) and backlog is not None: raise TypeError(f"backlog must be an int or None, not {backlog!r}") if backlog is None: @@ -125,9 +118,9 @@ async def open_tcp_listeners( listeners = [] unsupported_address_families = [] try: - for family, type, proto, _, sockaddr in addresses: + for family, type_, proto, _, sockaddr in addresses: try: - sock = tsocket.socket(family, type, proto) + sock = tsocket.socket(family, type_, proto) except OSError as ex: if ex.errno == errno.EAFNOSUPPORT: # If a system only supports IPv4, or only IPv6, it diff --git a/trio/_highlevel_open_tcp_stream.py b/src/trio/_highlevel_open_tcp_stream.py similarity index 98% rename from trio/_highlevel_open_tcp_stream.py rename to src/trio/_highlevel_open_tcp_stream.py index 121724446a..d5c83da7c0 100644 --- a/trio/_highlevel_open_tcp_stream.py +++ b/src/trio/_highlevel_open_tcp_stream.py @@ -1,17 +1,18 @@ from __future__ import annotations import sys -from collections.abc import Generator from contextlib import contextmanager, suppress -from socket import AddressFamily, SocketKind from typing import TYPE_CHECKING, Any import trio -from trio._core._multierror import MultiError from trio.socket import SOCK_STREAM, SocketType, getaddrinfo, socket +if TYPE_CHECKING: + from collections.abc import Generator + from socket import AddressFamily, SocketKind + if sys.version_info < (3, 11): - from exceptiongroup import ExceptionGroup + from exceptiongroup import BaseExceptionGroup, ExceptionGroup # Implementation of RFC 6555 "Happy eyeballs" @@ -128,7 +129,7 @@ def close_all() -> Generator[set[SocketType], None, None]: if len(errs) == 1: raise errs[0] elif errs: - raise MultiError(errs) + raise BaseExceptionGroup("", errs) def reorder_for_rfc_6555_section_5_4( diff --git a/trio/_highlevel_open_unix_stream.py b/src/trio/_highlevel_open_unix_stream.py similarity index 91% rename from trio/_highlevel_open_unix_stream.py rename to src/trio/_highlevel_open_unix_stream.py index c05b8f3fc8..d419574369 100644 --- a/trio/_highlevel_open_unix_stream.py +++ b/src/trio/_highlevel_open_unix_stream.py @@ -1,17 +1,18 @@ from __future__ import annotations import os -from collections.abc import Generator from contextlib import contextmanager -from typing import Protocol, TypeVar +from typing import TYPE_CHECKING, Protocol, TypeVar import trio from trio.socket import SOCK_STREAM, socket +if TYPE_CHECKING: + from collections.abc import Generator + class Closable(Protocol): - def close(self) -> None: - ... + def close(self) -> None: ... CloseT = TypeVar("CloseT", bound=Closable) diff --git a/trio/_highlevel_serve_listeners.py b/src/trio/_highlevel_serve_listeners.py similarity index 100% rename from trio/_highlevel_serve_listeners.py rename to src/trio/_highlevel_serve_listeners.py diff --git a/trio/_highlevel_socket.py b/src/trio/_highlevel_socket.py similarity index 98% rename from trio/_highlevel_socket.py rename to src/trio/_highlevel_socket.py index fe20ff527d..901e22f345 100644 --- a/trio/_highlevel_socket.py +++ b/src/trio/_highlevel_socket.py @@ -2,7 +2,6 @@ from __future__ import annotations import errno -from collections.abc import Generator from contextlib import contextmanager, suppress from typing import TYPE_CHECKING, overload @@ -13,6 +12,8 @@ from .abc import HalfCloseableStream, Listener if TYPE_CHECKING: + from collections.abc import Generator + from typing_extensions import Buffer from ._socket import SocketType @@ -146,12 +147,10 @@ async def aclose(self) -> None: # __aenter__, __aexit__ inherited from HalfCloseableStream are OK @overload - def setsockopt(self, level: int, option: int, value: int | Buffer) -> None: - ... + def setsockopt(self, level: int, option: int, value: int | Buffer) -> None: ... @overload - def setsockopt(self, level: int, option: int, value: None, length: int) -> None: - ... + def setsockopt(self, level: int, option: int, value: None, length: int) -> None: ... def setsockopt( self, @@ -178,12 +177,10 @@ def setsockopt( return self.socket.setsockopt(level, option, value, length) @overload - def getsockopt(self, level: int, option: int) -> int: - ... + def getsockopt(self, level: int, option: int) -> int: ... @overload - def getsockopt(self, level: int, option: int, buffersize: int) -> bytes: - ... + def getsockopt(self, level: int, option: int, buffersize: int) -> bytes: ... def getsockopt(self, level: int, option: int, buffersize: int = 0) -> int | bytes: """Check the current value of an option on the underlying socket. diff --git a/trio/_highlevel_ssl_helpers.py b/src/trio/_highlevel_ssl_helpers.py similarity index 97% rename from trio/_highlevel_ssl_helpers.py rename to src/trio/_highlevel_ssl_helpers.py index 3187b3ca00..03562c9edb 100644 --- a/trio/_highlevel_ssl_helpers.py +++ b/src/trio/_highlevel_ssl_helpers.py @@ -1,16 +1,19 @@ from __future__ import annotations import ssl -from collections.abc import Awaitable, Callable -from typing import NoReturn, TypeVar +from typing import TYPE_CHECKING, NoReturn, TypeVar import trio from ._highlevel_open_tcp_stream import DEFAULT_DELAY -from ._highlevel_socket import SocketStream T = TypeVar("T") +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + + from ._highlevel_socket import SocketStream + # It might have been nice to take a ssl_protocols= argument here to set up # NPN/ALPN, but to do this we have to mutate the context object, which is OK diff --git a/src/trio/_path.py b/src/trio/_path.py new file mode 100644 index 0000000000..b9b5749c35 --- /dev/null +++ b/src/trio/_path.py @@ -0,0 +1,264 @@ +from __future__ import annotations + +import os +import pathlib +import sys +from functools import partial, update_wrapper +from inspect import cleandoc +from typing import IO, TYPE_CHECKING, Any, BinaryIO, ClassVar, TypeVar, overload + +from trio._file_io import AsyncIOWrapper, wrap_file +from trio._util import final +from trio.to_thread import run_sync + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable, Iterable + from io import BufferedRandom, BufferedReader, BufferedWriter, FileIO, TextIOWrapper + + from _typeshed import ( + OpenBinaryMode, + OpenBinaryModeReading, + OpenBinaryModeUpdating, + OpenBinaryModeWriting, + OpenTextMode, + ) + from typing_extensions import Concatenate, Literal, ParamSpec, Self + + P = ParamSpec("P") + + PathT = TypeVar("PathT", bound="Path") + T = TypeVar("T") + + +def _wraps_async( + wrapped: Callable[..., Any] +) -> Callable[[Callable[P, T]], Callable[P, Awaitable[T]]]: + def decorator(fn: Callable[P, T]) -> Callable[P, Awaitable[T]]: + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + return await run_sync(partial(fn, *args, **kwargs)) + + update_wrapper(wrapper, wrapped) + if wrapped.__doc__: + wrapper.__doc__ = ( + f"Like :meth:`~{wrapped.__module__}.{wrapped.__qualname__}`, but async.\n" + f"\n" + f"{cleandoc(wrapped.__doc__)}\n" + ) + return wrapper + + return decorator + + +def _wrap_method( + fn: Callable[Concatenate[pathlib.Path, P], T], +) -> Callable[Concatenate[Path, P], Awaitable[T]]: + @_wraps_async(fn) + def wrapper(self: Path, /, *args: P.args, **kwargs: P.kwargs) -> T: + return fn(self._wrapped_cls(self), *args, **kwargs) + + return wrapper + + +def _wrap_method_path( + fn: Callable[Concatenate[pathlib.Path, P], pathlib.Path], +) -> Callable[Concatenate[PathT, P], Awaitable[PathT]]: + @_wraps_async(fn) + def wrapper(self: PathT, /, *args: P.args, **kwargs: P.kwargs) -> PathT: + return self.__class__(fn(self._wrapped_cls(self), *args, **kwargs)) + + return wrapper + + +def _wrap_method_path_iterable( + fn: Callable[Concatenate[pathlib.Path, P], Iterable[pathlib.Path]], +) -> Callable[Concatenate[PathT, P], Awaitable[Iterable[PathT]]]: + @_wraps_async(fn) + def wrapper(self: PathT, /, *args: P.args, **kwargs: P.kwargs) -> Iterable[PathT]: + return map(self.__class__, [*fn(self._wrapped_cls(self), *args, **kwargs)]) + + if wrapper.__doc__: + wrapper.__doc__ += ( + f"\n" + f"This is an async method that returns a synchronous iterator, so you\n" + f"use it like:\n" + f"\n" + f".. code:: python\n" + f"\n" + f" for subpath in await mypath.{fn.__name__}():\n" + f" ...\n" + f"\n" + f".. note::\n" + f"\n" + f" The iterator is loaded into memory immediately during the initial\n" + f" call (see `issue #501\n" + f" `__ for discussion).\n" + ) + return wrapper + + +class Path(pathlib.PurePath): + """An async :class:`pathlib.Path` that executes blocking methods in :meth:`trio.to_thread.run_sync`. + + Instantiating :class:`Path` returns a concrete platform-specific subclass, one of :class:`PosixPath` or + :class:`WindowsPath`. + """ + + __slots__ = () + + _wrapped_cls: ClassVar[type[pathlib.Path]] + + def __new__(cls, *args: str | os.PathLike[str]) -> Self: + if cls is Path: + cls = WindowsPath if os.name == "nt" else PosixPath # type: ignore[assignment] + return super().__new__(cls, *args) + + @classmethod + @_wraps_async(pathlib.Path.cwd) + def cwd(cls) -> Self: + return cls(pathlib.Path.cwd()) + + @classmethod + @_wraps_async(pathlib.Path.home) + def home(cls) -> Self: + return cls(pathlib.Path.home()) + + @overload + async def open( + self, + mode: OpenTextMode = "r", + buffering: int = -1, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + ) -> AsyncIOWrapper[TextIOWrapper]: ... + + @overload + async def open( + self, + mode: OpenBinaryMode, + buffering: Literal[0], + encoding: None = None, + errors: None = None, + newline: None = None, + ) -> AsyncIOWrapper[FileIO]: ... + + @overload + async def open( + self, + mode: OpenBinaryModeUpdating, + buffering: Literal[-1, 1] = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + ) -> AsyncIOWrapper[BufferedRandom]: ... + + @overload + async def open( + self, + mode: OpenBinaryModeWriting, + buffering: Literal[-1, 1] = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + ) -> AsyncIOWrapper[BufferedWriter]: ... + + @overload + async def open( + self, + mode: OpenBinaryModeReading, + buffering: Literal[-1, 1] = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + ) -> AsyncIOWrapper[BufferedReader]: ... + + @overload + async def open( + self, + mode: OpenBinaryMode, + buffering: int = -1, + encoding: None = None, + errors: None = None, + newline: None = None, + ) -> AsyncIOWrapper[BinaryIO]: ... + + @overload + async def open( # type: ignore[misc] # Any usage matches builtins.open(). + self, + mode: str, + buffering: int = -1, + encoding: str | None = None, + errors: str | None = None, + newline: str | None = None, + ) -> AsyncIOWrapper[IO[Any]]: ... + + @_wraps_async(pathlib.Path.open) # type: ignore[misc] # Overload return mismatch. + def open(self, *args: Any, **kwargs: Any) -> AsyncIOWrapper[IO[Any]]: + return wrap_file(self._wrapped_cls(self).open(*args, **kwargs)) + + def __repr__(self) -> str: + return f"trio.Path({str(self)!r})" + + stat = _wrap_method(pathlib.Path.stat) + chmod = _wrap_method(pathlib.Path.chmod) + exists = _wrap_method(pathlib.Path.exists) + glob = _wrap_method_path_iterable(pathlib.Path.glob) + rglob = _wrap_method_path_iterable(pathlib.Path.rglob) + is_dir = _wrap_method(pathlib.Path.is_dir) + is_file = _wrap_method(pathlib.Path.is_file) + is_symlink = _wrap_method(pathlib.Path.is_symlink) + is_socket = _wrap_method(pathlib.Path.is_socket) + is_fifo = _wrap_method(pathlib.Path.is_fifo) + is_block_device = _wrap_method(pathlib.Path.is_block_device) + is_char_device = _wrap_method(pathlib.Path.is_char_device) + if sys.version_info >= (3, 12): + is_junction = _wrap_method(pathlib.Path.is_junction) + iterdir = _wrap_method_path_iterable(pathlib.Path.iterdir) + lchmod = _wrap_method(pathlib.Path.lchmod) + lstat = _wrap_method(pathlib.Path.lstat) + mkdir = _wrap_method(pathlib.Path.mkdir) + if sys.platform != "win32": + owner = _wrap_method(pathlib.Path.owner) + group = _wrap_method(pathlib.Path.group) + if sys.platform != "win32" or sys.version_info >= (3, 12): + is_mount = _wrap_method(pathlib.Path.is_mount) + if sys.version_info >= (3, 9): + readlink = _wrap_method_path(pathlib.Path.readlink) + rename = _wrap_method_path(pathlib.Path.rename) + replace = _wrap_method_path(pathlib.Path.replace) + resolve = _wrap_method_path(pathlib.Path.resolve) + rmdir = _wrap_method(pathlib.Path.rmdir) + symlink_to = _wrap_method(pathlib.Path.symlink_to) + if sys.version_info >= (3, 10): + hardlink_to = _wrap_method(pathlib.Path.hardlink_to) + touch = _wrap_method(pathlib.Path.touch) + unlink = _wrap_method(pathlib.Path.unlink) + absolute = _wrap_method_path(pathlib.Path.absolute) + expanduser = _wrap_method_path(pathlib.Path.expanduser) + read_bytes = _wrap_method(pathlib.Path.read_bytes) + read_text = _wrap_method(pathlib.Path.read_text) + samefile = _wrap_method(pathlib.Path.samefile) + write_bytes = _wrap_method(pathlib.Path.write_bytes) + write_text = _wrap_method(pathlib.Path.write_text) + if sys.version_info < (3, 12): + link_to = _wrap_method(pathlib.Path.link_to) + if sys.version_info >= (3, 13): + full_match = _wrap_method(pathlib.Path.full_match) + + +@final +class PosixPath(Path, pathlib.PurePosixPath): + """An async :class:`pathlib.PosixPath` that executes blocking methods in :meth:`trio.to_thread.run_sync`.""" + + __slots__ = () + + _wrapped_cls: ClassVar[type[pathlib.Path]] = pathlib.PosixPath + + +@final +class WindowsPath(Path, pathlib.PureWindowsPath): + """An async :class:`pathlib.WindowsPath` that executes blocking methods in :meth:`trio.to_thread.run_sync`.""" + + __slots__ = () + + _wrapped_cls: ClassVar[type[pathlib.Path]] = pathlib.WindowsPath diff --git a/src/trio/_repl.py b/src/trio/_repl.py new file mode 100644 index 0000000000..73f050140e --- /dev/null +++ b/src/trio/_repl.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +import ast +import contextlib +import inspect +import sys +import types +import warnings +from code import InteractiveConsole + +import outcome + +import trio +import trio.lowlevel +from trio._util import final + + +@final +class TrioInteractiveConsole(InteractiveConsole): + # code.InteractiveInterpreter defines locals as Mapping[str, Any] + # but when we pass this to FunctionType it expects a dict. So + # we make the type more specific on our subclass + locals: dict[str, object] + + def __init__(self, repl_locals: dict[str, object] | None = None): + super().__init__(locals=repl_locals) + self.compile.compiler.flags |= ast.PyCF_ALLOW_TOP_LEVEL_AWAIT + + def runcode(self, code: types.CodeType) -> None: + func = types.FunctionType(code, self.locals) + if inspect.iscoroutinefunction(func): + result = trio.from_thread.run(outcome.acapture, func) + else: + result = trio.from_thread.run_sync(outcome.capture, func) + if isinstance(result, outcome.Error): + # If it is SystemExit, quit the repl. Otherwise, print the traceback. + # If there is a SystemExit inside a BaseExceptionGroup, it probably isn't + # the user trying to quit the repl, but rather an error in the code. So, we + # don't try to inspect groups for SystemExit. Instead, we just print and + # return to the REPL. + if isinstance(result.error, SystemExit): + raise result.error + else: + # Inline our own version of self.showtraceback that can use + # outcome.Error.error directly to print clean tracebacks. + # This also means overriding self.showtraceback does nothing. + sys.last_type, sys.last_value = type(result.error), result.error + sys.last_traceback = result.error.__traceback__ + # see https://docs.python.org/3/library/sys.html#sys.last_exc + if sys.version_info >= (3, 12): + sys.last_exc = result.error + + # We always use sys.excepthook, unlike other implementations. + # This means that overriding self.write also does nothing to tbs. + sys.excepthook(sys.last_type, sys.last_value, sys.last_traceback) + + +async def run_repl(console: TrioInteractiveConsole) -> None: + banner = ( + f"trio REPL {sys.version} on {sys.platform}\n" + f'Use "await" directly instead of "trio.run()".\n' + f'Type "help", "copyright", "credits" or "license" ' + f"for more information.\n" + f'{getattr(sys, "ps1", ">>> ")}import trio' + ) + try: + await trio.to_thread.run_sync(console.interact, banner) + finally: + warnings.filterwarnings( + "ignore", + message=r"^coroutine .* was never awaited$", + category=RuntimeWarning, + ) + + +def main(original_locals: dict[str, object]) -> None: + with contextlib.suppress(ImportError): + import readline # noqa: F401 + + repl_locals: dict[str, object] = {"trio": trio} + for key in { + "__name__", + "__package__", + "__loader__", + "__spec__", + "__builtins__", + "__file__", + }: + repl_locals[key] = original_locals[key] + + console = TrioInteractiveConsole(repl_locals) + trio.run(run_repl, console) diff --git a/trio/_signals.py b/src/trio/_signals.py similarity index 98% rename from trio/_signals.py rename to src/trio/_signals.py index 283c3a44a8..f4d912808f 100644 --- a/trio/_signals.py +++ b/src/trio/_signals.py @@ -2,9 +2,7 @@ import signal from collections import OrderedDict -from collections.abc import AsyncIterator, Callable, Generator, Iterable from contextlib import contextmanager -from types import FrameType from typing import TYPE_CHECKING import trio @@ -12,6 +10,9 @@ from ._util import ConflictDetector, is_main_thread, signal_raise if TYPE_CHECKING: + from collections.abc import AsyncIterator, Callable, Generator, Iterable + from types import FrameType + from typing_extensions import Self # Discussion of signal handling strategies: diff --git a/trio/_socket.py b/src/trio/_socket.py similarity index 93% rename from trio/_socket.py rename to src/trio/_socket.py index efef2ff5ed..0a3bd1cba1 100644 --- a/trio/_socket.py +++ b/src/trio/_socket.py @@ -240,7 +240,7 @@ def numeric_only_failure(exc: BaseException) -> bool: type, proto, flags, - cancellable=True, + abandon_on_cancel=True, ) @@ -261,7 +261,7 @@ async def getnameinfo( return await hr.getnameinfo(sockaddr, flags) else: return await trio.to_thread.run_sync( - _stdlib_socket.getnameinfo, sockaddr, flags, cancellable=True + _stdlib_socket.getnameinfo, sockaddr, flags, abandon_on_cancel=True ) @@ -272,7 +272,7 @@ async def getprotobyname(name: str) -> int: """ return await trio.to_thread.run_sync( - _stdlib_socket.getprotobyname, name, cancellable=True + _stdlib_socket.getprotobyname, name, abandon_on_cancel=True ) @@ -300,8 +300,8 @@ def fromfd( proto: int = 0, ) -> SocketType: """Like :func:`socket.fromfd`, but returns a Trio socket object.""" - family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, index(fd)) - return from_stdlib_socket(_stdlib_socket.fromfd(fd, family, type, proto)) + family, type_, proto = _sniff_sockopts_for_fileno(family, type, proto, index(fd)) + return from_stdlib_socket(_stdlib_socket.fromfd(fd, family, type_, proto)) if sys.platform == "win32" or ( @@ -310,6 +310,7 @@ def fromfd( @_wraps(_stdlib_socket.fromshare, assigned=(), updated=()) def fromshare(info: bytes) -> SocketType: + """Like :func:`socket.fromshare`, but returns a Trio socket object.""" return from_stdlib_socket(_stdlib_socket.fromshare(info)) @@ -355,14 +356,16 @@ def socket( if sf is not None: return sf.socket(family, type, proto) else: - family, type, proto = _sniff_sockopts_for_fileno(family, type, proto, fileno) + family, type, proto = _sniff_sockopts_for_fileno( # noqa: A001 + family, type, proto, fileno + ) stdlib_socket = _stdlib_socket.socket(family, type, proto, fileno) return from_stdlib_socket(stdlib_socket) def _sniff_sockopts_for_fileno( family: AddressFamily | int, - type: SocketKind | int, + type_: SocketKind | int, proto: int, fileno: int | None, ) -> tuple[AddressFamily | int, SocketKind | int, int]: @@ -371,23 +374,23 @@ def _sniff_sockopts_for_fileno( # This object might have the wrong metadata, but it lets us easily call getsockopt # and then we'll throw it away and construct a new one with the correct metadata. if sys.platform != "linux": - return family, type, proto - from socket import ( # type: ignore[attr-defined] + return family, type_, proto + from socket import ( # type: ignore[attr-defined,unused-ignore] SO_DOMAIN, SO_PROTOCOL, SO_TYPE, SOL_SOCKET, ) - sockobj = _stdlib_socket.socket(family, type, proto, fileno=fileno) + sockobj = _stdlib_socket.socket(family, type_, proto, fileno=fileno) try: family = sockobj.getsockopt(SOL_SOCKET, SO_DOMAIN) proto = sockobj.getsockopt(SOL_SOCKET, SO_PROTOCOL) - type = sockobj.getsockopt(SOL_SOCKET, SO_TYPE) + type_ = sockobj.getsockopt(SOL_SOCKET, SO_TYPE) finally: # Unwrap it again, so that sockobj.__del__ doesn't try to close our socket sockobj.detach() - return family, type, proto + return family, type_, proto ################################################################ @@ -447,7 +450,7 @@ async def wrapper(self: _SocketType, *args: P.args, **kwargs: P.kwargs) -> T: # @overload likely works, but is extremely verbose. # NOTE: this function does not always checkpoint async def _resolve_address_nocp( - type: int, + type_: int, family: AddressFamily, proto: int, *, @@ -501,7 +504,7 @@ async def _resolve_address_nocp( # flags |= AI_ADDRCONFIG if family == _stdlib_socket.AF_INET6 and not ipv6_v6only: flags |= _stdlib_socket.AI_V4MAPPED - gai_res = await getaddrinfo(host, port, family, type, proto, flags) + gai_res = await getaddrinfo(host, port, family, type_, proto, flags) # AFAICT from the spec it's not possible for getaddrinfo to return an # empty list. assert len(gai_res) >= 1 @@ -525,7 +528,7 @@ def __init__(self) -> None: # make sure this __init__ works with multiple inheritance super().__init__() # and only raises error if it's directly constructed - if type(self) == SocketType: + if type(self) is SocketType: raise TypeError( "SocketType is an abstract class; use trio.socket.socket if you " "want to construct a socket object" @@ -544,12 +547,10 @@ def getsockname(self) -> AddressFormat: raise NotImplementedError @overload - def getsockopt(self, /, level: int, optname: int) -> int: - ... + def getsockopt(self, /, level: int, optname: int) -> int: ... @overload - def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: - ... + def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: ... def getsockopt( self, /, level: int, optname: int, buflen: int | None = None @@ -557,12 +558,12 @@ def getsockopt( raise NotImplementedError @overload - def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: - ... + def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: ... @overload - def setsockopt(self, /, level: int, optname: int, value: None, optlen: int) -> None: - ... + def setsockopt( + self, /, level: int, optname: int, value: None, optlen: int + ) -> None: ... def setsockopt( self, @@ -615,6 +616,7 @@ def proto(self) -> int: @property def did_shutdown_SHUT_WR(self) -> bool: + """Return True if the socket has been shut down with the SHUT_WR flag""" raise NotImplementedError def __repr__(self) -> str: @@ -633,9 +635,11 @@ def shutdown(self, flag: int) -> None: raise NotImplementedError def is_readable(self) -> bool: + """Return True if the socket is readable. This is checked with `select.select` on Windows, otherwise `select.poll`.""" raise NotImplementedError async def wait_writable(self) -> None: + """Convenience method that calls trio.lowlevel.wait_writable for the object.""" raise NotImplementedError async def accept(self) -> tuple[SocketType, AddressFormat]: @@ -695,8 +699,7 @@ def send(__self, __bytes: Buffer, __flags: int = 0) -> Awaitable[int]: @overload async def sendto( self, __data: Buffer, __address: tuple[object, ...] | str | Buffer - ) -> int: - ... + ) -> int: ... @overload async def sendto( @@ -704,8 +707,7 @@ async def sendto( __data: Buffer, __flags: int, __address: tuple[object, ...] | str | Buffer, - ) -> int: - ... + ) -> int: ... async def sendto(self, *args: Any) -> int: raise NotImplementedError @@ -725,6 +727,21 @@ async def sendmsg( raise NotImplementedError +# copy docstrings from socket.SocketType / socket.socket +for name, obj in SocketType.__dict__.items(): + # skip dunders and already defined docstrings + if name.startswith("__") or obj.__doc__: + continue + # try both socket.socket and socket.SocketType + for stdlib_type in _stdlib_socket.socket, _stdlib_socket.SocketType: + stdlib_obj = getattr(stdlib_type, name, None) + if stdlib_obj and stdlib_obj.__doc__: + break + else: + continue + obj.__doc__ = stdlib_obj.__doc__ + + class _SocketType(SocketType): def __init__(self, sock: _stdlib_socket.socket): if type(sock) is not _stdlib_socket.socket: @@ -755,12 +772,10 @@ def getsockname(self) -> AddressFormat: return self._sock.getsockname() @overload - def getsockopt(self, /, level: int, optname: int) -> int: - ... + def getsockopt(self, /, level: int, optname: int) -> int: ... @overload - def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: - ... + def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: ... def getsockopt( self, /, level: int, optname: int, buflen: int | None = None @@ -770,12 +785,12 @@ def getsockopt( return self._sock.getsockopt(level, optname, buflen) @overload - def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: - ... + def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: ... @overload - def setsockopt(self, /, level: int, optname: int, value: None, optlen: int) -> None: - ... + def setsockopt( + self, /, level: int, optname: int, value: None, optlen: int + ) -> None: ... def setsockopt( self, @@ -1054,13 +1069,12 @@ async def connect(self, address: AddressFormat) -> None: # complain about AmbiguousType if TYPE_CHECKING: - def recv(__self, __buflen: int, __flags: int = 0) -> Awaitable[bytes]: - ... + def recv(__self, __buflen: int, __flags: int = 0) -> Awaitable[bytes]: ... # _make_simple_sock_method_wrapper is typed, so this checks that the above is correct # this requires that we refrain from using `/` to specify pos-only # args, or mypy thinks the signature differs from typeshed. - recv = _make_simple_sock_method_wrapper( # noqa: F811 + recv = _make_simple_sock_method_wrapper( _stdlib_socket.socket.recv, _core.wait_readable ) @@ -1072,10 +1086,9 @@ def recv(__self, __buflen: int, __flags: int = 0) -> Awaitable[bytes]: def recv_into( __self, buffer: Buffer, nbytes: int = 0, flags: int = 0 - ) -> Awaitable[int]: - ... + ) -> Awaitable[int]: ... - recv_into = _make_simple_sock_method_wrapper( # noqa: F811 + recv_into = _make_simple_sock_method_wrapper( _stdlib_socket.socket.recv_into, _core.wait_readable ) @@ -1087,10 +1100,9 @@ def recv_into( # return type of socket.socket.recvfrom in typeshed is tuple[bytes, Any] def recvfrom( __self, __bufsize: int, __flags: int = 0 - ) -> Awaitable[tuple[bytes, AddressFormat]]: - ... + ) -> Awaitable[tuple[bytes, AddressFormat]]: ... - recvfrom = _make_simple_sock_method_wrapper( # noqa: F811 + recvfrom = _make_simple_sock_method_wrapper( _stdlib_socket.socket.recvfrom, _core.wait_readable ) @@ -1102,10 +1114,9 @@ def recvfrom( # return type of socket.socket.recvfrom_into in typeshed is tuple[bytes, Any] def recvfrom_into( __self, buffer: Buffer, nbytes: int = 0, flags: int = 0 - ) -> Awaitable[tuple[int, AddressFormat]]: - ... + ) -> Awaitable[tuple[int, AddressFormat]]: ... - recvfrom_into = _make_simple_sock_method_wrapper( # noqa: F811 + recvfrom_into = _make_simple_sock_method_wrapper( _stdlib_socket.socket.recvfrom_into, _core.wait_readable ) @@ -1120,8 +1131,7 @@ def recvfrom_into( def recvmsg( __self, __bufsize: int, __ancbufsize: int = 0, __flags: int = 0 - ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, Any]]: - ... + ) -> Awaitable[tuple[bytes, list[tuple[int, int, bytes]], int, Any]]: ... recvmsg = _make_simple_sock_method_wrapper( _stdlib_socket.socket.recvmsg, _core.wait_readable, maybe_avail=True @@ -1141,8 +1151,7 @@ def recvmsg_into( __buffers: Iterable[Buffer], __ancbufsize: int = 0, __flags: int = 0, - ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, Any]]: - ... + ) -> Awaitable[tuple[int, list[tuple[int, int, bytes]], int, Any]]: ... recvmsg_into = _make_simple_sock_method_wrapper( _stdlib_socket.socket.recvmsg_into, _core.wait_readable, maybe_avail=True @@ -1154,10 +1163,9 @@ def recvmsg_into( if TYPE_CHECKING: - def send(__self, __bytes: Buffer, __flags: int = 0) -> Awaitable[int]: - ... + def send(__self, __bytes: Buffer, __flags: int = 0) -> Awaitable[int]: ... - send = _make_simple_sock_method_wrapper( # noqa: F811 + send = _make_simple_sock_method_wrapper( _stdlib_socket.socket.send, _core.wait_writable ) @@ -1168,14 +1176,12 @@ def send(__self, __bytes: Buffer, __flags: int = 0) -> Awaitable[int]: @overload async def sendto( self, __data: Buffer, __address: tuple[object, ...] | str | Buffer - ) -> int: - ... + ) -> int: ... @overload async def sendto( self, __data: Buffer, __flags: int, __address: tuple[object, ...] | str | Buffer - ) -> int: - ... + ) -> int: ... @_wraps(_stdlib_socket.socket.sendto, assigned=(), updated=()) # type: ignore[misc] async def sendto(self, *args: Any) -> int: diff --git a/trio/_ssl.py b/src/trio/_ssl.py similarity index 99% rename from trio/_ssl.py rename to src/trio/_ssl.py index 2c67da44a9..5bc37cf7dc 100644 --- a/trio/_ssl.py +++ b/src/trio/_ssl.py @@ -3,9 +3,8 @@ import contextlib import operator as _operator import ssl as _stdlib_ssl -from collections.abc import Awaitable, Callable from enum import Enum as _Enum -from typing import Any, ClassVar, Final as TFinal, Generic, TypeVar +from typing import TYPE_CHECKING, Any, ClassVar, Final as TFinal, Generic, TypeVar import trio @@ -14,6 +13,9 @@ from ._util import ConflictDetector, final from .abc import Listener, Stream +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + # General theory of operation: # # We implement an API that closely mirrors the stdlib ssl module's blocking diff --git a/trio/_subprocess.py b/src/trio/_subprocess.py similarity index 69% rename from trio/_subprocess.py rename to src/trio/_subprocess.py index eaa9bb7ce4..553e3d4885 100644 --- a/trio/_subprocess.py +++ b/src/trio/_subprocess.py @@ -2,21 +2,16 @@ import contextlib import os -import signal import subprocess import sys import warnings -from collections.abc import Awaitable, Callable, Mapping, Sequence from contextlib import ExitStack from functools import partial -from io import TextIOWrapper from typing import TYPE_CHECKING, Final, Literal, Protocol, Union, overload import trio -from ._abc import AsyncResource, ReceiveStream, SendStream from ._core import ClosedResourceError, TaskStatus -from ._deprecate import deprecated from ._highlevel_generic import StapledStream from ._subprocess_platform import ( create_pipe_from_child_output, @@ -27,11 +22,22 @@ from ._util import NoPublicConstructor, final if TYPE_CHECKING: - from typing_extensions import Self, TypeAlias + import signal + from collections.abc import Awaitable, Callable, Mapping, Sequence + from io import TextIOWrapper + from typing_extensions import TypeAlias -# Only subscriptable in 3.9+ -StrOrBytesPath: TypeAlias = Union[str, bytes, "os.PathLike[str]", "os.PathLike[bytes]"] + from ._abc import ReceiveStream, SendStream + + +# Sphinx cannot parse the stringified version +if sys.version_info >= (3, 9): + StrOrBytesPath: TypeAlias = Union[str, bytes, os.PathLike[str], os.PathLike[bytes]] +else: + StrOrBytesPath: TypeAlias = Union[ + str, bytes, "os.PathLike[str]", "os.PathLike[bytes]" + ] # Linux-specific, but has complex lifetime management stuff so we hard-code it @@ -39,8 +45,7 @@ can_try_pidfd_open: bool if TYPE_CHECKING: - def pidfd_open(fd: int, flags: int) -> int: - ... + def pidfd_open(fd: int, flags: int) -> int: ... from ._subprocess_platform import ClosableReceiveStream, ClosableSendStream @@ -50,8 +55,14 @@ def pidfd_open(fd: int, flags: int) -> int: from os import pidfd_open except ImportError: if sys.platform == "linux": - # This workaround is only needed on 3.8 and pypy - assert sys.version_info < (3, 9) or sys.implementation.name != "cpython" + # this workaround is needed on: + # - CPython <= 3.8 + # - non-CPython (maybe?) + # - Anaconda's interpreter (as it is built to assume an older + # than current linux kernel) + # + # The last point implies that other custom builds might not work; + # therefore, no assertion should be here. import ctypes _cdll_for_pidfd_open = ctypes.CDLL(None, use_errno=True) @@ -84,12 +95,11 @@ def pidfd_open(fd: int, flags: int) -> int: class HasFileno(Protocol): """Represents any file-like object that has a file descriptor.""" - def fileno(self) -> int: - ... + def fileno(self) -> int: ... @final -class Process(AsyncResource, metaclass=NoPublicConstructor): +class Process(metaclass=NoPublicConstructor): r"""A child process. Like :class:`subprocess.Popen`, but async. This class has no public constructor. The most common way to get a @@ -131,6 +141,7 @@ class Process(AsyncResource, metaclass=NoPublicConstructor): available; otherwise this will be None. """ + # We're always in binary mode. universal_newlines: Final = False encoding: Final = None @@ -163,7 +174,7 @@ def __init__( if can_try_pidfd_open: try: fd: int = pidfd_open(self._proc.pid, 0) - except OSError: + except OSError: # pragma: no cover # Well, we tried, but it didn't work (probably because we're # running on an older kernel, or in an older sandbox, that # hasn't been updated to support pidfd_open). We'll fall back @@ -211,40 +222,6 @@ def returncode(self) -> int | None: self._close_pidfd() return result - @deprecated( - "0.20.0", - thing="using trio.Process as an async context manager", - issue=1104, - instead="run_process or nursery.start(run_process, ...)", - ) - async def __aenter__(self) -> Self: - return self - - @deprecated( - "0.20.0", issue=1104, instead="run_process or nursery.start(run_process, ...)" - ) - async def aclose(self) -> None: - """Close any pipes we have to the process (both input and output) - and wait for it to exit. - - If cancelled, kills the process and waits for it to finish - exiting before propagating the cancellation. - """ - with trio.CancelScope(shield=True): - if self.stdin is not None: - await self.stdin.aclose() - if self.stdout is not None: - await self.stdout.aclose() - if self.stderr is not None: - await self.stderr.aclose() - try: - await self.wait() - finally: - if self._proc.returncode is None: - self.kill() - with trio.CancelScope(shield=True): - await self.wait() - def _close_pidfd(self) -> None: if self._pidfd is not None: trio.lowlevel.notify_closing(self._pidfd.fileno()) @@ -479,6 +456,7 @@ async def _posix_deliver_cancel(p: Process) -> None: # Use a private name, so we can declare platform-specific stubs below. +# This is also the signature read by Sphinx async def _run_process( command: StrOrBytesPath | Sequence[StrOrBytesPath], *, @@ -673,6 +651,10 @@ async def my_deliver_cancel(process): and the process exits with a nonzero exit status OSError: if an error is encountered starting or communicating with the process + ExceptionGroup: if exceptions occur in ``deliver_cancel``, + or when exceptions occur when communicating with the subprocess. + If strict_exception_groups is set to false in the global context, + which is deprecated, then single exceptions will be collapsed. .. note:: The child process runs in the same process group as the parent Trio process, so a Ctrl+C will be delivered simultaneously to both @@ -702,13 +684,13 @@ async def my_deliver_cancel(process): "since that's the only way to access the pipe" ) if isinstance(stdin, (bytes, bytearray, memoryview)): - input = stdin + input_ = stdin options["stdin"] = subprocess.PIPE else: # stdin should be something acceptable to Process # (None, DEVNULL, a file descriptor, etc) and Process # will raise if it's not - input = None + input_ = None options["stdin"] = stdin if capture_stdout: @@ -733,8 +715,8 @@ async def my_deliver_cancel(process): async def feed_input(stream: SendStream) -> None: async with stream: try: - assert input is not None - await stream.send_all(input) + assert input_ is not None + await stream.send_all(input_) except trio.BrokenResourceError: pass @@ -744,21 +726,26 @@ async def read_output( ) -> None: async with stream: async for chunk in stream: - chunks.append(chunk) + chunks.append(chunk) # noqa: PERF401 + # Opening the process does not need to be inside the nursery, so we put it outside + # so any exceptions get directly seen by users. + # options needs a complex TypedDict. The overload error only occurs on Unix. + proc = await open_process(command, **options) # type: ignore[arg-type, call-overload, unused-ignore] async with trio.open_nursery() as nursery: - # options needs a complex TypedDict. The overload error only occurs on Unix. - proc = await open_process(command, **options) # type: ignore[arg-type, call-overload, unused-ignore] try: - if input is not None: + if input_ is not None: + assert proc.stdin is not None nursery.start_soon(feed_input, proc.stdin) proc.stdin = None proc.stdio = None if capture_stdout: + assert proc.stdout is not None nursery.start_soon(read_output, proc.stdout, stdout_chunks) proc.stdout = None proc.stdio = None if capture_stderr: + assert proc.stderr is not None nursery.start_soon(read_output, proc.stderr, stderr_chunks) proc.stderr = None task_status.started(proc) @@ -815,6 +802,56 @@ async def open_process( startupinfo: subprocess.STARTUPINFO | None = None, creationflags: int = 0, ) -> trio.Process: + r"""Execute a child program in a new process. + + After construction, you can interact with the child process by writing data to its + `~trio.Process.stdin` stream (a `~trio.abc.SendStream`), reading data from its + `~trio.Process.stdout` and/or `~trio.Process.stderr` streams (both + `~trio.abc.ReceiveStream`\s), sending it signals using `~trio.Process.terminate`, + `~trio.Process.kill`, or `~trio.Process.send_signal`, and waiting for it to exit + using `~trio.Process.wait`. See `trio.Process` for details. + + Each standard stream is only available if you specify that a pipe should be created + for it. For example, if you pass ``stdin=subprocess.PIPE``, you can write to the + `~trio.Process.stdin` stream, else `~trio.Process.stdin` will be ``None``. + + Unlike `trio.run_process`, this function doesn't do any kind of automatic + management of the child process. It's up to you to implement whatever semantics you + want. + + Args: + command (list or str): The command to run. Typically this is a + sequence of strings such as ``['ls', '-l', 'directory with spaces']``, + where the first element names the executable to invoke and the other + elements specify its arguments. With ``shell=True`` in the + ``**options``, or on Windows, ``command`` may alternatively + be a string, which will be parsed following platform-dependent + :ref:`quoting rules `. + stdin: Specifies what the child process's standard input + stream should connect to: output written by the parent + (``subprocess.PIPE``), nothing (``subprocess.DEVNULL``), + or an open file (pass a file descriptor or something whose + ``fileno`` method returns one). If ``stdin`` is unspecified, + the child process will have the same standard input stream + as its parent. + stdout: Like ``stdin``, but for the child process's standard output + stream. + stderr: Like ``stdin``, but for the child process's standard error + stream. An additional value ``subprocess.STDOUT`` is supported, + which causes the child's standard output and standard error + messages to be intermixed on a single standard output stream, + attached to whatever the ``stdout`` option says to attach it to. + **options: Other :ref:`general subprocess options ` + are also accepted. + + Returns: + A new `trio.Process` object. + + Raises: + OSError: if the process spawning fails, for example because the + specified command could not be found. + + """ ... async def run_process( @@ -835,10 +872,203 @@ async def run_process( startupinfo: subprocess.STARTUPINFO | None = None, creationflags: int = 0, ) -> subprocess.CompletedProcess[bytes]: + """Run ``command`` in a subprocess and wait for it to complete. + + This function can be called in two different ways. + + One option is a direct call, like:: + + completed_process_info = await trio.run_process(...) + + In this case, it returns a :class:`subprocess.CompletedProcess` instance + describing the results. Use this if you want to treat a process like a + function call. + + The other option is to run it as a task using `Nursery.start` – the enhanced version + of `~Nursery.start_soon` that lets a task pass back a value during startup:: + + process = await nursery.start(trio.run_process, ...) + + In this case, `~Nursery.start` returns a `Process` object that you can use + to interact with the process while it's running. Use this if you want to + treat a process like a background task. + + Either way, `run_process` makes sure that the process has exited before + returning, handles cancellation, optionally checks for errors, and + provides some convenient shorthands for dealing with the child's + input/output. + + **Input:** `run_process` supports all the same ``stdin=`` arguments as + `subprocess.Popen`. In addition, if you simply want to pass in some fixed + data, you can pass a plain `bytes` object, and `run_process` will take + care of setting up a pipe, feeding in the data you gave, and then sending + end-of-file. The default is ``b""``, which means that the child will receive + an empty stdin. If you want the child to instead read from the parent's + stdin, use ``stdin=None``. + + **Output:** By default, any output produced by the subprocess is + passed through to the standard output and error streams of the + parent Trio process. + + When calling `run_process` directly, you can capture the subprocess's output by + passing ``capture_stdout=True`` to capture the subprocess's standard output, and/or + ``capture_stderr=True`` to capture its standard error. Captured data is collected up + by Trio into an in-memory buffer, and then provided as the + :attr:`~subprocess.CompletedProcess.stdout` and/or + :attr:`~subprocess.CompletedProcess.stderr` attributes of the returned + :class:`~subprocess.CompletedProcess` object. The value for any stream that was not + captured will be ``None``. + + If you want to capture both stdout and stderr while keeping them + separate, pass ``capture_stdout=True, capture_stderr=True``. + + If you want to capture both stdout and stderr but mixed together + in the order they were printed, use: ``capture_stdout=True, stderr=subprocess.STDOUT``. + This directs the child's stderr into its stdout, so the combined + output will be available in the `~subprocess.CompletedProcess.stdout` + attribute. + + If you're using ``await nursery.start(trio.run_process, ...)`` and want to capture + the subprocess's output for further processing, then use ``stdout=subprocess.PIPE`` + and then make sure to read the data out of the `Process.stdout` stream. If you want + to capture stderr separately, use ``stderr=subprocess.PIPE``. If you want to capture + both, but mixed together in the correct order, use ``stdout=subprocess.PIPE, + stderr=subprocess.STDOUT``. + + **Error checking:** If the subprocess exits with a nonzero status + code, indicating failure, :func:`run_process` raises a + :exc:`subprocess.CalledProcessError` exception rather than + returning normally. The captured outputs are still available as + the :attr:`~subprocess.CalledProcessError.stdout` and + :attr:`~subprocess.CalledProcessError.stderr` attributes of that + exception. To disable this behavior, so that :func:`run_process` + returns normally even if the subprocess exits abnormally, pass ``check=False``. + + Note that this can make the ``capture_stdout`` and ``capture_stderr`` + arguments useful even when starting `run_process` as a task: if you only + care about the output if the process fails, then you can enable capturing + and then read the output off of the `~subprocess.CalledProcessError`. + + **Cancellation:** If cancelled, `run_process` sends a termination + request to the subprocess, then waits for it to fully exit. The + ``deliver_cancel`` argument lets you control how the process is terminated. + + .. note:: `run_process` is intentionally similar to the standard library + `subprocess.run`, but some of the defaults are different. Specifically, we + default to: + + - ``check=True``, because `"errors should never pass silently / unless + explicitly silenced" `__. + + - ``stdin=b""``, because it produces less-confusing results if a subprocess + unexpectedly tries to read from stdin. + + To get the `subprocess.run` semantics, use ``check=False, stdin=None``. + + Args: + command (list or str): The command to run. Typically this is a + sequence of strings such as ``['ls', '-l', 'directory with spaces']``, + where the first element names the executable to invoke and the other + elements specify its arguments. With ``shell=True`` in the + ``**options``, or on Windows, ``command`` may alternatively + be a string, which will be parsed following platform-dependent + :ref:`quoting rules `. + + stdin (:obj:`bytes`, subprocess.PIPE, file descriptor, or None): The + bytes to provide to the subprocess on its standard input stream, or + ``None`` if the subprocess's standard input should come from the + same place as the parent Trio process's standard input. As is the + case with the :mod:`subprocess` module, you can also pass a file + descriptor or an object with a ``fileno()`` method, in which case + the subprocess's standard input will come from that file. + + When starting `run_process` as a background task, you can also use + ``stdin=subprocess.PIPE``, in which case `Process.stdin` will be a + `~trio.abc.SendStream` that you can use to send data to the child. + + capture_stdout (bool): If true, capture the bytes that the subprocess + writes to its standard output stream and return them in the + `~subprocess.CompletedProcess.stdout` attribute of the returned + `subprocess.CompletedProcess` or `subprocess.CalledProcessError`. + + capture_stderr (bool): If true, capture the bytes that the subprocess + writes to its standard error stream and return them in the + `~subprocess.CompletedProcess.stderr` attribute of the returned + `~subprocess.CompletedProcess` or `subprocess.CalledProcessError`. + + check (bool): If false, don't validate that the subprocess exits + successfully. You should be sure to check the + ``returncode`` attribute of the returned object if you pass + ``check=False``, so that errors don't pass silently. + + deliver_cancel (async function or None): If `run_process` is cancelled, + then it needs to kill the child process. There are multiple ways to + do this, so we let you customize it. + + If you pass None (the default), then the behavior depends on the + platform: + + - On Windows, Trio calls ``TerminateProcess``, which should kill the + process immediately. + + - On Unix-likes, the default behavior is to send a ``SIGTERM``, wait + 5 seconds, and send a ``SIGKILL``. + + Alternatively, you can customize this behavior by passing in an + arbitrary async function, which will be called with the `Process` + object as an argument. For example, the default Unix behavior could + be implemented like this:: + + async def my_deliver_cancel(process): + process.send_signal(signal.SIGTERM) + await trio.sleep(5) + process.send_signal(signal.SIGKILL) + + When the process actually exits, the ``deliver_cancel`` function + will automatically be cancelled – so if the process exits after + ``SIGTERM``, then we'll never reach the ``SIGKILL``. + + In any case, `run_process` will always wait for the child process to + exit before raising `Cancelled`. + + **options: :func:`run_process` also accepts any :ref:`general subprocess + options ` and passes them on to the + :class:`~trio.Process` constructor. This includes the + ``stdout`` and ``stderr`` options, which provide additional + redirection possibilities such as ``stderr=subprocess.STDOUT``, + ``stdout=subprocess.DEVNULL``, or file descriptors. + + Returns: + + When called normally – a `subprocess.CompletedProcess` instance + describing the return code and outputs. + + When called via `Nursery.start` – a `trio.Process` instance. + + Raises: + UnicodeError: if ``stdin`` is specified as a Unicode string, rather + than bytes + ValueError: if multiple redirections are specified for the same + stream, e.g., both ``capture_stdout=True`` and + ``stdout=subprocess.DEVNULL`` + subprocess.CalledProcessError: if ``check=False`` is not passed + and the process exits with a nonzero exit status + OSError: if an error is encountered starting or communicating with + the process + + .. note:: The child process runs in the same process group as the parent + Trio process, so a Ctrl+C will be delivered simultaneously to both + parent and child. If you don't want this behavior, consult your + platform's documentation for starting child processes in a different + process group. + + """ ... else: # Unix - + # pyright doesn't give any error about these missing docstrings as they're + # overloads. But might still be a problem for other static analyzers / docstring + # readers (?) @overload # type: ignore[no-overload-impl] async def open_process( command: StrOrBytesPath, @@ -854,8 +1084,7 @@ async def open_process( restore_signals: bool = True, start_new_session: bool = False, pass_fds: Sequence[int] = (), - ) -> trio.Process: - ... + ) -> trio.Process: ... @overload async def open_process( @@ -872,8 +1101,7 @@ async def open_process( restore_signals: bool = True, start_new_session: bool = False, pass_fds: Sequence[int] = (), - ) -> trio.Process: - ... + ) -> trio.Process: ... @overload # type: ignore[no-overload-impl] async def run_process( @@ -895,8 +1123,7 @@ async def run_process( restore_signals: bool = True, start_new_session: bool = False, pass_fds: Sequence[int] = (), - ) -> subprocess.CompletedProcess[bytes]: - ... + ) -> subprocess.CompletedProcess[bytes]: ... @overload async def run_process( @@ -918,8 +1145,7 @@ async def run_process( restore_signals: bool = True, start_new_session: bool = False, pass_fds: Sequence[int] = (), - ) -> subprocess.CompletedProcess[bytes]: - ... + ) -> subprocess.CompletedProcess[bytes]: ... else: # At runtime, use the actual implementations. diff --git a/trio/_subprocess_platform/__init__.py b/src/trio/_subprocess_platform/__init__.py similarity index 84% rename from trio/_subprocess_platform/__init__.py rename to src/trio/_subprocess_platform/__init__.py index 793d8d7f23..d74cd462a0 100644 --- a/trio/_subprocess_platform/__init__.py +++ b/src/trio/_subprocess_platform/__init__.py @@ -1,32 +1,31 @@ # Platform-specific subprocess bits'n'pieces. +from __future__ import annotations import os import sys -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING import trio from .. import _core, _subprocess -from .._abc import ReceiveStream, SendStream +from .._abc import ReceiveStream, SendStream # noqa: TCH001 -_wait_child_exiting_error: Optional[ImportError] = None -_create_child_pipe_error: Optional[ImportError] = None +_wait_child_exiting_error: ImportError | None = None +_create_child_pipe_error: ImportError | None = None if TYPE_CHECKING: # internal types for the pipe representations used in type checking only class ClosableSendStream(SendStream): - def close(self) -> None: - ... + def close(self) -> None: ... class ClosableReceiveStream(ReceiveStream): - def close(self) -> None: - ... + def close(self) -> None: ... # Fallback versions of the functions provided -- implementations # per OS are imported atop these at the bottom of the module. -async def wait_child_exiting(process: "_subprocess.Process") -> None: +async def wait_child_exiting(process: _subprocess.Process) -> None: """Block until the child process managed by ``process`` is exiting. It is invalid to call this function if the process has already @@ -41,7 +40,7 @@ async def wait_child_exiting(process: "_subprocess.Process") -> None: raise NotImplementedError from _wait_child_exiting_error # pragma: no cover -def create_pipe_to_child_stdin() -> Tuple["ClosableSendStream", int]: +def create_pipe_to_child_stdin() -> tuple[ClosableSendStream, int]: """Create a new pipe suitable for sending data from this process to the standard input of a child we're about to spawn. @@ -54,7 +53,7 @@ def create_pipe_to_child_stdin() -> Tuple["ClosableSendStream", int]: raise NotImplementedError from _create_child_pipe_error # pragma: no cover -def create_pipe_from_child_output() -> Tuple["ClosableReceiveStream", int]: +def create_pipe_from_child_output() -> tuple[ClosableReceiveStream, int]: """Create a new pipe suitable for receiving data into this process from the standard output or error stream of a child we're about to spawn. @@ -70,7 +69,7 @@ def create_pipe_from_child_output() -> Tuple["ClosableReceiveStream", int]: try: if sys.platform == "win32": - from .windows import wait_child_exiting # noqa: F811 + from .windows import wait_child_exiting elif sys.platform != "linux" and (TYPE_CHECKING or hasattr(_core, "wait_kevent")): from .kqueue import wait_child_exiting else: @@ -86,11 +85,11 @@ def create_pipe_from_child_output() -> Tuple["ClosableReceiveStream", int]: elif os.name == "posix": - def create_pipe_to_child_stdin(): # noqa: F811 + def create_pipe_to_child_stdin(): rfd, wfd = os.pipe() return trio.lowlevel.FdStream(wfd), rfd - def create_pipe_from_child_output(): # noqa: F811 + def create_pipe_from_child_output(): rfd, wfd = os.pipe() return trio.lowlevel.FdStream(rfd), wfd diff --git a/trio/_subprocess_platform/kqueue.py b/src/trio/_subprocess_platform/kqueue.py similarity index 100% rename from trio/_subprocess_platform/kqueue.py rename to src/trio/_subprocess_platform/kqueue.py diff --git a/trio/_subprocess_platform/waitid.py b/src/trio/_subprocess_platform/waitid.py similarity index 95% rename from trio/_subprocess_platform/waitid.py rename to src/trio/_subprocess_platform/waitid.py index 756741218f..44c8261074 100644 --- a/trio/_subprocess_platform/waitid.py +++ b/src/trio/_subprocess_platform/waitid.py @@ -72,14 +72,14 @@ async def _waitid_system_task(pid: int, event: Event) -> None: """Spawn a thread that waits for ``pid`` to exit, then wake any tasks that were waiting on it. """ - # cancellable=True: if this task is cancelled, then we abandon the + # abandon_on_cancel=True: if this task is cancelled, then we abandon the # thread to keep running waitpid in the background. Since this is # always run as a system task, this will only happen if the whole # call to trio.run is shutting down. try: await to_thread_run_sync( - sync_wait_reapable, pid, cancellable=True, limiter=waitid_limiter + sync_wait_reapable, pid, abandon_on_cancel=True, limiter=waitid_limiter ) except OSError: # If waitid fails, waitpid will fail too, so it still makes diff --git a/trio/_subprocess_platform/windows.py b/src/trio/_subprocess_platform/windows.py similarity index 76% rename from trio/_subprocess_platform/windows.py rename to src/trio/_subprocess_platform/windows.py index 1634e74fa7..81fb960e4b 100644 --- a/trio/_subprocess_platform/windows.py +++ b/src/trio/_subprocess_platform/windows.py @@ -1,6 +1,10 @@ -from .. import _subprocess +from typing import TYPE_CHECKING + from .._wait_for_object import WaitForSingleObject +if TYPE_CHECKING: + from .. import _subprocess + async def wait_child_exiting(process: "_subprocess.Process") -> None: # _handle is not in Popen stubs, though it is present on Windows. diff --git a/trio/_sync.py b/src/trio/_sync.py similarity index 96% rename from trio/_sync.py rename to src/trio/_sync.py index 951ff892ea..6e62eceeff 100644 --- a/trio/_sync.py +++ b/src/trio/_sync.py @@ -3,7 +3,7 @@ import math from typing import TYPE_CHECKING, Protocol -import attr +import attrs import trio @@ -18,7 +18,7 @@ from ._core._parking_lot import ParkingLotStatistics -@attr.s(frozen=True, slots=True) +@attrs.frozen class EventStatistics: """An object containing debugging information. @@ -29,11 +29,11 @@ class EventStatistics: """ - tasks_waiting: int = attr.ib() + tasks_waiting: int @final -@attr.s(repr=False, eq=False, hash=False, slots=True) +@attrs.define(repr=False, eq=False, hash=False) class Event: """A waitable boolean value useful for inter-task synchronization, inspired by :class:`threading.Event`. @@ -60,8 +60,8 @@ class Event: """ - _tasks: set[Task] = attr.ib(factory=set, init=False) - _flag: bool = attr.ib(default=False, init=False) + _tasks: set[Task] = attrs.field(factory=set, init=False) + _flag: bool = attrs.field(default=False, init=False) def is_set(self) -> bool: """Return the current value of the internal flag.""" @@ -109,11 +109,9 @@ def statistics(self) -> EventStatistics: class _HasAcquireRelease(Protocol): """Only classes with acquire() and release() can use the mixin's implementations.""" - async def acquire(self) -> object: - ... + async def acquire(self) -> object: ... - def release(self) -> object: - ... + def release(self) -> object: ... class AsyncContextManagerMixin: @@ -131,7 +129,7 @@ async def __aexit__( self.release() -@attr.s(frozen=True, slots=True) +@attrs.frozen class CapacityLimiterStatistics: """An object containing debugging information. @@ -150,10 +148,10 @@ class CapacityLimiterStatistics: """ - borrowed_tokens: int = attr.ib() - total_tokens: int | float = attr.ib() - borrowers: list[Task | object] = attr.ib() - tasks_waiting: int = attr.ib() + borrowed_tokens: int + total_tokens: int | float + borrowers: list[Task | object] + tasks_waiting: int # Can be a generic type with a default of Task if/when PEP 696 is released @@ -225,9 +223,7 @@ def __init__(self, total_tokens: int | float): # noqa: PYI041 assert self._total_tokens == total_tokens def __repr__(self) -> str: - return "".format( - id(self), len(self._borrowers), self._total_tokens, len(self._lot) - ) + return f"" @property def total_tokens(self) -> int | float: @@ -523,7 +519,7 @@ def statistics(self) -> ParkingLotStatistics: return self._lot.statistics() -@attr.s(frozen=True, slots=True) +@attrs.frozen class LockStatistics: """An object containing debugging information for a Lock. @@ -537,15 +533,15 @@ class LockStatistics: """ - locked: bool = attr.ib() - owner: Task | None = attr.ib() - tasks_waiting: int = attr.ib() + locked: bool + owner: Task | None + tasks_waiting: int -@attr.s(eq=False, hash=False, repr=False) +@attrs.define(eq=False, hash=False, repr=False, slots=False) class _LockImpl(AsyncContextManagerMixin): - _lot: ParkingLot = attr.ib(factory=ParkingLot, init=False) - _owner: Task | None = attr.ib(default=None, init=False) + _lot: ParkingLot = attrs.field(factory=ParkingLot, init=False) + _owner: Task | None = attrs.field(default=None, init=False) def __repr__(self) -> str: if self.locked(): @@ -709,7 +705,7 @@ class StrictFIFOLock(_LockImpl): """ -@attr.s(frozen=True, slots=True) +@attrs.frozen class ConditionStatistics: r"""An object containing debugging information for a Condition. @@ -721,8 +717,9 @@ class ConditionStatistics: :class:`Lock`\s :meth:`~Lock.statistics` method. """ - tasks_waiting: int = attr.ib() - lock_statistics: LockStatistics = attr.ib() + + tasks_waiting: int + lock_statistics: LockStatistics @final diff --git a/trio/_tests/__init__.py b/src/trio/_tests/__init__.py similarity index 100% rename from trio/_tests/__init__.py rename to src/trio/_tests/__init__.py diff --git a/src/trio/_tests/_check_type_completeness.json b/src/trio/_tests/_check_type_completeness.json new file mode 100644 index 0000000000..0bbd47fada --- /dev/null +++ b/src/trio/_tests/_check_type_completeness.json @@ -0,0 +1,58 @@ +{ + "Darwin": [ + "No docstring found for function \"trio._unix_pipes.FdStream.send_all\"", + "No docstring found for function \"trio._unix_pipes.FdStream.wait_send_all_might_not_block\"", + "No docstring found for function \"trio._unix_pipes.FdStream.receive_some\"", + "No docstring found for function \"trio._unix_pipes.FdStream.close\"", + "No docstring found for function \"trio._unix_pipes.FdStream.aclose\"", + "No docstring found for function \"trio._unix_pipes.FdStream.fileno\"" + ], + "Linux": [ + "No docstring found for class \"trio._core._io_epoll._EpollStatistics\"", + "No docstring found for function \"trio._unix_pipes.FdStream.send_all\"", + "No docstring found for function \"trio._unix_pipes.FdStream.wait_send_all_might_not_block\"", + "No docstring found for function \"trio._unix_pipes.FdStream.receive_some\"", + "No docstring found for function \"trio._unix_pipes.FdStream.close\"", + "No docstring found for function \"trio._unix_pipes.FdStream.aclose\"", + "No docstring found for function \"trio._unix_pipes.FdStream.fileno\"" + ], + "Windows": [], + "all": [ + "No docstring found for class \"trio.MemoryReceiveChannel\"", + "No docstring found for class \"trio._channel.MemoryReceiveChannel\"", + "No docstring found for function \"trio._channel.MemoryReceiveChannel.statistics\"", + "No docstring found for class \"trio._channel.MemoryChannelStats\"", + "No docstring found for function \"trio._channel.MemoryReceiveChannel.aclose\"", + "No docstring found for class \"trio.MemorySendChannel\"", + "No docstring found for class \"trio._channel.MemorySendChannel\"", + "No docstring found for function \"trio._channel.MemorySendChannel.statistics\"", + "No docstring found for function \"trio._channel.MemorySendChannel.aclose\"", + "No docstring found for class \"trio._core._run.Task\"", + "No docstring found for class \"trio._socket.SocketType\"", + "No docstring found for function \"trio._highlevel_socket.SocketStream.send_all\"", + "No docstring found for function \"trio._highlevel_socket.SocketStream.wait_send_all_might_not_block\"", + "No docstring found for function \"trio._highlevel_socket.SocketStream.send_eof\"", + "No docstring found for function \"trio._highlevel_socket.SocketStream.receive_some\"", + "No docstring found for function \"trio._highlevel_socket.SocketStream.aclose\"", + "No docstring found for function \"trio._subprocess.HasFileno.fileno\"", + "No docstring found for class \"trio._sync.AsyncContextManagerMixin\"", + "No docstring found for function \"trio._sync._HasAcquireRelease.acquire\"", + "No docstring found for function \"trio._sync._HasAcquireRelease.release\"", + "No docstring found for class \"trio._sync._LockImpl\"", + "No docstring found for class \"trio._core._local._NoValue\"", + "No docstring found for class \"trio._core._local.RunVarToken\"", + "No docstring found for class \"trio.lowlevel.RunVarToken\"", + "No docstring found for class \"trio.lowlevel.Task\"", + "No docstring found for class \"trio._core._ki.KIProtectionSignature\"", + "No docstring found for class \"trio.socket.SocketType\"", + "No docstring found for class \"trio.socket.gaierror\"", + "No docstring found for class \"trio.socket.herror\"", + "No docstring found for function \"trio._core._mock_clock.MockClock.start_clock\"", + "No docstring found for function \"trio._core._mock_clock.MockClock.current_time\"", + "No docstring found for function \"trio._core._mock_clock.MockClock.deadline_to_sleep_time\"", + "No docstring found for function \"trio.testing._raises_group._ExceptionInfo.exconly\"", + "No docstring found for function \"trio.testing._raises_group._ExceptionInfo.errisinstance\"", + "No docstring found for function \"trio.testing._raises_group._ExceptionInfo.getrepr\"", + "No docstring found for function \"trio.testing._raises_group.RaisesGroup.expected_type\"" + ] +} diff --git a/trio/_tests/astrill-codesigning-cert.cer b/src/trio/_tests/astrill-codesigning-cert.cer similarity index 100% rename from trio/_tests/astrill-codesigning-cert.cer rename to src/trio/_tests/astrill-codesigning-cert.cer diff --git a/src/trio/_tests/check_type_completeness.py b/src/trio/_tests/check_type_completeness.py new file mode 100755 index 0000000000..fa6ace074f --- /dev/null +++ b/src/trio/_tests/check_type_completeness.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python3 +"""This is a file that wraps calls to `pyright --verifytypes`, achieving two things: +1. give an error if docstrings are missing. + pyright will give a number of missing docstrings, and error messages, but not exit with a non-zero value. +2. filter out specific errors we don't care about. + this is largely due to 1, but also because Trio does some very complex stuff and --verifytypes has few to no ways of ignoring specific errors. + +If this check is giving you false alarms, you can ignore them by adding logic to `has_docstring_at_runtime`, in the main loop in `check_type`, or by updating the json file. +""" +from __future__ import annotations + +# this file is not run as part of the tests, instead it's run standalone from check.sh +import argparse +import json +import subprocess +import sys +from pathlib import Path + +import trio +import trio.testing + +# not needed if everything is working, but if somebody does something to generate +# tons of errors, we can be nice and stop them from getting 3*tons of output +printed_diagnostics: set[str] = set() + + +# TODO: consider checking manually without `--ignoreexternal`, and/or +# removing it from the below call later on. +def run_pyright(platform: str) -> subprocess.CompletedProcess[bytes]: + return subprocess.run( + [ + "pyright", + # Specify a platform and version to keep imported modules consistent. + f"--pythonplatform={platform}", + "--pythonversion=3.8", + "--verifytypes=trio", + "--outputjson", + "--ignoreexternal", + ], + capture_output=True, + ) + + +def has_docstring_at_runtime(name: str) -> bool: + """Pyright gives us an object identifier of xx.yy.zz + This function tries to decompose that into its constituent parts, such that we + can resolve it, in order to check whether it has a `__doc__` at runtime and + verifytypes misses it because we're doing overly fancy stuff. + """ + # This assert is solely for stopping isort from removing our imports of trio & trio.testing + # It could also be done with isort:skip, but that'd also disable import sorting and the like. + assert trio.testing + + # figure out what part of the name is the module, so we can "import" it + name_parts = name.split(".") + assert name_parts[0] == "trio" + if name_parts[1] == "tests": + return True + + # traverse down the remaining identifiers with getattr + obj = trio + try: + for obj_name in name_parts[1:]: + obj = getattr(obj, obj_name) + except AttributeError as exc: + # asynciowrapper does funky getattr stuff + if "AsyncIOWrapper" in str(exc) or name in ( + # Symbols not existing on all platforms, so we can't dynamically inspect them. + # Manually confirmed to have docstrings but pyright doesn't see them due to + # export shenanigans. TODO: actually manually confirm that. + # In theory we could verify these at runtime, probably by running the script separately + # on separate platforms. It might also be a decent idea to work the other way around, + # a la test_static_tool_sees_class_members + # darwin + "trio.lowlevel.current_kqueue", + "trio.lowlevel.monitor_kevent", + "trio.lowlevel.wait_kevent", + "trio._core._io_kqueue._KqueueStatistics", + # windows + "trio._socket.SocketType.share", + "trio._core._io_windows._WindowsStatistics", + "trio._core._windows_cffi.Handle", + "trio.lowlevel.current_iocp", + "trio.lowlevel.monitor_completion_key", + "trio.lowlevel.readinto_overlapped", + "trio.lowlevel.register_with_iocp", + "trio.lowlevel.wait_overlapped", + "trio.lowlevel.write_overlapped", + "trio.lowlevel.WaitForSingleObject", + "trio.socket.fromshare", + # linux + # this test will fail on linux, but I don't develop on linux. So the next + # person to do so is very welcome to open a pull request and populate with + # objects + # TODO: these are erroring on all platforms, why? + "trio._highlevel_generic.StapledStream.send_stream", + "trio._highlevel_generic.StapledStream.receive_stream", + "trio._ssl.SSLStream.transport_stream", + "trio._file_io._HasFileNo", + "trio._file_io._HasFileNo.fileno", + ): + return True + + else: + print( + f"Pyright sees {name} at runtime, but unable to getattr({obj.__name__}, {obj_name}).", + file=sys.stderr, + ) + return False + return bool(obj.__doc__) + + +def check_type( + platform: str, full_diagnostics_file: Path | None, expected_errors: list[object] +) -> list[object]: + # convince isort we use the trio import + assert trio + + # run pyright, load output into json + res = run_pyright(platform) + current_result = json.loads(res.stdout) + + if res.stderr: + print(res.stderr, file=sys.stderr) + + if full_diagnostics_file: + with open(full_diagnostics_file, "a") as f: + json.dump(current_result, f, sort_keys=True, indent=4) + + errors = [] + + for symbol in current_result["typeCompleteness"]["symbols"]: + diagnostics = symbol["diagnostics"] + name = symbol["name"] + for diagnostic in diagnostics: + message = diagnostic["message"] + if name in ( + "trio._path.PosixPath", + "trio._path.WindowsPath", + ) and message.startswith("Type of base class "): + continue + + if name.startswith("trio._path.Path"): + if message.startswith("No docstring found for"): + continue + if message.startswith( + "Type is missing type annotation and could be inferred differently by type checkers" + ): + continue + + # ignore errors about missing docstrings if they're available at runtime + if message.startswith("No docstring found for"): + if has_docstring_at_runtime(symbol["name"]): + continue + else: + # Missing docstring messages include the name of the object. + # Other errors don't, so we add it. + message = f"{name}: {message}" + if message not in expected_errors and message not in printed_diagnostics: + print(f"new error: {message}", file=sys.stderr) + errors.append(message) + printed_diagnostics.add(message) + + continue + + return errors + + +def main(args: argparse.Namespace) -> int: + if args.full_diagnostics_file: + full_diagnostics_file = Path(args.full_diagnostics_file) + full_diagnostics_file.write_text("") + else: + full_diagnostics_file = None + + errors_by_platform_file = Path(__file__).parent / "_check_type_completeness.json" + if errors_by_platform_file.exists(): + with open(errors_by_platform_file) as f: + errors_by_platform = json.load(f) + else: + errors_by_platform = {"Linux": [], "Windows": [], "Darwin": [], "all": []} + + changed = False + for platform in "Linux", "Windows", "Darwin": + platform_errors = errors_by_platform[platform] + errors_by_platform["all"] + print("*" * 20, f"\nChecking {platform}...") + errors = check_type(platform, full_diagnostics_file, platform_errors) + + new_errors = [e for e in errors if e not in platform_errors] + missing_errors = [e for e in platform_errors if e not in errors] + + if new_errors: + print( + f"New errors introduced in `pyright --verifytypes`. Fix them, or ignore them by modifying {errors_by_platform_file}, either manually or with '--overwrite-file'.", + file=sys.stderr, + ) + changed = True + if missing_errors: + print( + f"Congratulations, you have resolved existing errors! Please remove them from {errors_by_platform_file}, either manually or with '--overwrite-file'.", + file=sys.stderr, + ) + changed = True + print(missing_errors, file=sys.stderr) + + errors_by_platform[platform] = errors + print("*" * 20) + + # cut down the size of the json file by a lot, and make it easier to parse for + # humans, by moving errors that appear on all platforms to a separate category + errors_by_platform["all"] = [] + for e in errors_by_platform["Linux"].copy(): + if e in errors_by_platform["Darwin"] and e in errors_by_platform["Windows"]: + for platform in "Linux", "Windows", "Darwin": + errors_by_platform[platform].remove(e) + errors_by_platform["all"].append(e) + + if changed and args.overwrite_file: + with open(errors_by_platform_file, "w") as f: + json.dump(errors_by_platform, f, indent=4, sort_keys=True) + # newline at end of file + f.write("\n") + + # True -> 1 -> non-zero exit value -> error + return changed + + +parser = argparse.ArgumentParser() +parser.add_argument( + "--overwrite-file", + action="store_true", + default=False, + help="Use this flag to overwrite the current stored results. Either in CI together with a diff check, or to avoid having to manually correct it.", +) +parser.add_argument( + "--full-diagnostics-file", + type=Path, + default=None, + help="Use this for debugging, it will dump the output of all three pyright runs by platform into this file.", +) +args = parser.parse_args() + +assert __name__ == "__main__", "This script should be run standalone" +sys.exit(main(args)) diff --git a/trio/_tests/module_with_deprecations.py b/src/trio/_tests/module_with_deprecations.py similarity index 100% rename from trio/_tests/module_with_deprecations.py rename to src/trio/_tests/module_with_deprecations.py diff --git a/trio/_tests/pytest_plugin.py b/src/trio/_tests/pytest_plugin.py similarity index 100% rename from trio/_tests/pytest_plugin.py rename to src/trio/_tests/pytest_plugin.py diff --git a/trio/_tests/test_abc.py b/src/trio/_tests/test_abc.py similarity index 95% rename from trio/_tests/test_abc.py rename to src/trio/_tests/test_abc.py index e5e8260f9c..74e1ccf424 100644 --- a/trio/_tests/test_abc.py +++ b/src/trio/_tests/test_abc.py @@ -1,6 +1,6 @@ from __future__ import annotations -import attr +import attrs import pytest from .. import abc as tabc @@ -30,9 +30,9 @@ def test_instrument_implements_hook_methods() -> None: async def test_AsyncResource_defaults() -> None: - @attr.s + @attrs.define(slots=False) class MyAR(tabc.AsyncResource): - record: list[str] = attr.ib(factory=list) + record: list[str] = attrs.Factory(list) async def aclose(self) -> None: self.record.append("ac") diff --git a/trio/_tests/test_channel.py b/src/trio/_tests/test_channel.py similarity index 95% rename from trio/_tests/test_channel.py rename to src/trio/_tests/test_channel.py index f991fd551c..c82c6767ed 100644 --- a/trio/_tests/test_channel.py +++ b/src/trio/_tests/test_channel.py @@ -13,7 +13,7 @@ async def test_channel() -> None: with pytest.raises(TypeError): open_memory_channel(1.0) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="^max_buffer_size must be >= 0$"): open_memory_channel(-1) s, r = open_memory_channel[Union[int, str, None]](2) @@ -76,9 +76,7 @@ async def producer(send_channel: trio.MemorySendChannel[int], i: int) -> None: for i in range(10): nursery.start_soon(producer, send_channel.clone(), i) - got = [] - async for value in receive_channel: - got.append(value) + got = [value async for value in receive_channel] got.sort() assert got == list(range(30)) @@ -151,17 +149,17 @@ async def receive_block(r: trio.MemoryReceiveChannel[int]) -> None: with pytest.raises(trio.ClosedResourceError): await r.receive() - s, r = open_memory_channel[None](0) + s2, r2 = open_memory_channel[int](0) async with trio.open_nursery() as nursery: - nursery.start_soon(receive_block, r) + nursery.start_soon(receive_block, r2) await wait_all_tasks_blocked() - await r.aclose() + await r2.aclose() # and it's persistent with pytest.raises(trio.ClosedResourceError): - r.receive_nowait() + r2.receive_nowait() with pytest.raises(trio.ClosedResourceError): - await r.receive() + await r2.receive() async def test_close_sync() -> None: @@ -204,7 +202,7 @@ async def send_block( await s.send(None) # closing receive -> other receive gets ClosedResourceError - async def receive_block(r: trio.MemoryReceiveChannel[int]) -> None: + async def receive_block(r: trio.MemoryReceiveChannel[None]) -> None: with pytest.raises(trio.ClosedResourceError): await r.receive() @@ -291,16 +289,14 @@ async def receive_will_succeed() -> None: async def test_inf_capacity() -> None: - s, r = open_memory_channel[int](float("inf")) + send, receive = open_memory_channel[int](float("inf")) # It's accepted, and we can send all day without blocking - with s: + with send: for i in range(10): - s.send_nowait(i) + send.send_nowait(i) - got = [] - async for i in r: - got.append(i) + got = [i async for i in receive] assert got == list(range(10)) @@ -366,9 +362,9 @@ async def test_channel_fairness() -> None: # But if someone else is waiting to receive, then they "own" the item we # send, so we can't receive it (even though we run first): - result = None + result: int | None = None - async def do_receive(r: trio.MemoryReceiveChannel[int]) -> None: + async def do_receive(r: trio.MemoryReceiveChannel[int | None]) -> None: nonlocal result result = await r.receive() diff --git a/trio/_tests/test_contextvars.py b/src/trio/_tests/test_contextvars.py similarity index 100% rename from trio/_tests/test_contextvars.py rename to src/trio/_tests/test_contextvars.py diff --git a/trio/_tests/test_deprecate.py b/src/trio/_tests/test_deprecate.py similarity index 83% rename from trio/_tests/test_deprecate.py rename to src/trio/_tests/test_deprecate.py index efacb27b3a..fa5d7cbfef 100644 --- a/trio/_tests/test_deprecate.py +++ b/src/trio/_tests/test_deprecate.py @@ -38,7 +38,7 @@ def deprecated_thing() -> None: deprecated_thing() filename, lineno = _here() assert len(recwarn_always) == 1 - got = recwarn_always.pop(TrioDeprecationWarning) + got = recwarn_always.pop(DeprecationWarning) assert isinstance(got.message, Warning) assert "ice is deprecated" in got.message.args[0] assert "Trio 1.2" in got.message.args[0] @@ -54,7 +54,7 @@ def test_warn_deprecated_no_instead_or_issue( # Explicitly no instead or issue warn_deprecated("water", "1.3", issue=None, instead=None) assert len(recwarn_always) == 1 - got = recwarn_always.pop(TrioDeprecationWarning) + got = recwarn_always.pop(DeprecationWarning) assert isinstance(got.message, Warning) assert "water is deprecated" in got.message.args[0] assert "no replacement" in got.message.args[0] @@ -70,7 +70,7 @@ def nested2() -> None: filename, lineno = _here() nested1() - got = recwarn_always.pop(TrioDeprecationWarning) + got = recwarn_always.pop(DeprecationWarning) assert got.filename == filename assert got.lineno == lineno + 1 @@ -85,7 +85,7 @@ def new() -> None: # pragma: no cover def test_warn_deprecated_formatting(recwarn_always: pytest.WarningsRecorder) -> None: warn_deprecated(old, "1.0", issue=1, instead=new) - got = recwarn_always.pop(TrioDeprecationWarning) + got = recwarn_always.pop(DeprecationWarning) assert isinstance(got.message, Warning) assert "test_deprecate.old is deprecated" in got.message.args[0] assert "test_deprecate.new instead" in got.message.args[0] @@ -98,7 +98,7 @@ def deprecated_old() -> int: def test_deprecated_decorator(recwarn_always: pytest.WarningsRecorder) -> None: assert deprecated_old() == 3 - got = recwarn_always.pop(TrioDeprecationWarning) + got = recwarn_always.pop(DeprecationWarning) assert isinstance(got.message, Warning) assert "test_deprecate.deprecated_old is deprecated" in got.message.args[0] assert "1.5" in got.message.args[0] @@ -115,7 +115,7 @@ def method(self) -> int: def test_deprecated_decorator_method(recwarn_always: pytest.WarningsRecorder) -> None: f = Foo() assert f.method() == 7 - got = recwarn_always.pop(TrioDeprecationWarning) + got = recwarn_always.pop(DeprecationWarning) assert isinstance(got.message, Warning) assert "test_deprecate.Foo.method is deprecated" in got.message.args[0] @@ -129,7 +129,7 @@ def test_deprecated_decorator_with_explicit_thing( recwarn_always: pytest.WarningsRecorder, ) -> None: assert deprecated_with_thing() == 72 - got = recwarn_always.pop(TrioDeprecationWarning) + got = recwarn_always.pop(DeprecationWarning) assert isinstance(got.message, Warning) assert "the thing is deprecated" in got.message.args[0] @@ -143,7 +143,7 @@ def new_hotness() -> str: def test_deprecated_alias(recwarn_always: pytest.WarningsRecorder) -> None: assert old_hotness() == "new hotness" - got = recwarn_always.pop(TrioDeprecationWarning) + got = recwarn_always.pop(DeprecationWarning) assert isinstance(got.message, Warning) assert "test_deprecate.old_hotness is deprecated" in got.message.args[0] assert "1.23" in got.message.args[0] @@ -168,7 +168,7 @@ def new_hotness_method(self) -> str: def test_deprecated_alias_method(recwarn_always: pytest.WarningsRecorder) -> None: obj = Alias() assert obj.old_hotness_method() == "new hotness method" - got = recwarn_always.pop(TrioDeprecationWarning) + got = recwarn_always.pop(DeprecationWarning) assert isinstance(got.message, Warning) msg = got.message.args[0] assert "test_deprecate.Alias.old_hotness_method is deprecated" in msg @@ -243,7 +243,7 @@ def test_module_with_deprecations(recwarn_always: pytest.WarningsRecorder) -> No filename, lineno = _here() assert module_with_deprecations.dep1 == "value1" # type: ignore[attr-defined] - got = recwarn_always.pop(TrioDeprecationWarning) + got = recwarn_always.pop(DeprecationWarning) assert isinstance(got.message, Warning) assert got.filename == filename assert got.lineno == lineno + 1 @@ -254,7 +254,7 @@ def test_module_with_deprecations(recwarn_always: pytest.WarningsRecorder) -> No assert "value1 instead" in got.message.args[0] assert module_with_deprecations.dep2 == "value2" # type: ignore[attr-defined] - got = recwarn_always.pop(TrioDeprecationWarning) + got = recwarn_always.pop(DeprecationWarning) assert isinstance(got.message, Warning) assert "instead-string instead" in got.message.args[0] @@ -262,30 +262,15 @@ def test_module_with_deprecations(recwarn_always: pytest.WarningsRecorder) -> No module_with_deprecations.asdf # type: ignore[attr-defined] # noqa: B018 # "useless expression" -def test_tests_is_deprecated1() -> None: - with pytest.warns(TrioDeprecationWarning): - from trio import tests # warning on import - - # warning on access of any member - with pytest.warns(TrioDeprecationWarning): - assert tests.test_abc # type: ignore[attr-defined] - - -def test_tests_is_deprecated2() -> None: - # warning on direct import of test since that accesses `__spec__` - with pytest.warns(TrioDeprecationWarning): - import trio.tests - - with pytest.warns(TrioDeprecationWarning): - assert trio.tests.test_deprecate # type: ignore[attr-defined] - - -def test_tests_is_deprecated3() -> None: - import trio +def test_warning_class() -> None: + with pytest.deprecated_call(): + warn_deprecated("foo", "bar", issue=None, instead=None) - # no warning on accessing the submodule - assert trio.tests + # essentially the same as the above check + with pytest.warns(DeprecationWarning): + warn_deprecated("foo", "bar", issue=None, instead=None) - # only when accessing a submodule member with pytest.warns(TrioDeprecationWarning): - assert trio.tests.test_abc # type: ignore[attr-defined] + warn_deprecated( + "foo", "bar", issue=None, instead=None, use_triodeprecationwarning=True + ) diff --git a/src/trio/_tests/test_deprecate_strict_exception_groups_false.py b/src/trio/_tests/test_deprecate_strict_exception_groups_false.py new file mode 100644 index 0000000000..317672bf23 --- /dev/null +++ b/src/trio/_tests/test_deprecate_strict_exception_groups_false.py @@ -0,0 +1,61 @@ +from typing import Awaitable, Callable + +import pytest + +import trio + + +async def test_deprecation_warning_open_nursery() -> None: + with pytest.warns( + trio.TrioDeprecationWarning, match="strict_exception_groups=False" + ) as record: + async with trio.open_nursery(strict_exception_groups=False): + ... + assert len(record) == 1 + async with trio.open_nursery(strict_exception_groups=True): + ... + async with trio.open_nursery(): + ... + + +def test_deprecation_warning_run() -> None: + async def foo() -> None: ... + + async def foo_nursery() -> None: + # this should not raise a warning, even if it's implied loose + async with trio.open_nursery(): + ... + + async def foo_loose_nursery() -> None: + # this should raise a warning, even if specifying the parameter is redundant + async with trio.open_nursery(strict_exception_groups=False): + ... + + def helper(fun: Callable[..., Awaitable[None]], num: int) -> None: + with pytest.warns( + trio.TrioDeprecationWarning, match="strict_exception_groups=False" + ) as record: + trio.run(fun, strict_exception_groups=False) + assert len(record) == num + + helper(foo, 1) + helper(foo_nursery, 1) + helper(foo_loose_nursery, 2) + + +def test_deprecation_warning_start_guest_run() -> None: + # "The simplest possible "host" loop." + from .._core._tests.test_guest_mode import trivial_guest_run + + async def trio_return(in_host: object) -> str: + await trio.lowlevel.checkpoint() + return "ok" + + with pytest.warns( + trio.TrioDeprecationWarning, match="strict_exception_groups=False" + ) as record: + trivial_guest_run( + trio_return, + strict_exception_groups=False, + ) + assert len(record) == 1 diff --git a/trio/_tests/test_dtls.py b/src/trio/_tests/test_dtls.py similarity index 98% rename from trio/_tests/test_dtls.py rename to src/trio/_tests/test_dtls.py index 7e63f6d2c9..d14edae25c 100644 --- a/trio/_tests/test_dtls.py +++ b/src/trio/_tests/test_dtls.py @@ -1,12 +1,11 @@ from __future__ import annotations import random -from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from itertools import count -from typing import NoReturn +from typing import TYPE_CHECKING, NoReturn -import attr +import attrs import pytest from trio._tests.pytest_plugin import skip_if_optional_else_raise @@ -25,13 +24,16 @@ from .._core._tests.tutil import binds_ipv6, gc_collect_harder, slow +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + ca = trustme.CA() server_cert = ca.issue_cert("example.com") -server_ctx = SSL.Context(SSL.DTLS_METHOD) # type: ignore[attr-defined] +server_ctx = SSL.Context(SSL.DTLS_METHOD) server_cert.configure_cert(server_ctx) -client_ctx = SSL.Context(SSL.DTLS_METHOD) # type: ignore[attr-defined] +client_ctx = SSL.Context(SSL.DTLS_METHOD) ca.configure_trust(client_ctx) @@ -96,7 +98,9 @@ async def test_smoke(ipv6: bool) -> None: await client_channel.send(b"goodbye") assert await client_channel.receive() == b"goodbye" - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="^openssl doesn't support sending empty DTLS packets$" + ): await client_channel.send(b"") client_channel.set_ciphertext_mtu(1234) @@ -156,7 +160,7 @@ async def route_packet(packet: UDPPacket) -> None: # elif op == "distort": # payload = bytearray(packet.payload) # payload[r.randrange(len(payload))] ^= 1 << r.randrange(8) - # packet = attr.evolve(packet, payload=payload) + # packet = attrs.evolve(packet, payload=payload) else: assert op == "deliver" print( @@ -276,18 +280,17 @@ async def test_client_multiplex() -> None: with pytest.raises(trio.ClosedResourceError): client_endpoint.connect(address1, client_ctx) + async def null_handler(_: object) -> None: # pragma: no cover + pass + async with trio.open_nursery() as nursery: with pytest.raises(trio.ClosedResourceError): - - async def null_handler(_: object) -> None: # pragma: no cover - pass - await nursery.start(client_endpoint.serve, server_ctx, null_handler) async def test_dtls_over_dgram_only() -> None: with trio.socket.socket() as s: - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="^DTLS requires a SOCK_DGRAM socket$"): DTLSEndpoint(s) @@ -482,7 +485,7 @@ def route_packet(packet: UDPPacket) -> None: offset = len(payload) - 1 cscope.cancel() payload[offset] ^= 0x01 - packet = attr.evolve(packet, payload=payload) + packet = attrs.evolve(packet, payload=payload) fn.deliver_packet(packet) diff --git a/trio/_tests/test_exports.py b/src/trio/_tests/test_exports.py similarity index 84% rename from trio/_tests/test_exports.py rename to src/trio/_tests/test_exports.py index 7b38137887..32a2666e48 100644 --- a/trio/_tests/test_exports.py +++ b/src/trio/_tests/test_exports.py @@ -10,10 +10,9 @@ import socket as stdlib_socket import sys import types -from collections.abc import Iterator -from pathlib import Path +from pathlib import Path, PurePath from types import ModuleType -from typing import Iterable, Protocol +from typing import TYPE_CHECKING, Protocol import attrs import pytest @@ -26,6 +25,9 @@ from .._core._tests.tutil import slow from .pytest_plugin import RUN_SLOW +if TYPE_CHECKING: + from collections.abc import Iterable, Iterator + mypy_cache_updated = False @@ -114,7 +116,7 @@ def iter_modules( # they might be using a newer version of Python with additional symbols which # won't be reflected in trio.socket, and this shouldn't cause downstream test # runs to start failing. -@pytest.mark.redistributors_should_skip +@pytest.mark.redistributors_should_skip() # Static analysis tools often have trouble with alpha releases, where Python's # internals are in flux, grammar may not have settled down, etc. @pytest.mark.skipif( @@ -144,13 +146,6 @@ def no_underscores(symbols: Iterable[str]) -> set[str]: if getattr(module, name, None) is getattr(__future__, name): runtime_names.remove(name) - if tool in ("mypy", "pyright_verifytypes"): - # create py.typed file - py_typed_path = Path(trio.__file__).parent / "py.typed" - py_typed_exists = py_typed_path.exists() - if not py_typed_exists: # pragma: no branch - py_typed_path.write_text("") - if tool == "pylint": try: from pylint.lint import PyLinter @@ -162,6 +157,9 @@ def no_underscores(symbols: Iterable[str]) -> set[str]: ast = linter.get_ast(module.__file__, modname) static_names = no_underscores(ast) # type: ignore[arg-type] elif tool == "jedi": + if sys.implementation.name != "cpython": + pytest.skip("jedi does not support pypy") + try: import jedi except ImportError as error: @@ -185,12 +183,13 @@ def no_underscores(symbols: Iterable[str]) -> set[str]: _, modname = (modname + ".").split(".", 1) modname = modname[:-1] mod_cache = trio_cache / modname if modname else trio_cache - if mod_cache.is_dir(): + if mod_cache.is_dir(): # pragma: no coverage mod_cache = mod_cache / "__init__.data.json" else: mod_cache = trio_cache / (modname + ".data.json") - assert mod_cache.exists() and mod_cache.is_file() + assert mod_cache.exists() + assert mod_cache.is_file() with mod_cache.open() as cache_file: cache_json = json.loads(cache_file.read()) static_names = no_underscores( @@ -200,7 +199,8 @@ def no_underscores(symbols: Iterable[str]) -> set[str]: ) elif tool == "pyright_verifytypes": if not RUN_SLOW: # pragma: no cover - pytest.skip("use --run-slow to check against mypy") + pytest.skip("use --run-slow to check against pyright") + try: import pyright # noqa: F401 except ImportError as error: @@ -218,30 +218,9 @@ def no_underscores(symbols: Iterable[str]) -> set[str]: for x in current_result["typeCompleteness"]["symbols"] if x["name"].startswith(modname) } - - # pyright ignores the symbol defined behind `if False` - if modname == "trio": - static_names.add("testing") - - # these are hidden behind `if sys.platform != "win32" or not TYPE_CHECKING` - # so presumably pyright is parsing that if statement, in which case we don't - # care about them being missing. - if modname == "trio.socket" and sys.platform == "win32": - ignored_missing_names = {"if_indextoname", "if_nameindex", "if_nametoindex"} - assert static_names.isdisjoint(ignored_missing_names) - static_names.update(ignored_missing_names) - else: # pragma: no cover raise AssertionError() - # remove py.typed file - if tool in ("mypy", "pyright_verifytypes") and not py_typed_exists: - py_typed_path.unlink() - - # mypy handles errors with an `assert` in its branch - if tool == "mypy": - return - # It's expected that the static set will contain more names than the # runtime set: # - static tools are sometimes sloppy and include deleted names @@ -265,7 +244,7 @@ def no_underscores(symbols: Iterable[str]) -> set[str]: # modules, instead of once per class. @slow # see comment on test_static_tool_sees_all_symbols -@pytest.mark.redistributors_should_skip +@pytest.mark.redistributors_should_skip() # Static analysis tools often have trouble with alpha releases, where Python's # internals are in flux, grammar may not have settled down, etc. @pytest.mark.skipif( @@ -287,16 +266,9 @@ def no_hidden(symbols: Iterable[str]) -> set[str]: if (not symbol.startswith("_")) or symbol.startswith("__") } - py_typed_path = Path(trio.__file__).parent / "py.typed" - py_typed_exists = py_typed_path.exists() - if tool == "mypy": if sys.implementation.name != "cpython": pytest.skip("mypy not installed in tests on pypy") - # create py.typed file - # remove this logic when trio is marked with py.typed proper - if not py_typed_exists: # pragma: no branch - py_typed_path.write_text("") cache = Path.cwd() / ".mypy_cache" @@ -312,7 +284,8 @@ def no_hidden(symbols: Iterable[str]) -> set[str]: else: mod_cache = trio_cache / (modname + ".data.json") - assert mod_cache.exists() and mod_cache.is_file() + assert mod_cache.exists() + assert mod_cache.is_file() with mod_cache.open() as cache_file: cache_json = json.loads(cache_file.read()) @@ -329,11 +302,12 @@ def lookup_symbol(symbol: str) -> dict[str, str]: for piece in modname[:-1]: mod_cache /= piece next_cache = mod_cache / modname[-1] - if next_cache.is_dir(): + if next_cache.is_dir(): # pragma: no coverage mod_cache = next_cache / "__init__.data.json" else: mod_cache = mod_cache / (modname[-1] + ".data.json") - + elif mod_cache.is_dir(): + mod_cache /= "__init__.data.json" with mod_cache.open() as f: return json.loads(f.read())["names"][name] # type: ignore[no-any-return] @@ -343,9 +317,9 @@ def lookup_symbol(symbol: str) -> dict[str, str]: continue if module_name == "trio.socket" and class_name in dir(stdlib_socket): continue - # Deprecated classes are exported with a leading underscore - # We don't care about errors in _MultiError as that's on its way out anyway - if class_name.startswith("_"): # pragma: no cover + + # ignore class that does dirty tricks + if class_ is trio.testing.RaisesGroup: continue # dir() and inspect.getmembers doesn't display properties from the metaclass @@ -372,6 +346,11 @@ def lookup_symbol(symbol: str) -> dict[str, str]: "__deepcopy__", } + if type(class_) is type: + # C extension classes don't have these dunders, but Python classes do + ignore_names.add("__firstlineno__") + ignore_names.add("__static_attributes__") + # pypy seems to have some additional dunders that differ if sys.implementation.name == "pypy": ignore_names |= { @@ -451,15 +430,18 @@ def lookup_symbol(symbol: str) -> dict[str, str]: if ( tool == "mypy" and enum.Enum in class_.__mro__ - and sys.version_info >= (3, 11) + and sys.version_info >= (3, 12) ): - extra.difference_update({"__copy__", "__deepcopy__"}) + # Another attribute, in 3.12+ only. + extra.remove("__signature__") # TODO: this *should* be visible via `dir`!! if tool == "mypy" and class_ == trio.Nursery: extra.remove("cancel_scope") - # TODO: I'm not so sure about these, but should still be looked at. + # These are (mostly? solely?) *runtime* attributes, often set in + # __init__, which doesn't show up with dir() or inspect.getmembers, + # but we get them in the way we query mypy & jedi EXTRAS = { trio.DTLSChannel: {"peer_address", "endpoint"}, trio.DTLSEndpoint: {"socket", "incoming_packets_buffer"}, @@ -474,18 +456,17 @@ def lookup_symbol(symbol: str) -> dict[str, str]: "send_all_hook", "wait_send_all_might_not_block_hook", }, + trio.testing.Matcher: { + "exception_type", + "match", + "check", + }, } if tool == "mypy" and class_ in EXTRAS: before = len(extra) extra -= EXTRAS[class_] assert len(extra) == before - len(EXTRAS[class_]) - # probably an issue with mypy.... - if tool == "mypy" and class_ == trio.Path and sys.platform == "win32": - before = len(missing) - missing -= {"owner", "group", "is_mount"} - assert len(missing) == before - 3 - # TODO: why is this? Is it a problem? # see https://github.com/python-trio/trio/pull/2631#discussion_r1185615916 if class_ == trio.StapledStream: @@ -508,25 +489,23 @@ def lookup_symbol(symbol: str) -> dict[str, str]: missing.remove("__aiter__") missing.remove("__anext__") - # __getattr__ is intentionally hidden behind type guard. That hook then - # forwards property accesses to PurePath, meaning these names aren't directly on - # the class. - if class_ == trio.Path: - missing.remove("__getattr__") - before = len(extra) - extra -= { - "anchor", - "drive", - "name", - "parent", - "parents", - "parts", - "root", - "stem", - "suffix", - "suffixes", - } - assert len(extra) == before - 10 + if class_ in (trio.Path, trio.WindowsPath, trio.PosixPath): + # These are from inherited subclasses. + missing -= PurePath.__dict__.keys() + # These are unix-only. + if tool == "mypy" and sys.platform == "win32": + missing -= {"owner", "is_mount", "group"} + if tool == "jedi" and sys.platform == "win32": + extra -= {"owner", "is_mount", "group"} + + # not sure why jedi in particular ignores this (static?) method in 3.13 + # (especially given the method is from 3.12....) + if ( + tool == "jedi" + and sys.version_info >= (3, 13) + and class_ in (trio.Path, trio.WindowsPath, trio.PosixPath) + ): + missing.remove("with_segments") if missing or extra: # pragma: no cover errors[f"{module_name}.{class_name}"] = { @@ -534,10 +513,6 @@ def lookup_symbol(symbol: str) -> dict[str, str]: "extra": extra, } - # clean up created py.typed file - if tool == "mypy" and not py_typed_exists: - py_typed_path.unlink() - # `assert not errors` will not print the full content of errors, even with # `--verbose`, so we manually print it if errors: # pragma: no cover @@ -553,7 +528,7 @@ def test_nopublic_is_final() -> None: assert class_is_final(_util.NoPublicConstructor) # This is itself final. for module in ALL_MODULES: - for _name, class_ in module.__dict__.items(): + for class_ in module.__dict__.values(): if isinstance(class_, _util.NoPublicConstructor): assert class_is_final(class_) @@ -588,6 +563,9 @@ def test_classes_are_final() -> None: continue # ... insert other special cases here ... + # The `Path` class needs to support inheritance to allow `WindowsPath` and `PosixPath`. + if class_ is trio.Path: + continue # don't care about the *Statistics classes if name.endswith("Statistics"): continue diff --git a/src/trio/_tests/test_fakenet.py b/src/trio/_tests/test_fakenet.py new file mode 100644 index 0000000000..bde6db0191 --- /dev/null +++ b/src/trio/_tests/test_fakenet.py @@ -0,0 +1,301 @@ +import errno +import re +import socket +import sys + +import pytest + +import trio +from trio.testing._fake_net import FakeNet + +# ENOTCONN gives different messages on different platforms +if sys.platform == "linux": + ENOTCONN_MSG = r"^\[Errno 107\] (Transport endpoint is|Socket) not connected$" +elif sys.platform == "darwin": + ENOTCONN_MSG = r"^\[Errno 57\] Socket is not connected$" +else: + ENOTCONN_MSG = r"^\[Errno 10057\] Unknown error$" + + +def fn() -> FakeNet: + fn = FakeNet() + fn.enable() + return fn + + +async def test_basic_udp() -> None: + fn() + s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + + await s1.bind(("127.0.0.1", 0)) + ip, port = s1.getsockname() + assert ip == "127.0.0.1" + assert port != 0 + + with pytest.raises( + OSError, match=r"^\[\w+ \d+\] Invalid argument$" + ) as exc: # Cannot rebind. + await s1.bind(("192.0.2.1", 0)) + assert exc.value.errno == errno.EINVAL + + # Cannot bind multiple sockets to the same address + with pytest.raises( + OSError, match=r"^\[\w+ \d+\] (Address (already )?in use|Unknown error)$" + ) as exc: + await s2.bind(("127.0.0.1", port)) + assert exc.value.errno == errno.EADDRINUSE + + await s2.sendto(b"xyz", s1.getsockname()) + data, addr = await s1.recvfrom(10) + assert data == b"xyz" + assert addr == s2.getsockname() + await s1.sendto(b"abc", s2.getsockname()) + data, addr = await s2.recvfrom(10) + assert data == b"abc" + assert addr == s1.getsockname() + + +async def test_msg_trunc() -> None: + fn() + s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + await s1.bind(("127.0.0.1", 0)) + await s2.sendto(b"xyz", s1.getsockname()) + data, addr = await s1.recvfrom(10) + + +async def test_recv_methods() -> None: + """Test all recv methods for codecov""" + fn() + s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + + # receiving on an unbound socket is a bad idea (I think?) + with pytest.raises(NotImplementedError, match="code will most likely hang"): + await s2.recv(10) + + await s1.bind(("127.0.0.1", 0)) + ip, port = s1.getsockname() + assert ip == "127.0.0.1" + assert port != 0 + + # recvfrom + await s2.sendto(b"abc", s1.getsockname()) + data, addr = await s1.recvfrom(10) + assert data == b"abc" + assert addr == s2.getsockname() + + # recv + await s1.sendto(b"def", s2.getsockname()) + data = await s2.recv(10) + assert data == b"def" + + # recvfrom_into + assert await s1.sendto(b"ghi", s2.getsockname()) == 3 + buf = bytearray(10) + + with pytest.raises(NotImplementedError, match="^partial recvfrom_into$"): + (nbytes, addr) = await s2.recvfrom_into(buf, nbytes=2) + + (nbytes, addr) = await s2.recvfrom_into(buf) + assert nbytes == 3 + assert buf == b"ghi" + b"\x00" * 7 + assert addr == s1.getsockname() + + # recv_into + assert await s1.sendto(b"jkl", s2.getsockname()) == 3 + buf2 = bytearray(10) + nbytes = await s2.recv_into(buf2) + assert nbytes == 3 + assert buf2 == b"jkl" + b"\x00" * 7 + + if sys.platform == "linux" and sys.implementation.name == "cpython": + flags: int = socket.MSG_MORE + else: + flags = 1 + + # Send seems explicitly non-functional + with pytest.raises(OSError, match=ENOTCONN_MSG) as exc: + await s2.send(b"mno") + assert exc.value.errno == errno.ENOTCONN + with pytest.raises(NotImplementedError, match="^FakeNet send flags must be 0, not"): + await s2.send(b"mno", flags) + + # sendto errors + # it's successfully used earlier + with pytest.raises(NotImplementedError, match="^FakeNet send flags must be 0, not"): + await s2.sendto(b"mno", flags, s1.getsockname()) + with pytest.raises(TypeError, match="wrong number of arguments$"): + await s2.sendto(b"mno", flags, s1.getsockname(), "extra arg") # type: ignore[call-overload] + + +@pytest.mark.skipif( + sys.platform == "win32", reason="functions not in socket on windows" +) +async def test_nonwindows_functionality() -> None: + # mypy doesn't support a good way of aborting typechecking on different platforms + if sys.platform != "win32": # pragma: no branch + fn() + s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + await s2.bind(("127.0.0.1", 0)) + + # sendmsg + with pytest.raises(OSError, match=ENOTCONN_MSG) as exc: + await s2.sendmsg([b"mno"]) + assert exc.value.errno == errno.ENOTCONN + + assert await s1.sendmsg([b"jkl"], (), 0, s2.getsockname()) == 3 + (data, ancdata, msg_flags, addr) = await s2.recvmsg(10) + assert data == b"jkl" + assert ancdata == [] + assert msg_flags == 0 + assert addr == s1.getsockname() + + # TODO: recvmsg + + # recvmsg_into + assert await s1.sendto(b"xyzw", s2.getsockname()) == 4 + buf1 = bytearray(2) + buf2 = bytearray(3) + ret = await s2.recvmsg_into([buf1, buf2]) + (nbytes, ancdata, msg_flags, addr) = ret + assert nbytes == 4 + assert buf1 == b"xy" + assert buf2 == b"zw" + b"\x00" + assert ancdata == [] + assert msg_flags == 0 + assert addr == s1.getsockname() + + # recvmsg_into with MSG_TRUNC set + assert await s1.sendto(b"xyzwv", s2.getsockname()) == 5 + buf1 = bytearray(2) + ret = await s2.recvmsg_into([buf1]) + (nbytes, ancdata, msg_flags, addr) = ret + assert nbytes == 2 + assert buf1 == b"xy" + assert ancdata == [] + assert msg_flags == socket.MSG_TRUNC + assert addr == s1.getsockname() + + with pytest.raises( + AttributeError, match="^'FakeSocket' object has no attribute 'share'$" + ): + await s1.share(0) # type: ignore[attr-defined] + + +@pytest.mark.skipif( + sys.platform != "win32", reason="windows-specific fakesocket testing" +) +async def test_windows_functionality() -> None: + # mypy doesn't support a good way of aborting typechecking on different platforms + if sys.platform == "win32": # pragma: no branch + fn() + s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + await s1.bind(("127.0.0.1", 0)) + with pytest.raises( + AttributeError, match="^'FakeSocket' object has no attribute 'sendmsg'$" + ): + await s1.sendmsg([b"jkl"], (), 0, s2.getsockname()) # type: ignore[attr-defined] + with pytest.raises( + AttributeError, match="^'FakeSocket' object has no attribute 'recvmsg'$" + ): + s2.recvmsg(0) # type: ignore[attr-defined] + with pytest.raises( + AttributeError, + match="^'FakeSocket' object has no attribute 'recvmsg_into'$", + ): + s2.recvmsg_into([]) # type: ignore[attr-defined] + with pytest.raises(NotImplementedError): + s1.share(0) + + +async def test_basic_tcp() -> None: + fn() + with pytest.raises(NotImplementedError): + trio.socket.socket() + + +async def test_not_implemented_functions() -> None: + fn() + s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + + # getsockopt + with pytest.raises( + OSError, match=r"^FakeNet doesn't implement getsockopt\(\d, \d\)$" + ): + s1.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY) + + # setsockopt + with pytest.raises( + NotImplementedError, match="^FakeNet always has IPV6_V6ONLY=True$" + ): + s1.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, False) + with pytest.raises( + OSError, match=r"^FakeNet doesn't implement setsockopt\(\d+, \d+, \.\.\.\)$" + ): + s1.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, True) + with pytest.raises( + OSError, match=r"^FakeNet doesn't implement setsockopt\(\d+, \d+, \.\.\.\)$" + ): + s1.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + + # set_inheritable + s1.set_inheritable(False) + with pytest.raises( + NotImplementedError, match="^FakeNet can't make inheritable sockets$" + ): + s1.set_inheritable(True) + + # get_inheritable + assert not s1.get_inheritable() + + +async def test_getpeername() -> None: + fn() + s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + with pytest.raises(OSError, match=ENOTCONN_MSG) as exc: + s1.getpeername() + assert exc.value.errno == errno.ENOTCONN + + await s1.bind(("127.0.0.1", 0)) + + with pytest.raises( + AssertionError, + match="^This method seems to assume that self._binding has a remote UDPEndpoint$", + ): + s1.getpeername() + + +async def test_init() -> None: + fn() + with pytest.raises( + NotImplementedError, + match=re.escape( + f"FakeNet doesn't (yet) support type={trio.socket.SOCK_STREAM}" + ), + ): + s1 = trio.socket.socket() + + # getsockname on unbound ipv4 socket + s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) + assert s1.getsockname() == ("0.0.0.0", 0) + + # getsockname on bound ipv4 socket + await s1.bind(("0.0.0.0", 0)) + ip, port = s1.getsockname() + assert ip == "127.0.0.1" + assert port != 0 + + # getsockname on unbound ipv6 socket + s2 = trio.socket.socket(family=socket.AF_INET6, type=socket.SOCK_DGRAM) + assert s2.getsockname() == ("::", 0) + + # getsockname on bound ipv6 socket + await s2.bind(("::", 0)) + ip, port, *_ = s2.getsockname() + assert ip == "::1" + assert port != 0 + assert _ == [0, 0] diff --git a/trio/_tests/test_file_io.py b/src/trio/_tests/test_file_io.py similarity index 97% rename from trio/_tests/test_file_io.py rename to src/trio/_tests/test_file_io.py index 85b8324c56..cd02d4f768 100644 --- a/trio/_tests/test_file_io.py +++ b/src/trio/_tests/test_file_io.py @@ -1,9 +1,10 @@ +from __future__ import annotations + import importlib import io import os -import pathlib import re -from typing import List, Tuple +from typing import TYPE_CHECKING from unittest import mock from unittest.mock import sentinel @@ -13,6 +14,9 @@ from trio import _core, _file_io from trio._file_io import _FILE_ASYNC_METHODS, _FILE_SYNC_ATTRS, AsyncIOWrapper +if TYPE_CHECKING: + import pathlib + @pytest.fixture def path(tmp_path: pathlib.Path) -> str: @@ -109,7 +113,7 @@ def test_type_stubs_match_lists() -> None: pytest.fail("No TYPE CHECKING line?") # Now we should be at the type checking block. - found: List[Tuple[str, str]] = [] + found: list[tuple[str, str]] = [] for line in source: # pragma: no branch - expected to break early if line.strip() and not line.startswith(" " * 8): break # Dedented out of the if TYPE_CHECKING block. @@ -222,11 +226,9 @@ async def test_open_context_manager(path: pathlib.Path) -> None: async def test_async_iter() -> None: async_file = trio.wrap_file(io.StringIO("test\nfoo\nbar")) expected = list(async_file.wrapped) - result = [] async_file.wrapped.seek(0) - async for line in async_file: - result.append(line) + result = [line async for line in async_file] assert result == expected @@ -249,7 +251,7 @@ async def test_detach_rewraps_asynciobase(tmp_path: pathlib.Path) -> None: tmp_file = tmp_path / "filename" tmp_file.touch() # flake8-async does not like opening files in async mode - with open(tmp_file, mode="rb", buffering=0) as raw: # noqa: ASYNC101 + with open(tmp_file, mode="rb", buffering=0) as raw: # noqa: ASYNC230 buffered = io.BufferedReader(raw) async_file = trio.wrap_file(buffered) diff --git a/trio/_tests/test_highlevel_generic.py b/src/trio/_tests/test_highlevel_generic.py similarity index 88% rename from trio/_tests/test_highlevel_generic.py rename to src/trio/_tests/test_highlevel_generic.py index 3e9fc212a8..1fcedba749 100644 --- a/trio/_tests/test_highlevel_generic.py +++ b/src/trio/_tests/test_highlevel_generic.py @@ -2,16 +2,16 @@ from typing import NoReturn -import attr +import attrs import pytest from .._highlevel_generic import StapledStream from ..abc import ReceiveStream, SendStream -@attr.s +@attrs.define(slots=False) class RecordSendStream(SendStream): - record: list[str | tuple[str, object]] = attr.ib(factory=list) + record: list[str | tuple[str, object]] = attrs.Factory(list) async def send_all(self, data: object) -> None: self.record.append(("send_all", data)) @@ -23,9 +23,9 @@ async def aclose(self) -> None: self.record.append("aclose") -@attr.s +@attrs.define(slots=False) class RecordReceiveStream(ReceiveStream): - record: list[str | tuple[str, int | None]] = attr.ib(factory=list) + record: list[str | tuple[str, int | None]] = attrs.Factory(list) async def receive_some(self, max_bytes: int | None = None) -> bytes: self.record.append(("receive_some", max_bytes)) @@ -81,16 +81,16 @@ async def test_StapledStream_with_erroring_close() -> None: class BrokenSendStream(RecordSendStream): async def aclose(self) -> NoReturn: await super().aclose() - raise ValueError + raise ValueError("send error") class BrokenReceiveStream(RecordReceiveStream): async def aclose(self) -> NoReturn: await super().aclose() - raise ValueError + raise ValueError("recv error") stapled = StapledStream(BrokenSendStream(), BrokenReceiveStream()) - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError, match="^(send|recv) error$") as excinfo: await stapled.aclose() assert isinstance(excinfo.value.__context__, ValueError) diff --git a/trio/_tests/test_highlevel_open_tcp_listeners.py b/src/trio/_tests/test_highlevel_open_tcp_listeners.py similarity index 84% rename from trio/_tests/test_highlevel_open_tcp_listeners.py rename to src/trio/_tests/test_highlevel_open_tcp_listeners.py index 23d6f794e0..3196c9e533 100644 --- a/trio/_tests/test_highlevel_open_tcp_listeners.py +++ b/src/trio/_tests/test_highlevel_open_tcp_listeners.py @@ -6,13 +6,12 @@ from socket import AddressFamily, SocketKind from typing import TYPE_CHECKING, Any, Sequence, overload -import attr +import attrs import pytest import trio from trio import ( SocketListener, - TrioDeprecationWarning, open_tcp_listeners, open_tcp_stream, serve_tcp, @@ -77,7 +76,10 @@ async def test_open_tcp_listeners_ipv6_v6only() -> None: async with ipv6_listener: _, port, *_ = ipv6_listener.socket.getsockname() - with pytest.raises(OSError): + with pytest.raises( + OSError, + match=r"(Error|all attempts to) connect(ing)* to (\(')*127\.0\.0\.1(', |:)\d+(\): Connection refused| failed)$", + ): await open_tcp_stream("127.0.0.1", port) @@ -89,7 +91,10 @@ async def test_open_tcp_listeners_rebind() -> None: # SO_REUSEADDR set with stdlib_socket.socket() as probe: probe.setsockopt(stdlib_socket.SOL_SOCKET, stdlib_socket.SO_REUSEADDR, 1) - with pytest.raises(OSError): + with pytest.raises( + OSError, + match="(Address (already )?in use|An attempt was made to access a socket in a way forbidden by its access permissions)$", + ): probe.bind(sockaddr1) # Now use the first listener to set up some connections in various states, @@ -126,15 +131,15 @@ class FakeOSError(OSError): pass -@attr.s +@attrs.define(slots=False) class FakeSocket(tsocket.SocketType): - _family: AddressFamily = attr.ib(converter=AddressFamily) - _type: SocketKind = attr.ib(converter=SocketKind) - _proto: int = attr.ib() + _family: AddressFamily = attrs.field(converter=AddressFamily) + _type: SocketKind = attrs.field(converter=SocketKind) + _proto: int - closed: bool = attr.ib(default=False) - poison_listen: bool = attr.ib(default=False) - backlog: int | None = attr.ib(default=None) + closed: bool = False + poison_listen: bool = False + backlog: int | None = None @property def type(self) -> SocketKind: @@ -149,12 +154,10 @@ def proto(self) -> int: # pragma: no cover return self._proto @overload - def getsockopt(self, /, level: int, optname: int) -> int: - ... + def getsockopt(self, /, level: int, optname: int) -> int: ... @overload - def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: - ... + def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: ... def getsockopt( self, /, level: int, optname: int, buflen: int | None = None @@ -164,12 +167,12 @@ def getsockopt( raise AssertionError() # pragma: no cover @overload - def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: - ... + def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: ... @overload - def setsockopt(self, /, level: int, optname: int, value: None, optlen: int) -> None: - ... + def setsockopt( + self, /, level: int, optname: int, value: None, optlen: int + ) -> None: ... def setsockopt( self, @@ -195,25 +198,25 @@ def close(self) -> None: self.closed = True -@attr.s +@attrs.define(slots=False) class FakeSocketFactory(SocketFactory): - poison_after: int = attr.ib() - sockets: list[tsocket.SocketType] = attr.ib(factory=list) - raise_on_family: dict[AddressFamily, int] = attr.ib(factory=dict) # family => errno + poison_after: int + sockets: list[tsocket.SocketType] = attrs.Factory(list) + raise_on_family: dict[AddressFamily, int] = attrs.Factory(dict) # family => errno def socket( self, family: AddressFamily | int | None = None, - type: SocketKind | int | None = None, + type_: SocketKind | int | None = None, proto: int = 0, ) -> tsocket.SocketType: assert family is not None - assert type is not None + assert type_ is not None if isinstance(family, int) and not isinstance(family, AddressFamily): family = AddressFamily(family) # pragma: no cover if family in self.raise_on_family: raise OSError(self.raise_on_family[family], "nope") - sock = FakeSocket(family, type, proto) + sock = FakeSocket(family, type_, proto) self.poison_after -= 1 if self.poison_after == 0: sock.poison_listen = True @@ -221,13 +224,13 @@ def socket( return sock -@attr.s +@attrs.define(slots=False) class FakeHostnameResolver(HostnameResolver): - family_addr_pairs: Sequence[tuple[AddressFamily, str]] = attr.ib() + family_addr_pairs: Sequence[tuple[AddressFamily, str]] async def getaddrinfo( self, - host: bytes | str | None, + host: bytes | None, port: bytes | str | int | None, family: int = 0, type: int = 0, @@ -323,16 +326,15 @@ async def test_open_tcp_listeners_some_address_families_unavailable( should_succeed = try_families - fail_families if not should_succeed: - with pytest.raises(OSError) as exc_info: + with pytest.raises(OSError, match="This system doesn't support") as exc_info: await open_tcp_listeners(80, host="example.org") - assert "This system doesn't support" in str(exc_info.value) - if isinstance(exc_info.value.__cause__, BaseExceptionGroup): - for subexc in exc_info.value.__cause__.exceptions: - assert "nope" in str(subexc) - else: - assert isinstance(exc_info.value.__cause__, OSError) - assert "nope" in str(exc_info.value.__cause__) + # open_listeners always creates an exceptiongroup with the + # unsupported address families, regardless of the value of + # strict_exception_groups or number of unsupported families. + assert isinstance(exc_info.value.__cause__, BaseExceptionGroup) + for subexc in exc_info.value.__cause__.exceptions: + assert "nope" in str(subexc) else: listeners = await open_tcp_listeners(80) for listener in listeners: @@ -353,7 +355,7 @@ async def test_open_tcp_listeners_socket_fails_not_afnosupport() -> None: FakeHostnameResolver([(tsocket.AF_INET, "foo"), (tsocket.AF_INET6, "bar")]) ) - with pytest.raises(OSError) as exc_info: + with pytest.raises(OSError, match="nope") as exc_info: await open_tcp_listeners(80, host="example.org") assert exc_info.value.errno == errno.EINVAL assert exc_info.value.__cause__ is None @@ -384,17 +386,6 @@ async def test_open_tcp_listeners_backlog() -> None: assert listener.socket.backlog == expected # type: ignore[attr-defined] -async def test_open_tcp_listeners_backlog_inf_warning() -> None: - fsf = FakeSocketFactory(99) - tsocket.set_custom_socket_factory(fsf) - with pytest.warns(TrioDeprecationWarning): - listeners = await open_tcp_listeners(0, backlog=float("inf")) # type: ignore[arg-type] - assert listeners - for listener in listeners: - # `backlog` only exists on FakeSocket - assert listener.socket.backlog == 0xFFFF # type: ignore[attr-defined] - - async def test_open_tcp_listeners_backlog_float_error() -> None: fsf = FakeSocketFactory(99) tsocket.set_custom_socket_factory(fsf) diff --git a/trio/_tests/test_highlevel_open_tcp_stream.py b/src/trio/_tests/test_highlevel_open_tcp_stream.py similarity index 92% rename from trio/_tests/test_highlevel_open_tcp_stream.py rename to src/trio/_tests/test_highlevel_open_tcp_stream.py index 79dd8b0f78..ce1b1ac1de 100644 --- a/trio/_tests/test_highlevel_open_tcp_stream.py +++ b/src/trio/_tests/test_highlevel_open_tcp_stream.py @@ -5,7 +5,7 @@ from socket import AddressFamily, SocketKind from typing import TYPE_CHECKING, Any, Sequence -import attr +import attrs import pytest import trio @@ -16,6 +16,7 @@ reorder_for_rfc_6555_section_5_4, ) from trio.socket import AF_INET, AF_INET6, IPPROTO_TCP, SOCK_STREAM, SocketType +from trio.testing import Matcher, RaisesGroup if TYPE_CHECKING: from trio.testing import MockClock @@ -33,7 +34,7 @@ def close(self) -> None: class CloseKiller(SocketType): def close(self) -> None: - raise OSError + raise OSError("os error text") c: CloseMe = CloseMe() with close_all() as to_close: @@ -48,7 +49,7 @@ def close(self) -> None: assert c.closed c = CloseMe() - with pytest.raises(OSError): + with pytest.raises(OSError, match="os error text"): with close_all() as to_close: to_close.add(CloseKiller()) to_close.add(c) @@ -122,7 +123,7 @@ async def test_open_tcp_stream_real_socket_smoketest() -> None: async def test_open_tcp_stream_input_validation() -> None: - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="^host must be str or bytes, not None$"): await open_tcp_stream(None, 80) # type: ignore[arg-type] with pytest.raises(TypeError): await open_tcp_stream("127.0.0.1", b"80") # type: ignore[arg-type] @@ -170,7 +171,9 @@ async def test_local_address_real() -> None: # Trying to connect to an ipv4 address with the ipv6 wildcard # local_address should fail - with pytest.raises(OSError): + with pytest.raises( + OSError, match=r"^all attempts to connect* to *127\.0\.0\.\d:\d+ failed$" + ): await open_tcp_stream(*listener.getsockname(), local_address="::") # But the ipv4 wildcard address should work @@ -185,18 +188,18 @@ async def test_local_address_real() -> None: # Now, thorough tests using fake sockets -@attr.s(eq=False) +@attrs.define(eq=False, slots=False) class FakeSocket(trio.socket.SocketType): - scenario: Scenario = attr.ib() - _family: AddressFamily = attr.ib() - _type: SocketKind = attr.ib() - _proto: int = attr.ib() + scenario: Scenario + _family: AddressFamily + _type: SocketKind + _proto: int - ip: str | int | None = attr.ib(default=None) - port: str | int | None = attr.ib(default=None) - succeeded: bool = attr.ib(default=False) - closed: bool = attr.ib(default=False) - failing: bool = attr.ib(default=False) + ip: str | int | None = None + port: str | int | None = None + succeeded: bool = False + closed: bool = False + failing: bool = False @property def type(self) -> SocketKind: @@ -262,20 +265,18 @@ def __init__( def socket( self, family: AddressFamily | int | None = None, - type: SocketKind | int | None = None, + type_: SocketKind | int | None = None, proto: int | None = None, ) -> SocketType: assert isinstance(family, AddressFamily) - assert isinstance(type, SocketKind) + assert isinstance(type_, SocketKind) assert proto is not None if family not in self.supported_families: raise OSError("pretending not to support this family") self.socket_count += 1 - return FakeSocket(self, family, type, proto) + return FakeSocket(self, family, type_, proto) - def _ip_to_gai_entry( - self, ip: str - ) -> tuple[ + def _ip_to_gai_entry(self, ip: str) -> tuple[ AddressFamily, SocketKind, int, @@ -293,7 +294,7 @@ def _ip_to_gai_entry( async def getaddrinfo( self, - host: str | bytes | None, + host: bytes | None, port: bytes | str | int | None, family: int = -1, type: int = -1, @@ -472,6 +473,28 @@ async def test_custom_delay(autojump_clock: MockClock) -> None: } +async def test_none_default(autojump_clock: MockClock) -> None: + """Copy of test_basic_fallthrough, but specifying the delay =None""" + sock, scenario = await run_scenario( + 80, + [ + ("1.1.1.1", 1, "success"), + ("2.2.2.2", 1, "success"), + ("3.3.3.3", 0.2, "success"), + ], + happy_eyeballs_delay=None, + ) + assert isinstance(sock, FakeSocket) + assert sock.ip == "3.3.3.3" + # current time is default time + default time + connection time + assert trio.current_time() == (0.250 + 0.250 + 0.2) + assert scenario.connect_times == { + "1.1.1.1": 0, + "2.2.2.2": 0.250, + "3.3.3.3": 0.500, + } + + async def test_custom_errors_expedite(autojump_clock: MockClock) -> None: sock, scenario = await run_scenario( 80, @@ -506,8 +529,12 @@ async def test_all_fail(autojump_clock: MockClock) -> None: expect_error=OSError, ) assert isinstance(exc, OSError) - assert isinstance(exc.__cause__, BaseExceptionGroup) - assert len(exc.__cause__.exceptions) == 4 + + subexceptions = (Matcher(OSError, match="^sorry$"),) * 4 + assert RaisesGroup( + *subexceptions, match="all attempts to connect to test.example.com:80 failed" + ).matches(exc.__cause__) + assert trio.current_time() == (0.1 + 0.2 + 10) assert scenario.connect_times == { "1.1.1.1": 0, diff --git a/trio/_tests/test_highlevel_open_unix_stream.py b/src/trio/_tests/test_highlevel_open_unix_stream.py similarity index 81% rename from trio/_tests/test_highlevel_open_unix_stream.py rename to src/trio/_tests/test_highlevel_open_unix_stream.py index 0ff11209a7..38c31b8a4a 100644 --- a/trio/_tests/test_highlevel_open_unix_stream.py +++ b/src/trio/_tests/test_highlevel_open_unix_stream.py @@ -11,10 +11,12 @@ assert not TYPE_CHECKING or sys.platform != "win32" -if not hasattr(socket, "AF_UNIX"): - pytestmark = pytest.mark.skip("Needs unix socket support") +skip_if_not_unix = pytest.mark.skipif( + not hasattr(socket, "AF_UNIX"), reason="Needs unix socket support" +) +@skip_if_not_unix def test_close_on_error() -> None: class CloseMe: closed = False @@ -32,12 +34,14 @@ def close(self) -> None: assert c.closed +@skip_if_not_unix @pytest.mark.parametrize("filename", [4, 4.5]) async def test_open_with_bad_filename_type(filename: float) -> None: with pytest.raises(TypeError): await open_unix_socket(filename) # type: ignore[arg-type] +@skip_if_not_unix async def test_open_bad_socket() -> None: # mktemp is marked as insecure, but that's okay, we don't want the file to # exist @@ -46,6 +50,7 @@ async def test_open_bad_socket() -> None: await open_unix_socket(name) +@skip_if_not_unix async def test_open_unix_socket() -> None: for name_type in [Path, str]: name = tempfile.mktemp() @@ -69,3 +74,11 @@ async def test_open_unix_socket() -> None: assert received == b"response" finally: os.unlink(name) + + +@pytest.mark.skipif(hasattr(socket, "AF_UNIX"), reason="Test for non-unix platforms") +async def test_error_on_no_unix() -> None: + with pytest.raises( + RuntimeError, match="^Unix sockets are not supported on this platform$" + ): + await open_unix_socket("") diff --git a/trio/_tests/test_highlevel_serve_listeners.py b/src/trio/_tests/test_highlevel_serve_listeners.py similarity index 86% rename from trio/_tests/test_highlevel_serve_listeners.py rename to src/trio/_tests/test_highlevel_serve_listeners.py index 86fe1af2ad..a1457de3d8 100644 --- a/trio/_tests/test_highlevel_serve_listeners.py +++ b/src/trio/_tests/test_highlevel_serve_listeners.py @@ -2,37 +2,42 @@ import errno from functools import partial -from typing import Awaitable, Callable, NoReturn +from typing import TYPE_CHECKING, Awaitable, Callable, NoReturn -import attr -import pytest +import attrs import trio from trio import Nursery, StapledStream, TaskStatus -from trio._channel import MemoryReceiveChannel, MemorySendChannel -from trio.abc import Stream from trio.testing import ( + Matcher, MemoryReceiveStream, MemorySendStream, MockClock, + RaisesGroup, memory_stream_pair, wait_all_tasks_blocked, ) +if TYPE_CHECKING: + import pytest + + from trio._channel import MemoryReceiveChannel, MemorySendChannel + from trio.abc import Stream + # types are somewhat tentative - I just bruteforced them until I got something that didn't # give errors StapledMemoryStream = StapledStream[MemorySendStream, MemoryReceiveStream] -@attr.s(hash=False, eq=False) +@attrs.define(hash=False, eq=False, slots=False) class MemoryListener(trio.abc.Listener[StapledMemoryStream]): - closed: bool = attr.ib(default=False) - accepted_streams: list[trio.abc.Stream] = attr.ib(factory=list) + closed: bool = False + accepted_streams: list[trio.abc.Stream] = attrs.Factory(list) queued_streams: tuple[ MemorySendChannel[StapledMemoryStream], MemoryReceiveChannel[StapledMemoryStream], - ] = attr.ib(factory=(lambda: trio.open_memory_channel[StapledMemoryStream](1))) - accept_hook: Callable[[], Awaitable[object]] | None = attr.ib(default=None) + ] = attrs.Factory(lambda: trio.open_memory_channel[StapledMemoryStream](1)) + accept_hook: Callable[[], Awaitable[object]] | None = None async def connect(self) -> StapledMemoryStream: assert not self.closed @@ -108,11 +113,13 @@ async def test_serve_listeners_accept_unrecognized_error() -> None: async def raise_error() -> NoReturn: raise error # noqa: B023 # Set from loop + def check_error(e: BaseException) -> bool: + return e is error # noqa: B023 + listener.accept_hook = raise_error - with pytest.raises(type(error)) as excinfo: + with RaisesGroup(Matcher(check=check_error)): await trio.serve_listeners(None, [listener]) # type: ignore[arg-type] - assert excinfo.value is error async def test_serve_listeners_accept_capacity_error( @@ -156,7 +163,8 @@ async def connection_watcher( assert len(nursery.child_tasks) == 10 raise Done - with pytest.raises(Done): + # the exception is wrapped twice because we open two nested nurseries + with RaisesGroup(RaisesGroup(Done)): async with trio.open_nursery() as nursery: handler_nursery: trio.Nursery = await nursery.start(connection_watcher) await nursery.start( diff --git a/trio/_tests/test_highlevel_socket.py b/src/trio/_tests/test_highlevel_socket.py similarity index 92% rename from trio/_tests/test_highlevel_socket.py rename to src/trio/_tests/test_highlevel_socket.py index 61e891e94b..976a3b5e04 100644 --- a/trio/_tests/test_highlevel_socket.py +++ b/src/trio/_tests/test_highlevel_socket.py @@ -26,7 +26,9 @@ async def test_SocketStream_basics() -> None: # DGRAM socket bad with tsocket.socket(type=tsocket.SOCK_DGRAM) as sock: - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="^SocketStream requires a SOCK_STREAM socket$" + ): # TODO: does not raise an error? SocketStream(sock) @@ -152,7 +154,9 @@ async def test_SocketListener() -> None: # Not a SOCK_STREAM with tsocket.socket(type=tsocket.SOCK_DGRAM) as s: await s.bind(("127.0.0.1", 0)) - with pytest.raises(ValueError) as excinfo: + with pytest.raises( + ValueError, match="^SocketListener requires a SOCK_STREAM socket$" + ) as excinfo: SocketListener(s) excinfo.match(r".*SOCK_STREAM") @@ -161,7 +165,9 @@ async def test_SocketListener() -> None: if sys.platform != "darwin": with tsocket.socket() as s: await s.bind(("127.0.0.1", 0)) - with pytest.raises(ValueError) as excinfo: + with pytest.raises( + ValueError, match="^SocketListener requires a listening socket$" + ) as excinfo: SocketListener(s) excinfo.match(r".*listen") @@ -218,14 +224,12 @@ def __init__(self, events: Sequence[SocketType | BaseException]) -> None: # Fool the check for SO_ACCEPTCONN in SocketListener.__init__ @overload - def getsockopt(self, /, level: int, optname: int) -> int: - ... + def getsockopt(self, /, level: int, optname: int) -> int: ... @overload def getsockopt( # noqa: F811 self, /, level: int, optname: int, buflen: int - ) -> bytes: - ... + ) -> bytes: ... def getsockopt( # noqa: F811 self, /, level: int, optname: int, buflen: int | None = None @@ -233,14 +237,14 @@ def getsockopt( # noqa: F811 return True @overload - def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: - ... + def setsockopt( + self, /, level: int, optname: int, value: int | Buffer + ) -> None: ... @overload def setsockopt( # noqa: F811 self, /, level: int, optname: int, value: None, optlen: int - ) -> None: - ... + ) -> None: ... def setsockopt( # noqa: F811 self, @@ -281,9 +285,13 @@ async def accept(self) -> tuple[SocketType, object]: stream = await listener.accept() assert stream.socket is fake_server_sock - for code in [errno.EMFILE, errno.EFAULT, errno.ENOBUFS]: + for code, match in { + errno.EMFILE: r"\[\w+ \d+\] Out of file descriptors$", + errno.EFAULT: r"\[\w+ \d+\] attempt to write to read-only memory$", + errno.ENOBUFS: r"\[\w+ \d+\] out of buffers$", + }.items(): with assert_checkpoints(): - with pytest.raises(OSError) as excinfo: + with pytest.raises(OSError, match=match) as excinfo: await listener.accept() assert excinfo.value.errno == code diff --git a/trio/_tests/test_highlevel_ssl_helpers.py b/src/trio/_tests/test_highlevel_ssl_helpers.py similarity index 92% rename from trio/_tests/test_highlevel_ssl_helpers.py rename to src/trio/_tests/test_highlevel_ssl_helpers.py index 8e90adb3d2..53f687d7c3 100644 --- a/trio/_tests/test_highlevel_ssl_helpers.py +++ b/src/trio/_tests/test_highlevel_ssl_helpers.py @@ -1,29 +1,33 @@ from __future__ import annotations from functools import partial -from socket import AddressFamily, SocketKind -from ssl import SSLContext -from typing import Any, NoReturn +from typing import TYPE_CHECKING, Any, NoReturn -import attr +import attrs import pytest import trio import trio.testing -from trio.abc import Stream from trio.socket import AF_INET, IPPROTO_TCP, SOCK_STREAM -from .._highlevel_socket import SocketListener from .._highlevel_ssl_helpers import ( open_ssl_over_tcp_listeners, open_ssl_over_tcp_stream, serve_ssl_over_tcp, ) -from .._ssl import SSLListener # using noqa because linters don't understand how pytest fixtures work. from .test_ssl import SERVER_CTX, client_ctx # noqa: F401 +if TYPE_CHECKING: + from socket import AddressFamily, SocketKind + from ssl import SSLContext + + from trio.abc import Stream + + from .._highlevel_socket import SocketListener + from .._ssl import SSLListener + async def echo_handler(stream: Stream) -> None: async with stream: @@ -39,13 +43,13 @@ async def echo_handler(stream: Stream) -> None: # Resolver that always returns the given sockaddr, no matter what host/port # you ask for. -@attr.s +@attrs.define(slots=False) class FakeHostnameResolver(trio.abc.HostnameResolver): - sockaddr: tuple[str, int] | tuple[str, int, int, int] = attr.ib() + sockaddr: tuple[str, int] | tuple[str, int, int, int] async def getaddrinfo( self, - host: bytes | str | None, + host: bytes | None, port: bytes | str | int | None, family: int = 0, type: int = 0, diff --git a/trio/_tests/test_path.py b/src/trio/_tests/test_path.py similarity index 84% rename from trio/_tests/test_path.py rename to src/trio/_tests/test_path.py index 30158f8b18..af29a0604b 100644 --- a/trio/_tests/test_path.py +++ b/src/trio/_tests/test_path.py @@ -2,14 +2,15 @@ import os import pathlib -from collections.abc import Awaitable, Callable -from typing import Any, Type, Union +from typing import TYPE_CHECKING, Type, Union import pytest import trio from trio._file_io import AsyncIOWrapper -from trio._path import AsyncAutoWrapperType as WrapperType + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable @pytest.fixture @@ -26,6 +27,16 @@ def method_pair( return getattr(sync_path, method_name), getattr(async_path, method_name) +@pytest.mark.skipif(os.name == "nt", reason="OS is not posix") +async def test_instantiate_posix() -> None: + assert isinstance(trio.Path(), trio.PosixPath) + + +@pytest.mark.skipif(os.name != "nt", reason="OS is not Windows") +async def test_instantiate_windows() -> None: + assert isinstance(trio.Path(), trio.WindowsPath) + + async def test_open_is_async_context_manager(path: trio.Path) -> None: async with await path.open("w") as f: assert isinstance(f, AsyncIOWrapper) @@ -49,7 +60,7 @@ async def test_magic() -> None: ] -@pytest.mark.parametrize("cls_a,cls_b", cls_pairs) +@pytest.mark.parametrize(("cls_a", "cls_b"), cls_pairs) async def test_cmp_magic(cls_a: EitherPathType, cls_b: EitherPathType) -> None: a, b = cls_a(""), cls_b("") assert a == b @@ -76,7 +87,7 @@ async def test_cmp_magic(cls_a: EitherPathType, cls_b: EitherPathType) -> None: ] -@pytest.mark.parametrize("cls_a,cls_b", cls_pairs_str) +@pytest.mark.parametrize(("cls_a", "cls_b"), cls_pairs_str) async def test_div_magic(cls_a: PathOrStrType, cls_b: PathOrStrType) -> None: a, b = cls_a("a"), cls_b("b") @@ -87,7 +98,7 @@ async def test_div_magic(cls_a: PathOrStrType, cls_b: PathOrStrType) -> None: @pytest.mark.parametrize( - "cls_a,cls_b", [(trio.Path, pathlib.Path), (trio.Path, trio.Path)] + ("cls_a", "cls_b"), [(trio.Path, pathlib.Path), (trio.Path, trio.Path)] ) @pytest.mark.parametrize("path", ["foo", "foo/bar/baz", "./foo"]) async def test_hash_magic( @@ -111,7 +122,7 @@ async def test_async_method_signature(path: trio.Path) -> None: assert path.resolve.__qualname__ == "Path.resolve" assert path.resolve.__doc__ is not None - assert "pathlib.Path.resolve" in path.resolve.__doc__ + assert path.resolve.__qualname__ in path.resolve.__doc__ @pytest.mark.parametrize("method_name", ["is_dir", "is_file"]) @@ -166,41 +177,6 @@ async def test_repr() -> None: assert repr(path) == "trio.Path('.')" -class MockWrapped: - unsupported = "unsupported" - _private = "private" - - -class _MockWrapper: - _forwards = MockWrapped - _wraps = MockWrapped - - -MockWrapper: Any = _MockWrapper # Disable type checking, it's a mock. - - -async def test_type_forwards_unsupported() -> None: - with pytest.raises(TypeError): - WrapperType.generate_forwards(MockWrapper, {}) - - -async def test_type_wraps_unsupported() -> None: - with pytest.raises(TypeError): - WrapperType.generate_wraps(MockWrapper, {}) - - -async def test_type_forwards_private() -> None: - WrapperType.generate_forwards(MockWrapper, {"unsupported": None}) - - assert not hasattr(MockWrapper, "_private") - - -async def test_type_wraps_private() -> None: - WrapperType.generate_wraps(MockWrapper, {"unsupported": None}) - - assert not hasattr(MockWrapper, "_private") - - @pytest.mark.parametrize("meth", [trio.Path.__init__, trio.Path.joinpath]) async def test_path_wraps_path( path: trio.Path, @@ -276,3 +252,21 @@ async def test_classmethods() -> None: # Wrapped method has docstring assert trio.Path.home.__doc__ + + +@pytest.mark.parametrize( + "wrapper", + [ + trio._path._wraps_async, + trio._path._wrap_method, + trio._path._wrap_method_path, + trio._path._wrap_method_path_iterable, + ], +) +def test_wrapping_without_docstrings( + wrapper: Callable[[Callable[[], None]], Callable[[], None]] +) -> None: + @wrapper + def func_without_docstring() -> None: ... # pragma: no cover + + assert func_without_docstring.__doc__ is None diff --git a/src/trio/_tests/test_repl.py b/src/trio/_tests/test_repl.py new file mode 100644 index 0000000000..fbfdb07a05 --- /dev/null +++ b/src/trio/_tests/test_repl.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +import subprocess +import sys +from typing import Protocol + +import pytest + +import trio._repl + + +class RawInput(Protocol): + def __call__(self, prompt: str = "") -> str: ... + + +def build_raw_input(cmds: list[str]) -> RawInput: + """ + Pass in a list of strings. + Returns a callable that returns each string, each time its called + When there are not more strings to return, raise EOFError + """ + cmds_iter = iter(cmds) + prompts = [] + + def _raw_helper(prompt: str = "") -> str: + prompts.append(prompt) + try: + return next(cmds_iter) + except StopIteration: + raise EOFError from None + + return _raw_helper + + +def test_build_raw_input() -> None: + """Quick test of our helper function.""" + raw_input = build_raw_input(["cmd1"]) + assert raw_input() == "cmd1" + with pytest.raises(EOFError): + raw_input() + + +# In 3.10 or later, types.FunctionType (used internally) will automatically +# attach __builtins__ to the function objects. However we need to explicitly +# include it for 3.8 & 3.9 +def build_locals() -> dict[str, object]: + return {"__builtins__": __builtins__} + + +async def test_basic_interaction( + capsys: pytest.CaptureFixture[str], + monkeypatch: pytest.MonkeyPatch, +) -> None: + """ + Run some basic commands through the interpreter while capturing stdout. + Ensure that the interpreted prints the expected results. + """ + console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals()) + raw_input = build_raw_input( + [ + # evaluate simple expression and recall the value + "x = 1", + "print(f'{x=}')", + # Literal gets printed + "'hello'", + # define and call sync function + "def func():", + " print(x + 1)", + "", + "func()", + # define and call async function + "async def afunc():", + " return 4", + "", + "await afunc()", + # import works + "import sys", + "sys.stdout.write('hello stdout\\n')", + ] + ) + monkeypatch.setattr(console, "raw_input", raw_input) + await trio._repl.run_repl(console) + out, err = capsys.readouterr() + assert out.splitlines() == ["x=1", "'hello'", "2", "4", "hello stdout", "13"] + + +async def test_system_exits_quit_interpreter(monkeypatch: pytest.MonkeyPatch) -> None: + console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals()) + raw_input = build_raw_input( + [ + "raise SystemExit", + ] + ) + monkeypatch.setattr(console, "raw_input", raw_input) + with pytest.raises(SystemExit): + await trio._repl.run_repl(console) + + +async def test_KI_interrupts( + capsys: pytest.CaptureFixture[str], + monkeypatch: pytest.MonkeyPatch, +) -> None: + console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals()) + raw_input = build_raw_input( + [ + "from trio._util import signal_raise", + "import signal, trio, trio.lowlevel", + "async def f():", + " trio.lowlevel.spawn_system_task(" + " trio.to_thread.run_sync," + " signal_raise,signal.SIGINT," + " )", # just awaiting this kills the test runner?! + " await trio.sleep_forever()", + " print('should not see this')", + "", + "await f()", + "print('AFTER KeyboardInterrupt')", + ] + ) + monkeypatch.setattr(console, "raw_input", raw_input) + await trio._repl.run_repl(console) + out, err = capsys.readouterr() + assert "KeyboardInterrupt" in err + assert "should" not in out + assert "AFTER KeyboardInterrupt" in out + + +async def test_system_exits_in_exc_group( + capsys: pytest.CaptureFixture[str], + monkeypatch: pytest.MonkeyPatch, +) -> None: + console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals()) + raw_input = build_raw_input( + [ + "import sys", + "if sys.version_info < (3, 11):", + " from exceptiongroup import BaseExceptionGroup", + "", + "raise BaseExceptionGroup('', [RuntimeError(), SystemExit()])", + "print('AFTER BaseExceptionGroup')", + ] + ) + monkeypatch.setattr(console, "raw_input", raw_input) + await trio._repl.run_repl(console) + out, err = capsys.readouterr() + # assert that raise SystemExit in an exception group + # doesn't quit + assert "AFTER BaseExceptionGroup" in out + + +async def test_system_exits_in_nested_exc_group( + capsys: pytest.CaptureFixture[str], + monkeypatch: pytest.MonkeyPatch, +) -> None: + console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals()) + raw_input = build_raw_input( + [ + "import sys", + "if sys.version_info < (3, 11):", + " from exceptiongroup import BaseExceptionGroup", + "", + "raise BaseExceptionGroup(", + " '', [BaseExceptionGroup('', [RuntimeError(), SystemExit()])])", + "print('AFTER BaseExceptionGroup')", + ] + ) + monkeypatch.setattr(console, "raw_input", raw_input) + await trio._repl.run_repl(console) + out, err = capsys.readouterr() + # assert that raise SystemExit in an exception group + # doesn't quit + assert "AFTER BaseExceptionGroup" in out + + +async def test_base_exception_captured( + capsys: pytest.CaptureFixture[str], + monkeypatch: pytest.MonkeyPatch, +) -> None: + console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals()) + raw_input = build_raw_input( + [ + # The statement after raise should still get executed + "raise BaseException", + "print('AFTER BaseException')", + ] + ) + monkeypatch.setattr(console, "raw_input", raw_input) + await trio._repl.run_repl(console) + out, err = capsys.readouterr() + assert "_threads.py" not in err + assert "_repl.py" not in err + assert "AFTER BaseException" in out + + +async def test_exc_group_captured( + capsys: pytest.CaptureFixture[str], + monkeypatch: pytest.MonkeyPatch, +) -> None: + console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals()) + raw_input = build_raw_input( + [ + # The statement after raise should still get executed + "raise ExceptionGroup('', [KeyError()])", + "print('AFTER ExceptionGroup')", + ] + ) + monkeypatch.setattr(console, "raw_input", raw_input) + await trio._repl.run_repl(console) + out, err = capsys.readouterr() + assert "AFTER ExceptionGroup" in out + + +async def test_base_exception_capture_from_coroutine( + capsys: pytest.CaptureFixture[str], + monkeypatch: pytest.MonkeyPatch, +) -> None: + console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals()) + raw_input = build_raw_input( + [ + "async def async_func_raises_base_exception():", + " raise BaseException", + "", + # This will raise, but the statement after should still + # be executed + "await async_func_raises_base_exception()", + "print('AFTER BaseException')", + ] + ) + monkeypatch.setattr(console, "raw_input", raw_input) + await trio._repl.run_repl(console) + out, err = capsys.readouterr() + assert "_threads.py" not in err + assert "_repl.py" not in err + assert "AFTER BaseException" in out + + +def test_main_entrypoint() -> None: + """ + Basic smoke test when running via the package __main__ entrypoint. + """ + repl = subprocess.run([sys.executable, "-m", "trio"], input=b"exit()") + assert repl.returncode == 0 diff --git a/trio/_tests/test_scheduler_determinism.py b/src/trio/_tests/test_scheduler_determinism.py similarity index 80% rename from trio/_tests/test_scheduler_determinism.py rename to src/trio/_tests/test_scheduler_determinism.py index 7e0a8e98de..3c2299a015 100644 --- a/trio/_tests/test_scheduler_determinism.py +++ b/src/trio/_tests/test_scheduler_determinism.py @@ -1,9 +1,12 @@ from __future__ import annotations -from pytest import MonkeyPatch +from typing import TYPE_CHECKING import trio +if TYPE_CHECKING: + import pytest + async def scheduler_trace() -> tuple[tuple[str, int], ...]: """Returns a scheduler-dependent value we can use to check determinism.""" @@ -12,25 +15,23 @@ async def scheduler_trace() -> tuple[tuple[str, int], ...]: async def tracer(name: str) -> None: for i in range(50): trace.append((name, i)) - await trio.sleep(0) + await trio.lowlevel.checkpoint() async with trio.open_nursery() as nursery: for i in range(5): - nursery.start_soon(tracer, i) + nursery.start_soon(tracer, str(i)) return tuple(trace) def test_the_trio_scheduler_is_not_deterministic() -> None: # At least, not yet. See https://github.com/python-trio/trio/issues/32 - traces = [] - for _ in range(10): - traces.append(trio.run(scheduler_trace)) + traces = [trio.run(scheduler_trace) for _ in range(10)] assert len(set(traces)) == len(traces) def test_the_trio_scheduler_is_deterministic_if_seeded( - monkeypatch: MonkeyPatch, + monkeypatch: pytest.MonkeyPatch, ) -> None: monkeypatch.setattr(trio._core._run, "_ALLOW_DETERMINISTIC_SCHEDULING", True) traces = [] diff --git a/trio/_tests/test_signals.py b/src/trio/_tests/test_signals.py similarity index 96% rename from trio/_tests/test_signals.py rename to src/trio/_tests/test_signals.py index d0f1bd1c74..5e639652ef 100644 --- a/trio/_tests/test_signals.py +++ b/src/trio/_tests/test_signals.py @@ -1,17 +1,20 @@ from __future__ import annotations import signal -from types import FrameType -from typing import NoReturn +from typing import TYPE_CHECKING, NoReturn import pytest import trio +from trio.testing import RaisesGroup from .. import _core from .._signals import _signal_handler, get_pending_signal_count, open_signal_receiver from .._util import signal_raise +if TYPE_CHECKING: + from types import FrameType + async def test_open_signal_receiver() -> None: orig = signal.getsignal(signal.SIGILL) @@ -39,7 +42,9 @@ async def test_open_signal_receiver() -> None: async def test_open_signal_receiver_restore_handler_after_one_bad_signal() -> None: orig = signal.getsignal(signal.SIGILL) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match="(signal number out of range|invalid signal value)$" + ): with open_signal_receiver(signal.SIGILL, 1234567): pass # pragma: no cover # Still restored even if we errored out @@ -70,7 +75,7 @@ async def naughty() -> None: async def test_open_signal_receiver_conflict() -> None: - with pytest.raises(trio.BusyResourceError): + with RaisesGroup(trio.BusyResourceError): with open_signal_receiver(signal.SIGILL) as receiver: async with trio.open_nursery() as nursery: nursery.start_soon(receiver.__anext__) diff --git a/trio/_tests/test_socket.py b/src/trio/_tests/test_socket.py similarity index 93% rename from trio/_tests/test_socket.py rename to src/trio/_tests/test_socket.py index 3d37ab1c25..b98b3246e9 100644 --- a/trio/_tests/test_socket.py +++ b/src/trio/_tests/test_socket.py @@ -9,18 +9,19 @@ from socket import AddressFamily, SocketKind from typing import TYPE_CHECKING, Any, Callable, List, Tuple, Union -import attr +import attrs import pytest from .. import _core, socket as tsocket from .._core._tests.tutil import binds_ipv6, creates_ipv6 -from .._highlevel_socket import SocketStream from .._socket import _NUMERIC_ONLY, SocketType, _SocketType, _try_sync from ..testing import assert_checkpoints, wait_all_tasks_blocked if TYPE_CHECKING: from typing_extensions import TypeAlias + from .._highlevel_socket import SocketStream + GaiTuple: TypeAlias = Tuple[ AddressFamily, SocketKind, @@ -133,8 +134,8 @@ def interesting_fields( tuple[str, int] | tuple[str, int, int] | tuple[str, int, int, int], ]: # (family, type, proto, canonname, sockaddr) - family, type, proto, canonname, sockaddr = gai_tup - return (family, type, sockaddr) + family, type_, proto, canonname, sockaddr = gai_tup + return (family, type_, sockaddr) def filtered( gai_list: GetAddrInfoResponse, @@ -320,10 +321,11 @@ async def test_sniff_sockopts() -> None: from socket import AF_INET, AF_INET6, SOCK_DGRAM, SOCK_STREAM # generate the combinations of families/types we're testing: - sockets = [] - for family in [AF_INET, AF_INET6]: - for type in [SOCK_DGRAM, SOCK_STREAM]: - sockets.append(stdlib_socket.socket(family, type)) + sockets = [ + stdlib_socket.socket(family, type_) + for family in [AF_INET, AF_INET6] + for type_ in [SOCK_DGRAM, SOCK_STREAM] + ] for socket in sockets: # regular Trio socket constructor tsocket_socket = tsocket.socket(fileno=socket.fileno()) @@ -463,7 +465,7 @@ async def test_SocketType_shutdown() -> None: @pytest.mark.parametrize( - "address, socket_type", + ("address", "socket_type"), [ ("127.0.0.1", tsocket.AF_INET), pytest.param("::1", tsocket.AF_INET6, marks=binds_ipv6), @@ -511,17 +513,17 @@ def gai_without_v4mapped_is_buggy() -> bool: # pragma: no cover return True -@attr.s +@attrs.define(slots=False) class Addresses: - bind_all: str = attr.ib() - localhost: str = attr.ib() - arbitrary: str = attr.ib() - broadcast: str = attr.ib() + bind_all: str + localhost: str + arbitrary: str + broadcast: str # Direct thorough tests of the implicit resolver helpers @pytest.mark.parametrize( - "socket_type, addrs", + ("socket_type", "addrs"), [ ( tsocket.AF_INET, @@ -579,12 +581,14 @@ def assert_eq( for local in [False, True]: async def res( - args: tuple[str, int] - | tuple[str, int, int] - | tuple[str, int, int, int] - | tuple[str, str] - | tuple[str, str, int] - | tuple[str, str, int, int] + args: ( + tuple[str, int] + | tuple[str, int, int] + | tuple[str, int, int, int] + | tuple[str, str] + | tuple[str, str, int] + | tuple[str, str, int, int] + ) ) -> Any: return await sock._resolve_address_nocp( args, @@ -623,8 +627,8 @@ async def res( sock.setsockopt(tsocket.IPPROTO_IPV6, tsocket.IPV6_V6ONLY, True) with pytest.raises(tsocket.gaierror) as excinfo: await res(("1.2.3.4", 80)) - # Windows, macOS - expected_errnos = {tsocket.EAI_NONAME} + # Windows, macOS, musl/Linux + expected_errnos = {tsocket.EAI_NONAME, tsocket.EAI_NODATA} # Linux if hasattr(tsocket, "EAI_ADDRFAMILY"): expected_errnos.add(tsocket.EAI_ADDRFAMILY) @@ -647,11 +651,15 @@ async def res( ) netlink_sock.close() - with pytest.raises(ValueError): + address = r"^address should be a \(host, port(, \[flowinfo, \[scopeid\]\])*\) tuple$" + with pytest.raises(ValueError, match=address): await res("1.2.3.4") # type: ignore[arg-type] - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=address): await res(("1.2.3.4",)) # type: ignore[arg-type] - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match=address, + ): if v6: await res(("1.2.3.4", 80, 0, 0, 0)) # type: ignore[arg-type] else: @@ -755,7 +763,10 @@ async def t2() -> None: # This tests the complicated paths through connect async def test_SocketType_connect_paths() -> None: with tsocket.socket() as sock: - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match=r"^address should be a \(host, port(, \[flowinfo, \[scopeid\]\])*\) tuple$", + ): # Should be a tuple await sock.connect("localhost") @@ -800,7 +811,10 @@ def connect(self, *args: Any, **kwargs: Any) -> None: # Failed connect (hopefully after raising BlockingIOError) with tsocket.socket() as sock: - with pytest.raises(OSError): + with pytest.raises( + OSError, + match=r"^\[\w+ \d+\] Error connecting to \('127\.0\.0\.\d', \d+\): (Connection refused|Unknown error)$", + ): # TCP port 2 is not assigned. Pretty sure nothing will be # listening there. (We used to bind a port and then *not* call # listen() to ensure nothing was listening there, but it turns @@ -815,10 +829,11 @@ def connect(self, *args: Any, **kwargs: Any) -> None: async def test_address_in_socket_error() -> None: address = "127.0.0.1" with tsocket.socket() as sock: - try: + with pytest.raises( + OSError, + match=rf"^\[\w+ \d+\] Error connecting to \({address!r}, 2\): (Connection refused|Unknown error)$", + ): await sock.connect((address, 2)) - except OSError as e: - assert any(address in str(arg) for arg in e.args) async def test_resolve_address_exception_in_connect_closes_socket() -> None: @@ -965,7 +980,13 @@ async def test_custom_hostname_resolver(monkeygai: MonkeypatchedGAI) -> None: # This intentionally breaks the signatures used in HostnameResolver class CustomResolver: async def getaddrinfo( - self, host: str, port: str, family: int, type: int, proto: int, flags: int + self, + host: str, + port: str, + family: int, + type: int, + proto: int, + flags: int, ) -> tuple[str, str, str, int, int, int, int]: return ("custom_gai", host, port, family, type, proto, flags) @@ -1017,7 +1038,10 @@ async def getnameinfo( async def test_custom_socket_factory() -> None: class CustomSocketFactory: def socket( - self, family: AddressFamily, type: SocketKind, proto: int + self, + family: AddressFamily, + type: SocketKind, + proto: int, ) -> tuple[str, AddressFamily, SocketKind, int]: return ("hi", family, type, proto) @@ -1111,19 +1135,23 @@ async def receiver() -> None: async def test_many_sockets() -> None: total = 5000 # Must be more than MAX_AFD_GROUP_SIZE sockets = [] - for _x in range(total // 2): + # Open at most socket pairs + for opened in range(0, total, 2): try: a, b = stdlib_socket.socketpair() - except OSError as e: # pragma: no cover - assert e.errno in (errno.EMFILE, errno.ENFILE) + except OSError as exc: # pragma: no cover + # Semi-expecting following errors (sockets are files): + # EMFILE: "Too many open files" (reached kernel cap) + # ENFILE: "File table overflow" (beyond kernel cap) + assert exc.errno in (errno.EMFILE, errno.ENFILE) # noqa: PT017 + print(f"Unable to open more than {opened} sockets.") + # Stop opening any more sockets if too many are open break sockets += [a, b] async with _core.open_nursery() as nursery: - for s in sockets: - nursery.start_soon(_core.wait_readable, s) + for socket in sockets: + nursery.start_soon(_core.wait_readable, socket) await _core.wait_all_tasks_blocked() nursery.cancel_scope.cancel() - for sock in sockets: - sock.close() - if _x != total // 2 - 1: # pragma: no cover - print(f"Unable to open more than {(_x-1)*2} sockets.") + for socket in sockets: + socket.close() diff --git a/trio/_tests/test_ssl.py b/src/trio/_tests/test_ssl.py similarity index 93% rename from trio/_tests/test_ssl.py rename to src/trio/_tests/test_ssl.py index 13decd5c72..8e780b2f9c 100644 --- a/trio/_tests/test_ssl.py +++ b/src/trio/_tests/test_ssl.py @@ -8,16 +8,27 @@ from contextlib import asynccontextmanager, contextmanager, suppress from functools import partial from ssl import SSLContext -from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, NoReturn +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Awaitable, + Callable, + Iterator, + NoReturn, +) import pytest from trio import StapledStream -from trio._core import MockClock -from trio._ssl import T_Stream from trio._tests.pytest_plugin import skip_if_optional_else_raise from trio.abc import ReceiveStream, SendStream -from trio.testing import MemoryReceiveStream, MemorySendStream +from trio.testing import ( + Matcher, + MemoryReceiveStream, + MemorySendStream, + RaisesGroup, +) try: import trustme @@ -30,7 +41,6 @@ from .. import _core, socket as tsocket from .._abc import Stream from .._core import BrokenResourceError, ClosedResourceError -from .._core._run import CancelScope from .._core._tests.tutil import slow from .._highlevel_generic import aclose_forcefully from .._highlevel_open_tcp_stream import open_tcp_stream @@ -48,6 +58,11 @@ if TYPE_CHECKING: from typing_extensions import TypeAlias + from trio._core import MockClock + from trio._ssl import T_Stream + + from .._core._run import CancelScope + # We have two different kinds of echo server fixtures we use for testing. The # first is a real server written using the stdlib ssl module and blocking # sockets. It runs in a thread and we talk to it over a real socketpair(), to @@ -342,33 +357,21 @@ async def test_PyOpenSSLEchoStream_gives_resource_busy_errors() -> None: # PyOpenSSLEchoStream, so this makes sure that if we do have a bug then # PyOpenSSLEchoStream will notice and complain. - s = PyOpenSSLEchoStream() - with pytest.raises(_core.BusyResourceError) as excinfo: - async with _core.open_nursery() as nursery: - nursery.start_soon(s.send_all, b"x") - nursery.start_soon(s.send_all, b"x") - assert "simultaneous" in str(excinfo.value) - - s = PyOpenSSLEchoStream() - with pytest.raises(_core.BusyResourceError) as excinfo: - async with _core.open_nursery() as nursery: - nursery.start_soon(s.send_all, b"x") - nursery.start_soon(s.wait_send_all_might_not_block) - assert "simultaneous" in str(excinfo.value) - - s = PyOpenSSLEchoStream() - with pytest.raises(_core.BusyResourceError) as excinfo: - async with _core.open_nursery() as nursery: - nursery.start_soon(s.wait_send_all_might_not_block) - nursery.start_soon(s.wait_send_all_might_not_block) - assert "simultaneous" in str(excinfo.value) + async def do_test( + func1: str, args1: tuple[object, ...], func2: str, args2: tuple[object, ...] + ) -> None: + s = PyOpenSSLEchoStream() + with RaisesGroup(Matcher(_core.BusyResourceError, "simultaneous")): + async with _core.open_nursery() as nursery: + nursery.start_soon(getattr(s, func1), *args1) + nursery.start_soon(getattr(s, func2), *args2) - s = PyOpenSSLEchoStream() - with pytest.raises(_core.BusyResourceError) as excinfo: - async with _core.open_nursery() as nursery: - nursery.start_soon(s.receive_some, 1) - nursery.start_soon(s.receive_some, 1) - assert "simultaneous" in str(excinfo.value) + await do_test("send_all", (b"x",), "send_all", (b"x",)) + await do_test("send_all", (b"x",), "wait_send_all_might_not_block", ()) + await do_test( + "wait_send_all_might_not_block", (), "wait_send_all_might_not_block", () + ) + await do_test("receive_some", (1,), "receive_some", (1,)) @contextmanager # type: ignore[misc] # decorated contains Any @@ -406,9 +409,10 @@ def ssl_wrap_pair( MemoryStapledStream: TypeAlias = StapledStream[MemorySendStream, MemoryReceiveStream] -def ssl_memory_stream_pair( - client_ctx: SSLContext, **kwargs: Any -) -> tuple[SSLStream[MemoryStapledStream], SSLStream[MemoryStapledStream],]: +def ssl_memory_stream_pair(client_ctx: SSLContext, **kwargs: Any) -> tuple[ + SSLStream[MemoryStapledStream], + SSLStream[MemoryStapledStream], +]: client_transport, server_transport = memory_stream_pair() return ssl_wrap_pair(client_ctx, client_transport, server_transport, **kwargs) @@ -416,9 +420,10 @@ def ssl_memory_stream_pair( MyStapledStream: TypeAlias = StapledStream[SendStream, ReceiveStream] -def ssl_lockstep_stream_pair( - client_ctx: SSLContext, **kwargs: Any -) -> tuple[SSLStream[MyStapledStream], SSLStream[MyStapledStream],]: +def ssl_lockstep_stream_pair(client_ctx: SSLContext, **kwargs: Any) -> tuple[ + SSLStream[MyStapledStream], + SSLStream[MyStapledStream], +]: client_transport, server_transport = lockstep_stream_pair() return ssl_wrap_pair(client_ctx, client_transport, server_transport, **kwargs) @@ -725,45 +730,35 @@ async def sleeper_with_slow_wait_writable_and_expect(method: str) -> None: async def test_resource_busy_errors(client_ctx: SSLContext) -> None: - async def do_send_all() -> None: + S: TypeAlias = trio.SSLStream[ + trio.StapledStream[trio.abc.SendStream, trio.abc.ReceiveStream] + ] + + async def do_send_all(s: S) -> None: with assert_checkpoints(): await s.send_all(b"x") - async def do_receive_some() -> None: + async def do_receive_some(s: S) -> None: with assert_checkpoints(): await s.receive_some(1) - async def do_wait_send_all_might_not_block() -> None: + async def do_wait_send_all_might_not_block(s: S) -> None: with assert_checkpoints(): await s.wait_send_all_might_not_block() - s, _ = ssl_lockstep_stream_pair(client_ctx) - with pytest.raises(_core.BusyResourceError) as excinfo: - async with _core.open_nursery() as nursery: - nursery.start_soon(do_send_all) - nursery.start_soon(do_send_all) - assert "another task" in str(excinfo.value) - - s, _ = ssl_lockstep_stream_pair(client_ctx) - with pytest.raises(_core.BusyResourceError) as excinfo: - async with _core.open_nursery() as nursery: - nursery.start_soon(do_receive_some) - nursery.start_soon(do_receive_some) - assert "another task" in str(excinfo.value) - - s, _ = ssl_lockstep_stream_pair(client_ctx) - with pytest.raises(_core.BusyResourceError) as excinfo: - async with _core.open_nursery() as nursery: - nursery.start_soon(do_send_all) - nursery.start_soon(do_wait_send_all_might_not_block) - assert "another task" in str(excinfo.value) + async def do_test( + func1: Callable[[S], Awaitable[None]], func2: Callable[[S], Awaitable[None]] + ) -> None: + s, _ = ssl_lockstep_stream_pair(client_ctx) + with RaisesGroup(Matcher(_core.BusyResourceError, "another task")): + async with _core.open_nursery() as nursery: + nursery.start_soon(func1, s) + nursery.start_soon(func2, s) - s, _ = ssl_lockstep_stream_pair(client_ctx) - with pytest.raises(_core.BusyResourceError) as excinfo: - async with _core.open_nursery() as nursery: - nursery.start_soon(do_wait_send_all_might_not_block) - nursery.start_soon(do_wait_send_all_might_not_block) - assert "another task" in str(excinfo.value) + await do_test(do_send_all, do_send_all) + await do_test(do_receive_some, do_receive_some) + await do_test(do_send_all, do_wait_send_all_might_not_block) + await do_test(do_wait_send_all_might_not_block, do_wait_send_all_might_not_block) async def test_wait_writable_calls_underlying_wait_writable() -> None: diff --git a/trio/_tests/test_subprocess.py b/src/trio/_tests/test_subprocess.py similarity index 88% rename from trio/_tests/test_subprocess.py rename to src/trio/_tests/test_subprocess.py index c901f6f29e..0a70e7a974 100644 --- a/trio/_tests/test_subprocess.py +++ b/src/trio/_tests/test_subprocess.py @@ -1,5 +1,6 @@ from __future__ import annotations +import gc import os import random import signal @@ -9,7 +10,6 @@ from functools import partial from pathlib import Path as SyncPath from signal import Signals -from types import FrameType from typing import ( TYPE_CHECKING, Any, @@ -20,10 +20,11 @@ ) import pytest -from pytest import MonkeyPatch, WarningsRecorder + +import trio +from trio.testing import Matcher, RaisesGroup from .. import ( - ClosedResourceError, Event, Process, _core, @@ -33,14 +34,17 @@ sleep, sleep_forever, ) -from .._abc import Stream from .._core._tests.tutil import skip_if_fbsd_pipes_broken, slow from ..lowlevel import open_process from ..testing import MockClock, assert_no_checkpoints, wait_all_tasks_blocked if TYPE_CHECKING: + from types import FrameType + from typing_extensions import TypeAlias + from .._abc import ReceiveStream + if sys.platform == "win32": SignalType: TypeAlias = None else: @@ -71,7 +75,7 @@ def python(code: str) -> list[str]: if posix: def SLEEP(seconds: int) -> list[str]: - return ["/bin/sleep", str(seconds)] + return ["sleep", str(seconds)] else: @@ -166,35 +170,6 @@ async def test_multi_wait(background_process: BackgroundProcessType) -> None: proc.kill() -# Test for deprecated 'async with process:' semantics -async def test_async_with_basics_deprecated(recwarn: WarningsRecorder) -> None: - async with await open_process( - CAT, stdin=subprocess.PIPE, stdout=subprocess.PIPE - ) as proc: - pass - assert proc.returncode is not None - assert proc.stdin is not None - assert proc.stdout is not None - with pytest.raises(ClosedResourceError): - await proc.stdin.send_all(b"x") - with pytest.raises(ClosedResourceError): - await proc.stdout.receive_some() - - -# Test for deprecated 'async with process:' semantics -async def test_kill_when_context_cancelled(recwarn: WarningsRecorder) -> None: - with move_on_after(100) as scope: - async with await open_process(SLEEP(10)) as proc: - assert proc.poll() is None - scope.cancel() - await sleep_forever() - assert scope.cancelled_caught - assert got_signal(proc, SIGKILL) - assert repr(proc) == "".format( - SLEEP(10), "exited with signal 9" if posix else "exited with status 1" - ) - - COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR = python( "data = sys.stdin.buffer.read(); " "sys.stdout.buffer.write(data); " @@ -217,12 +192,15 @@ async def feed_input() -> None: await proc.stdin.send_all(msg) await proc.stdin.aclose() - async def check_output(stream: Stream, expected: bytes) -> None: + async def check_output(stream: ReceiveStream, expected: bytes) -> None: seen = bytearray() async for chunk in stream: seen += chunk assert seen == expected + assert proc.stdout is not None + assert proc.stderr is not None + async with _core.open_nursery() as nursery: # fail eventually if something is broken nursery.cancel_scope.deadline = _core.current_time() + 30.0 @@ -267,7 +245,9 @@ async def test_interactive(background_process: BackgroundProcessType) -> None: async def expect(idx: int, request: int) -> None: async with _core.open_nursery() as nursery: - async def drain_one(stream: Stream, count: int, digit: int) -> None: + async def drain_one( + stream: ReceiveStream, count: int, digit: int + ) -> None: while count > 0: result = await stream.receive_some(count) assert result == (f"{digit}".encode() * len(result)) @@ -275,6 +255,8 @@ async def drain_one(stream: Stream, count: int, digit: int) -> None: assert count == 0 assert await stream.receive_some(len(newline)) == newline + assert proc.stdout is not None + assert proc.stderr is not None nursery.start_soon(drain_one, proc.stdout, request, idx * 2) nursery.start_soon(drain_one, proc.stderr, request * 2, idx * 2 + 1) @@ -336,15 +318,25 @@ async def test_run() -> None: # invalid combinations with pytest.raises(UnicodeError): await run_process(CAT, stdin="oh no, it's text") - with pytest.raises(ValueError): + + pipe_stdout_error = r"^stdout=subprocess\.PIPE is only valid with nursery\.start, since that's the only way to access the pipe(; use nursery\.start or pass the data you want to write directly)*$" + with pytest.raises(ValueError, match=pipe_stdout_error): await run_process(CAT, stdin=subprocess.PIPE) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=pipe_stdout_error): await run_process(CAT, stdout=subprocess.PIPE) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, match=pipe_stdout_error.replace("stdout", "stderr", 1) + ): await run_process(CAT, stderr=subprocess.PIPE) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="^can't specify both stdout and capture_stdout$", + ): await run_process(CAT, capture_stdout=True, stdout=subprocess.DEVNULL) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="^can't specify both stderr and capture_stderr$", + ): await run_process(CAT, capture_stderr=True, stderr=None) @@ -573,7 +565,25 @@ async def custom_deliver_cancel(proc: Process) -> None: assert custom_deliver_cancel_called -async def test_warn_on_failed_cancel_terminate(monkeypatch: MonkeyPatch) -> None: +def test_bad_deliver_cancel() -> None: + async def custom_deliver_cancel(proc: Process) -> None: + proc.terminate() + raise ValueError("foo") + + async def do_stuff() -> None: + async with _core.open_nursery() as nursery: + nursery.start_soon( + partial(run_process, SLEEP(9999), deliver_cancel=custom_deliver_cancel) + ) + await wait_all_tasks_blocked() + nursery.cancel_scope.cancel() + + # double wrap from our nursery + the internal nursery + with RaisesGroup(RaisesGroup(Matcher(ValueError, "^foo$"))): + _core.run(do_stuff, strict_exception_groups=True) + + +async def test_warn_on_failed_cancel_terminate(monkeypatch: pytest.MonkeyPatch) -> None: original_terminate = Process.terminate def broken_terminate(self: Process) -> NoReturn: @@ -591,7 +601,7 @@ def broken_terminate(self: Process) -> NoReturn: @pytest.mark.skipif(not posix, reason="posix only") async def test_warn_on_cancel_SIGKILL_escalation( - autojump_clock: MockClock, monkeypatch: MonkeyPatch + autojump_clock: MockClock, monkeypatch: pytest.MonkeyPatch ) -> None: monkeypatch.setattr(Process, "terminate", lambda *args: None) @@ -605,11 +615,9 @@ async def test_warn_on_cancel_SIGKILL_escalation( # the background_process_param exercises a lot of run_process cases, but it uses # check=False, so lets have a test that uses check=True as well async def test_run_process_background_fail() -> None: - with pytest.raises(subprocess.CalledProcessError): + with RaisesGroup(subprocess.CalledProcessError): async with _core.open_nursery() as nursery: - proc: subprocess.CompletedProcess[bytes] = await nursery.start( - run_process, EXIT_FALSE - ) + proc: Process = await nursery.start(run_process, EXIT_FALSE) assert proc.returncode == 1 @@ -618,6 +626,8 @@ async def test_run_process_background_fail() -> None: reason="requires a way to iterate through open files", ) async def test_for_leaking_fds() -> None: + gc.collect() # address possible flakiness on PyPy + starting_fds = set(SyncPath("/dev/fd").iterdir()) await run_process(EXIT_TRUE) assert set(SyncPath("/dev/fd").iterdir()) == starting_fds @@ -631,6 +641,17 @@ async def test_for_leaking_fds() -> None: assert set(SyncPath("/dev/fd").iterdir()) == starting_fds +async def test_run_process_internal_error(monkeypatch: pytest.MonkeyPatch) -> None: + # There's probably less extreme ways of triggering errors inside the nursery + # in run_process. + async def very_broken_open(*args: object, **kwargs: object) -> str: + return "oops" + + monkeypatch.setattr(trio._subprocess, "open_process", very_broken_open) + with RaisesGroup(AttributeError, AttributeError): + await run_process(EXIT_TRUE, capture_stdout=True) + + # regression test for #2209 async def test_subprocess_pidfd_unnotified() -> None: noticed_exit = None diff --git a/trio/_tests/test_sync.py b/src/trio/_tests/test_sync.py similarity index 96% rename from trio/_tests/test_sync.py rename to src/trio/_tests/test_sync.py index 9179c8a5ae..e4d04202cb 100644 --- a/trio/_tests/test_sync.py +++ b/src/trio/_tests/test_sync.py @@ -47,7 +47,7 @@ async def child() -> None: async def test_CapacityLimiter() -> None: with pytest.raises(TypeError): CapacityLimiter(1.0) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="^total_tokens must be >= 1$"): CapacityLimiter(-1) c = CapacityLimiter(2) repr(c) # smoke test @@ -135,10 +135,10 @@ async def test_CapacityLimiter_change_total_tokens() -> None: with pytest.raises(TypeError): c.total_tokens = 1.0 - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="^total_tokens must be >= 1$"): c.total_tokens = 0 - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="^total_tokens must be >= 1$"): c.total_tokens = -10 assert c.total_tokens == 2 @@ -183,7 +183,7 @@ async def test_CapacityLimiter_memleak_548() -> None: async def test_Semaphore() -> None: with pytest.raises(TypeError): Semaphore(1.0) # type: ignore[arg-type] - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="^initial value must be >= 0$"): Semaphore(-1) s = Semaphore(1) repr(s) # smoke test @@ -231,12 +231,12 @@ async def do_acquire(s: Semaphore) -> None: async def test_Semaphore_bounded() -> None: with pytest.raises(TypeError): Semaphore(1, max_value=1.0) # type: ignore[arg-type] - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="^max_values must be >= initial_value$"): Semaphore(2, max_value=1) bs = Semaphore(1, max_value=1) assert bs.max_value == 1 repr(bs) # smoke test - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="^semaphore released too many times$"): bs.release() assert bs.value == 1 bs.acquire_nowait() @@ -546,7 +546,7 @@ async def test_generic_lock_fifo_fairness(lock_factory: LockFactory) -> None: record = [] LOOPS = 5 - async def loopy(name: str, lock_like: LockLike) -> None: + async def loopy(name: int, lock_like: LockLike) -> None: # Record the order each task was initially scheduled in initial_order.append(name) for _ in range(LOOPS): diff --git a/trio/_tests/test_testing.py b/src/trio/_tests/test_testing.py similarity index 95% rename from trio/_tests/test_testing.py rename to src/trio/_tests/test_testing.py index d4348a8c59..0f2778dc15 100644 --- a/trio/_tests/test_testing.py +++ b/src/trio/_tests/test_testing.py @@ -2,12 +2,11 @@ # XX this should get broken up, like testing.py did import tempfile +from typing import TYPE_CHECKING import pytest -from pytest import WarningsRecorder -from trio import Nursery -from trio.abc import ReceiveStream, SendStream +from trio.testing import RaisesGroup from .. import _core, sleep, socket as tsocket from .._core._tests.tutil import can_bind_ipv6 @@ -17,6 +16,10 @@ from ..testing._check_streams import _assert_raises from ..testing._memory_streams import _UnboundedByteQueue +if TYPE_CHECKING: + from trio import Nursery + from trio.abc import ReceiveStream, SendStream + async def test_wait_all_tasks_blocked() -> None: record = [] @@ -109,7 +112,7 @@ async def wait_big_cushion() -> None: ################################################################ -async def test_assert_checkpoints(recwarn: WarningsRecorder) -> None: +async def test_assert_checkpoints(recwarn: pytest.WarningsRecorder) -> None: with assert_checkpoints(): await _core.checkpoint() @@ -135,7 +138,7 @@ async def test_assert_checkpoints(recwarn: WarningsRecorder) -> None: await _core.cancel_shielded_checkpoint() -async def test_assert_no_checkpoints(recwarn: WarningsRecorder) -> None: +async def test_assert_no_checkpoints(recwarn: pytest.WarningsRecorder) -> None: with assert_no_checkpoints(): 1 + 1 # noqa: B018 # "useless expression" @@ -233,8 +236,6 @@ async def child(i: int) -> None: ################################################################ - - async def test__assert_raises() -> None: with pytest.raises(AssertionError): with _assert_raises(RuntimeError): @@ -291,7 +292,7 @@ async def getter(expect: bytes) -> None: nursery.start_soon(putter, b"xyz") # Two gets at the same time -> BusyResourceError - with pytest.raises(_core.BusyResourceError): + with RaisesGroup(_core.BusyResourceError): async with _core.open_nursery() as nursery: nursery.start_soon(getter, b"asdf") nursery.start_soon(getter, b"asdf") @@ -425,7 +426,7 @@ async def do_receive_some(max_bytes: int | None) -> bytes: mrs.put_data(b"abc") assert await do_receive_some(None) == b"abc" - with pytest.raises(_core.BusyResourceError): + with RaisesGroup(_core.BusyResourceError): async with _core.open_nursery() as nursery: nursery.start_soon(do_receive_some, 10) nursery.start_soon(do_receive_some, 10) @@ -647,7 +648,8 @@ async def check(listener: SocketListener) -> None: sock.listen(10) await check(SocketListener(sock)) - if can_bind_ipv6: + # true on all CI systems + if can_bind_ipv6: # pragma: no branch # Listener bound to IPv6 wildcard (needs special handling) sock = tsocket.socket(family=tsocket.AF_INET6) await sock.bind(("::", 0)) @@ -664,3 +666,14 @@ async def check(listener: SocketListener) -> None: await sock.bind(path) sock.listen(10) await check(SocketListener(sock)) + + +def test_trio_test() -> None: + async def busy_kitchen( + *, mock_clock: object, autojump_clock: object + ) -> None: ... # pragma: no cover + + with pytest.raises(ValueError, match="^too many clocks spoil the broth!$"): + trio_test(busy_kitchen)( + mock_clock=MockClock(), autojump_clock=MockClock(autojump_threshold=0) + ) diff --git a/src/trio/_tests/test_testing_raisesgroup.py b/src/trio/_tests/test_testing_raisesgroup.py new file mode 100644 index 0000000000..1e96d38e52 --- /dev/null +++ b/src/trio/_tests/test_testing_raisesgroup.py @@ -0,0 +1,378 @@ +from __future__ import annotations + +import re +import sys +from types import TracebackType +from typing import Any + +import pytest + +import trio +from trio.testing import Matcher, RaisesGroup + +if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup + + +def wrap_escape(s: str) -> str: + return "^" + re.escape(s) + "$" + + +def test_raises_group() -> None: + with pytest.raises( + ValueError, + match=wrap_escape( + f'Invalid argument "{TypeError()!r}" must be exception type, Matcher, or RaisesGroup.' + ), + ): + RaisesGroup(TypeError()) + + with RaisesGroup(ValueError): + raise ExceptionGroup("foo", (ValueError(),)) + + with RaisesGroup(SyntaxError): + with RaisesGroup(ValueError): + raise ExceptionGroup("foo", (SyntaxError(),)) + + # multiple exceptions + with RaisesGroup(ValueError, SyntaxError): + raise ExceptionGroup("foo", (ValueError(), SyntaxError())) + + # order doesn't matter + with RaisesGroup(SyntaxError, ValueError): + raise ExceptionGroup("foo", (ValueError(), SyntaxError())) + + # nested exceptions + with RaisesGroup(RaisesGroup(ValueError)): + raise ExceptionGroup("foo", (ExceptionGroup("bar", (ValueError(),)),)) + + with RaisesGroup( + SyntaxError, + RaisesGroup(ValueError), + RaisesGroup(RuntimeError), + ): + raise ExceptionGroup( + "foo", + ( + SyntaxError(), + ExceptionGroup("bar", (ValueError(),)), + ExceptionGroup("", (RuntimeError(),)), + ), + ) + + # will error if there's excess exceptions + with pytest.raises(ExceptionGroup): + with RaisesGroup(ValueError): + raise ExceptionGroup("", (ValueError(), ValueError())) + + with pytest.raises(ExceptionGroup): + with RaisesGroup(ValueError): + raise ExceptionGroup("", (RuntimeError(), ValueError())) + + # will error if there's missing exceptions + with pytest.raises(ExceptionGroup): + with RaisesGroup(ValueError, ValueError): + raise ExceptionGroup("", (ValueError(),)) + + with pytest.raises(ExceptionGroup): + with RaisesGroup(ValueError, SyntaxError): + raise ExceptionGroup("", (ValueError(),)) + + +def test_flatten_subgroups() -> None: + # loose semantics, as with expect* + with RaisesGroup(ValueError, flatten_subgroups=True): + raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)) + + with RaisesGroup(ValueError, TypeError, flatten_subgroups=True): + raise ExceptionGroup("", (ExceptionGroup("", (ValueError(), TypeError())),)) + with RaisesGroup(ValueError, TypeError, flatten_subgroups=True): + raise ExceptionGroup("", [ExceptionGroup("", [ValueError()]), TypeError()]) + + # mixed loose is possible if you want it to be at least N deep + with RaisesGroup(RaisesGroup(ValueError, flatten_subgroups=True)): + raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)) + with RaisesGroup(RaisesGroup(ValueError, flatten_subgroups=True)): + raise ExceptionGroup( + "", (ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)),) + ) + with pytest.raises(ExceptionGroup): + with RaisesGroup(RaisesGroup(ValueError, flatten_subgroups=True)): + raise ExceptionGroup("", (ValueError(),)) + + # but not the other way around + with pytest.raises( + ValueError, + match="^You cannot specify a nested structure inside a RaisesGroup with", + ): + RaisesGroup(RaisesGroup(ValueError), flatten_subgroups=True) # type: ignore[call-overload] + + +def test_catch_unwrapped_exceptions() -> None: + # Catches lone exceptions with strict=False + # just as except* would + with RaisesGroup(ValueError, allow_unwrapped=True): + raise ValueError + + # expecting multiple unwrapped exceptions is not possible + with pytest.raises( + ValueError, match="^You cannot specify multiple exceptions with" + ): + RaisesGroup(SyntaxError, ValueError, allow_unwrapped=True) # type: ignore[call-overload] + # if users want one of several exception types they need to use a Matcher + # (which the error message suggests) + with RaisesGroup( + Matcher(check=lambda e: isinstance(e, (SyntaxError, ValueError))), + allow_unwrapped=True, + ): + raise ValueError + + # Unwrapped nested `RaisesGroup` is likely a user error, so we raise an error. + with pytest.raises(ValueError, match="has no effect when expecting"): + RaisesGroup(RaisesGroup(ValueError), allow_unwrapped=True) # type: ignore[call-overload] + + # But it *can* be used to check for nesting level +- 1 if they move it to + # the nested RaisesGroup. Users should probably use `Matcher`s instead though. + with RaisesGroup(RaisesGroup(ValueError, allow_unwrapped=True)): + raise ExceptionGroup("", [ExceptionGroup("", [ValueError()])]) + with RaisesGroup(RaisesGroup(ValueError, allow_unwrapped=True)): + raise ExceptionGroup("", [ValueError()]) + + # with allow_unwrapped=False (default) it will not be caught + with pytest.raises(ValueError, match="^value error text$"): + with RaisesGroup(ValueError): + raise ValueError("value error text") + + # allow_unwrapped on it's own won't match against nested groups + with pytest.raises(ExceptionGroup): + with RaisesGroup(ValueError, allow_unwrapped=True): + raise ExceptionGroup("", [ExceptionGroup("", [ValueError()])]) + + # for that you need both allow_unwrapped and flatten_subgroups + with RaisesGroup(ValueError, allow_unwrapped=True, flatten_subgroups=True): + raise ExceptionGroup("", [ExceptionGroup("", [ValueError()])]) + + # code coverage + with pytest.raises(TypeError): + with RaisesGroup(ValueError, allow_unwrapped=True): + raise TypeError + + +def test_match() -> None: + # supports match string + with RaisesGroup(ValueError, match="bar"): + raise ExceptionGroup("bar", (ValueError(),)) + + # now also works with ^$ + with RaisesGroup(ValueError, match="^bar$"): + raise ExceptionGroup("bar", (ValueError(),)) + + # it also includes notes + with RaisesGroup(ValueError, match="my note"): + e = ExceptionGroup("bar", (ValueError(),)) + e.add_note("my note") + raise e + + # and technically you can match it all with ^$ + # but you're probably better off using a Matcher at that point + with RaisesGroup(ValueError, match="^bar\nmy note$"): + e = ExceptionGroup("bar", (ValueError(),)) + e.add_note("my note") + raise e + + with pytest.raises(ExceptionGroup): + with RaisesGroup(ValueError, match="foo"): + raise ExceptionGroup("bar", (ValueError(),)) + + +def test_check() -> None: + exc = ExceptionGroup("", (ValueError(),)) + with RaisesGroup(ValueError, check=lambda x: x is exc): + raise exc + with pytest.raises(ExceptionGroup): + with RaisesGroup(ValueError, check=lambda x: x is exc): + raise ExceptionGroup("", (ValueError(),)) + + +def test_unwrapped_match_check() -> None: + def my_check(e: object) -> bool: # pragma: no cover + return True + + msg = ( + "`allow_unwrapped=True` bypasses the `match` and `check` parameters" + " if the exception is unwrapped. If you intended to match/check the" + " exception you should use a `Matcher` object. If you want to match/check" + " the exceptiongroup when the exception *is* wrapped you need to" + " do e.g. `if isinstance(exc.value, ExceptionGroup):" + " assert RaisesGroup(...).matches(exc.value)` afterwards." + ) + with pytest.raises(ValueError, match=re.escape(msg)): + RaisesGroup(ValueError, allow_unwrapped=True, match="foo") # type: ignore[call-overload] + with pytest.raises(ValueError, match=re.escape(msg)): + RaisesGroup(ValueError, allow_unwrapped=True, check=my_check) # type: ignore[call-overload] + + # Users should instead use a Matcher + rg = RaisesGroup(Matcher(ValueError, match="^foo$"), allow_unwrapped=True) + with rg: + raise ValueError("foo") + with rg: + raise ExceptionGroup("", [ValueError("foo")]) + + # or if they wanted to match/check the group, do a conditional `.matches()` + with RaisesGroup(ValueError, allow_unwrapped=True) as exc: + raise ExceptionGroup("bar", [ValueError("foo")]) + if isinstance(exc.value, ExceptionGroup): # pragma: no branch + assert RaisesGroup(ValueError, match="bar").matches(exc.value) + + +def test_RaisesGroup_matches() -> None: + rg = RaisesGroup(ValueError) + assert not rg.matches(None) + assert not rg.matches(ValueError()) + assert rg.matches(ExceptionGroup("", (ValueError(),))) + + +def test_message() -> None: + def check_message(message: str, body: RaisesGroup[Any]) -> None: + with pytest.raises( + AssertionError, + match=f"^DID NOT RAISE any exception, expected {re.escape(message)}$", + ): + with body: + ... + + # basic + check_message("ExceptionGroup(ValueError)", RaisesGroup(ValueError)) + # multiple exceptions + check_message( + "ExceptionGroup(ValueError, ValueError)", RaisesGroup(ValueError, ValueError) + ) + # nested + check_message( + "ExceptionGroup(ExceptionGroup(ValueError))", + RaisesGroup(RaisesGroup(ValueError)), + ) + + # Matcher + check_message( + "ExceptionGroup(Matcher(ValueError, match='my_str'))", + RaisesGroup(Matcher(ValueError, "my_str")), + ) + check_message( + "ExceptionGroup(Matcher(match='my_str'))", + RaisesGroup(Matcher(match="my_str")), + ) + + # BaseExceptionGroup + check_message( + "BaseExceptionGroup(KeyboardInterrupt)", RaisesGroup(KeyboardInterrupt) + ) + # BaseExceptionGroup with type inside Matcher + check_message( + "BaseExceptionGroup(Matcher(KeyboardInterrupt))", + RaisesGroup(Matcher(KeyboardInterrupt)), + ) + # Base-ness transfers to parent containers + check_message( + "BaseExceptionGroup(BaseExceptionGroup(KeyboardInterrupt))", + RaisesGroup(RaisesGroup(KeyboardInterrupt)), + ) + # but not to child containers + check_message( + "BaseExceptionGroup(BaseExceptionGroup(KeyboardInterrupt), ExceptionGroup(ValueError))", + RaisesGroup(RaisesGroup(KeyboardInterrupt), RaisesGroup(ValueError)), + ) + + +def test_matcher() -> None: + with pytest.raises( + ValueError, match="^You must specify at least one parameter to match on.$" + ): + Matcher() # type: ignore[call-overload] + with pytest.raises( + ValueError, + match=f"^exception_type {re.escape(repr(object))} must be a subclass of BaseException$", + ): + Matcher(object) # type: ignore[type-var] + + with RaisesGroup(Matcher(ValueError)): + raise ExceptionGroup("", (ValueError(),)) + with pytest.raises(ExceptionGroup): + with RaisesGroup(Matcher(TypeError)): + raise ExceptionGroup("", (ValueError(),)) + + +def test_matcher_match() -> None: + with RaisesGroup(Matcher(ValueError, "foo")): + raise ExceptionGroup("", (ValueError("foo"),)) + with pytest.raises(ExceptionGroup): + with RaisesGroup(Matcher(ValueError, "foo")): + raise ExceptionGroup("", (ValueError("bar"),)) + + # Can be used without specifying the type + with RaisesGroup(Matcher(match="foo")): + raise ExceptionGroup("", (ValueError("foo"),)) + with pytest.raises(ExceptionGroup): + with RaisesGroup(Matcher(match="foo")): + raise ExceptionGroup("", (ValueError("bar"),)) + + # check ^$ + with RaisesGroup(Matcher(ValueError, match="^bar$")): + raise ExceptionGroup("", [ValueError("bar")]) + with pytest.raises(ExceptionGroup): + with RaisesGroup(Matcher(ValueError, match="^bar$")): + raise ExceptionGroup("", [ValueError("barr")]) + + +def test_Matcher_check() -> None: + def check_oserror_and_errno_is_5(e: BaseException) -> bool: + return isinstance(e, OSError) and e.errno == 5 + + with RaisesGroup(Matcher(check=check_oserror_and_errno_is_5)): + raise ExceptionGroup("", (OSError(5, ""),)) + + # specifying exception_type narrows the parameter type to the callable + def check_errno_is_5(e: OSError) -> bool: + return e.errno == 5 + + with RaisesGroup(Matcher(OSError, check=check_errno_is_5)): + raise ExceptionGroup("", (OSError(5, ""),)) + + with pytest.raises(ExceptionGroup): + with RaisesGroup(Matcher(OSError, check=check_errno_is_5)): + raise ExceptionGroup("", (OSError(6, ""),)) + + +def test_matcher_tostring() -> None: + assert str(Matcher(ValueError)) == "Matcher(ValueError)" + assert str(Matcher(match="[a-z]")) == "Matcher(match='[a-z]')" + pattern_no_flags = re.compile("noflag", 0) + assert str(Matcher(match=pattern_no_flags)) == "Matcher(match='noflag')" + pattern_flags = re.compile("noflag", re.IGNORECASE) + assert str(Matcher(match=pattern_flags)) == f"Matcher(match={pattern_flags!r})" + assert ( + str(Matcher(ValueError, match="re", check=bool)) + == f"Matcher(ValueError, match='re', check={bool!r})" + ) + + +def test__ExceptionInfo(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + trio.testing._raises_group, + "ExceptionInfo", + trio.testing._raises_group._ExceptionInfo, + ) + with trio.testing.RaisesGroup(ValueError) as excinfo: + raise ExceptionGroup("", (ValueError("hello"),)) + assert excinfo.type is ExceptionGroup + assert excinfo.value.exceptions[0].args == ("hello",) + assert isinstance(excinfo.tb, TracebackType) + + +def test_deprecated_strict() -> None: + """`strict` has been replaced with `flatten_subgroups`""" + # parameter is not included in overloaded signatures at all + with pytest.deprecated_call(): + RaisesGroup(ValueError, strict=False) # type: ignore[call-overload] + with pytest.deprecated_call(): + RaisesGroup(ValueError, strict=True) # type: ignore[call-overload] diff --git a/trio/_tests/test_threads.py b/src/trio/_tests/test_threads.py similarity index 90% rename from trio/_tests/test_threads.py rename to src/trio/_tests/test_threads.py index 86d59b4dbb..b4a5842ff0 100644 --- a/trio/_tests/test_threads.py +++ b/src/trio/_tests/test_threads.py @@ -9,6 +9,7 @@ import weakref from functools import partial from typing import ( + TYPE_CHECKING, AsyncGenerator, Awaitable, Callable, @@ -22,8 +23,6 @@ import pytest import sniffio -from outcome import Outcome -from pytest import MonkeyPatch from .. import ( CancelScope, @@ -36,17 +35,23 @@ sleep_forever, ) from .._core._tests.test_ki import ki_self -from .._core._tests.tutil import buggy_pypy_asyncgens, slow +from .._core._tests.tutil import slow from .._threads import ( + active_thread_count, current_default_thread_limiter, from_thread_check_cancelled, from_thread_run, from_thread_run_sync, to_thread_run_sync, + wait_all_threads_completed, ) -from ..lowlevel import Task from ..testing import wait_all_tasks_blocked +if TYPE_CHECKING: + from outcome import Outcome + + from ..lowlevel import Task + RecordType = List[Tuple[str, Union[threading.Thread, Type[BaseException]]]] T = TypeVar("T") @@ -160,8 +165,8 @@ def external_thread_fn() -> None: thread = threading.Thread(target=external_thread_fn) thread.start() print("waiting") - while thread.is_alive(): - await sleep(0.01) + while thread.is_alive(): # noqa: ASYNC110 + await sleep(0.01) # Fine to poll in tests. print("waited, joining") thread.join() print("done") @@ -233,9 +238,17 @@ def _get_thread_name(ident: int | None = None) -> str | None: libpthread_path = ctypes.util.find_library("pthread") if not libpthread_path: - print(f"no pthread on {sys.platform})") + # musl includes pthread functions directly in libc.so + # (but note that find_library("c") does not work on musl, + # see: https://github.com/python/cpython/issues/65821) + # so try that library instead + # if it doesn't exist, CDLL() will fail below + libpthread_path = "libc.so" + try: + libpthread = ctypes.CDLL(libpthread_path) + except Exception: + print(f"no pthread on {sys.platform}") return None - libpthread = ctypes.CDLL(libpthread_path) pthread_getname_np = getattr(libpthread, "pthread_getname_np", None) @@ -322,7 +335,9 @@ def f(x: T) -> tuple[T, threading.Thread]: def g() -> NoReturn: raise ValueError(threading.current_thread()) - with pytest.raises(ValueError) as excinfo: + with pytest.raises( + ValueError, match=r"^$" + ) as excinfo: await to_thread_run_sync(g) print(excinfo.value.args) assert excinfo.value.args[0] != trio_thread @@ -337,10 +352,10 @@ def f(q: stdlib_queue.Queue[str]) -> None: q.get() register[0] = "finished" - async def child(q: stdlib_queue.Queue[None], cancellable: bool) -> None: + async def child(q: stdlib_queue.Queue[None], abandon_on_cancel: bool) -> None: record.append("start") try: - return await to_thread_run_sync(f, q, cancellable=cancellable) + return await to_thread_run_sync(f, q, abandon_on_cancel=abandon_on_cancel) finally: record.append("exit") @@ -360,7 +375,7 @@ async def child(q: stdlib_queue.Queue[None], cancellable: bool) -> None: # Put the thread out of its misery: q.put(None) while register[0] != "finished": - time.sleep(0.01) # noqa: ASYNC101 # Need to wait for OS thread + time.sleep(0.01) # noqa: ASYNC251 # Need to wait for OS thread # This one can't be cancelled record = [] @@ -389,7 +404,7 @@ async def child(q: stdlib_queue.Queue[None], cancellable: bool) -> None: # handled gracefully. (Requires that the thread result machinery be prepared # for call_soon to raise RunFinishedError.) def test_run_in_worker_thread_abandoned( - capfd: pytest.CaptureFixture[str], monkeypatch: MonkeyPatch + capfd: pytest.CaptureFixture[str], monkeypatch: pytest.MonkeyPatch ) -> None: monkeypatch.setattr(_core._thread_cache, "IDLE_TIMEOUT", 0.01) @@ -402,7 +417,7 @@ def thread_fn() -> None: async def main() -> None: async def child() -> None: - await to_thread_run_sync(thread_fn, cancellable=True) + await to_thread_run_sync(thread_fn, abandon_on_cancel=True) async with _core.open_nursery() as nursery: nursery.start_soon(child) @@ -491,7 +506,10 @@ def thread_fn(cancel_scope: CancelScope) -> None: async def run_thread(event: Event) -> None: with _core.CancelScope() as cancel_scope: await to_thread_run_sync( - thread_fn, cancel_scope, limiter=limiter_arg, cancellable=cancel + thread_fn, + cancel_scope, + abandon_on_cancel=cancel, + limiter=limiter_arg, ) print("run_thread finished, cancelled:", cancel_scope.cancelled_caught) event.set() @@ -517,7 +535,9 @@ async def run_thread(event: Event) -> None: # sure no-one is sneaking past, and to make sure the high_water # check below won't fail due to scheduling issues. (It could still # fail if too many threads are let through here.) - while state.parked != MAX or c.statistics().tasks_waiting != MAX: + while ( # noqa: ASYNC110 + state.parked != MAX or c.statistics().tasks_waiting != MAX + ): await sleep(0.01) # pragma: no cover # Then release the threads gate.set() @@ -528,7 +548,7 @@ async def run_thread(event: Event) -> None: # Some threads might still be running; need to wait to them to # finish before checking that all threads ran. We can do this # using the CapacityLimiter. - while c.borrowed_tokens > 0: + while c.borrowed_tokens > 0: # noqa: ASYNC110 await sleep(0.01) # pragma: no cover assert state.ran == COUNT @@ -566,11 +586,11 @@ async def acquire_on_behalf_of(self, borrower: Task) -> None: def release_on_behalf_of(self, borrower: Task) -> NoReturn: record.append("release") - raise ValueError + raise ValueError("release on behalf") bs = BadCapacityLimiter() - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError, match="^release on behalf$") as excinfo: await to_thread_run_sync(lambda: None, limiter=bs) # type: ignore[arg-type] assert excinfo.value.__context__ is None assert record == ["acquire", "release"] @@ -579,13 +599,15 @@ def release_on_behalf_of(self, borrower: Task) -> NoReturn: # If the original function raised an error, then the semaphore error # chains with it d: dict[str, object] = {} - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError, match="^release on behalf$") as excinfo: await to_thread_run_sync(lambda: d["x"], limiter=bs) # type: ignore[arg-type] assert isinstance(excinfo.value.__context__, KeyError) assert record == ["acquire", "release"] -async def test_run_in_worker_thread_fail_to_spawn(monkeypatch: MonkeyPatch) -> None: +async def test_run_in_worker_thread_fail_to_spawn( + monkeypatch: pytest.MonkeyPatch, +) -> None: # Test the unlikely but possible case where trying to spawn a thread fails def bad_start(self: object, *args: object) -> NoReturn: raise RuntimeError("the engines canna take it captain") @@ -846,7 +868,6 @@ def not_called() -> None: # pragma: no cover from_thread_run_sync(not_called, trio_token=trio_token) -@pytest.mark.skipif(buggy_pypy_asyncgens, reason="pypy 7.2.0 is buggy") def test_from_thread_run_during_shutdown() -> None: save = [] record = [] @@ -881,7 +902,7 @@ async def test_trio_token_weak_referenceable() -> None: assert token is weak_reference() -async def test_unsafe_cancellable_kwarg() -> None: +async def test_unsafe_abandon_on_cancel_kwarg() -> None: # This is a stand in for a numpy ndarray or other objects # that (maybe surprisingly) lack a notion of truthiness class BadBool: @@ -889,7 +910,7 @@ def __bool__(self) -> bool: raise NotImplementedError with pytest.raises(NotImplementedError): - await to_thread_run_sync(int, cancellable=BadBool()) # type: ignore[arg-type] + await to_thread_run_sync(int, abandon_on_cancel=BadBool()) # type: ignore[arg-type] async def test_from_thread_reuses_task() -> None: @@ -909,7 +930,7 @@ def get_tid_then_reenter() -> int: nonlocal tid tid = threading.get_ident() # The nesting of wrapper functions loses the return value of threading.get_ident - return from_thread_run(to_thread_run_sync, threading.get_ident) # type: ignore[return-value] + return from_thread_run(to_thread_run_sync, threading.get_ident) # type: ignore[no-any-return] assert tid != await to_thread_run_sync(get_tid_then_reenter) @@ -933,7 +954,7 @@ def sync_check() -> None: assert not queue.get_nowait() with _core.CancelScope() as cancel_scope: - await to_thread_run_sync(sync_check, cancellable=True) + await to_thread_run_sync(sync_check, abandon_on_cancel=True) assert cancel_scope.cancelled_caught assert not await to_thread_run_sync(partial(queue.get, timeout=1)) @@ -957,7 +978,7 @@ def async_check() -> None: assert not queue.get_nowait() with _core.CancelScope() as cancel_scope: - await to_thread_run_sync(async_check, cancellable=True) + await to_thread_run_sync(async_check, abandon_on_cancel=True) assert cancel_scope.cancelled_caught assert not await to_thread_run_sync(partial(queue.get, timeout=1)) @@ -976,11 +997,11 @@ async def async_time_bomb() -> None: async def test_from_thread_check_cancelled() -> None: q: stdlib_queue.Queue[str] = stdlib_queue.Queue() - async def child(cancellable: bool, scope: CancelScope) -> None: + async def child(abandon_on_cancel: bool, scope: CancelScope) -> None: with scope: record.append("start") try: - return await to_thread_run_sync(f, cancellable=cancellable) + return await to_thread_run_sync(f, abandon_on_cancel=abandon_on_cancel) except _core.Cancelled: record.append("cancel") raise @@ -1009,7 +1030,7 @@ def f() -> None: # implicit assertion, Cancelled not raised via nursery assert record[1] == "exit" - # cancellable=False case: a cancel will pop out but be handled by + # abandon_on_cancel=False case: a cancel will pop out but be handled by # the appropriate cancel scope record = [] ev = threading.Event() @@ -1025,7 +1046,7 @@ def f() -> None: assert "cancel" in record assert record[-1] == "exit" - # cancellable=True case: slightly different thread behavior needed + # abandon_on_cancel=True case: slightly different thread behavior needed # check thread is cancelled "soon" after abandonment def f() -> None: # type: ignore[no-redef] # noqa: F811 ev.wait() @@ -1068,9 +1089,56 @@ async def test_reentry_doesnt_deadlock() -> None: async def child() -> None: while True: - await to_thread_run_sync(from_thread_run, sleep, 0, cancellable=False) + await to_thread_run_sync(from_thread_run, sleep, 0, abandon_on_cancel=False) with move_on_after(2): async with _core.open_nursery() as nursery: for _ in range(4): nursery.start_soon(child) + + +async def test_wait_all_threads_completed() -> None: + no_threads_left = False + e1 = Event() + e2 = Event() + + e1_exited = Event() + e2_exited = Event() + + async def wait_event(e: Event, e_exit: Event) -> None: + def thread() -> None: + from_thread_run(e.wait) + + await to_thread_run_sync(thread) + e_exit.set() + + async def wait_no_threads_left() -> None: + nonlocal no_threads_left + await wait_all_threads_completed() + no_threads_left = True + + async with _core.open_nursery() as nursery: + nursery.start_soon(wait_event, e1, e1_exited) + nursery.start_soon(wait_event, e2, e2_exited) + await wait_all_tasks_blocked() + nursery.start_soon(wait_no_threads_left) + await wait_all_tasks_blocked() + assert not no_threads_left + assert active_thread_count() == 2 + + e1.set() + await e1_exited.wait() + await wait_all_tasks_blocked() + assert not no_threads_left + assert active_thread_count() == 1 + + e2.set() + await e2_exited.wait() + await wait_all_tasks_blocked() + assert no_threads_left + assert active_thread_count() == 0 + + +async def test_wait_all_threads_completed_no_threads() -> None: + await wait_all_threads_completed() + assert active_thread_count() == 0 diff --git a/trio/_tests/test_timeouts.py b/src/trio/_tests/test_timeouts.py similarity index 92% rename from trio/_tests/test_timeouts.py rename to src/trio/_tests/test_timeouts.py index c6def0bf9e..98c3d18def 100644 --- a/trio/_tests/test_timeouts.py +++ b/src/trio/_tests/test_timeouts.py @@ -109,7 +109,10 @@ async def test_timeouts_raise_value_error() -> None: (sleep, nan), (sleep_until, nan), ): - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="^(duration|deadline|timeout) must (not )*be (non-negative|NaN)$", + ): await fun(val) for cm, val in ( @@ -120,6 +123,9 @@ async def test_timeouts_raise_value_error() -> None: (move_on_after, nan), (move_on_at, nan), ): - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="^(duration|deadline|timeout) must (not )*be (non-negative|NaN)$", + ): with cm(val): pass # pragma: no cover diff --git a/trio/_tests/test_tracing.py b/src/trio/_tests/test_tracing.py similarity index 100% rename from trio/_tests/test_tracing.py rename to src/trio/_tests/test_tracing.py diff --git a/trio/_tests/test_unix_pipes.py b/src/trio/_tests/test_unix_pipes.py similarity index 94% rename from trio/_tests/test_unix_pipes.py rename to src/trio/_tests/test_unix_pipes.py index c258cd97cc..6f8fa6e02e 100644 --- a/trio/_tests/test_unix_pipes.py +++ b/src/trio/_tests/test_unix_pipes.py @@ -7,7 +7,6 @@ from typing import TYPE_CHECKING import pytest -from pytest import MonkeyPatch from .. import _core from .._core._tests.tutil import gc_collect_harder, skip_if_fbsd_pipes_broken @@ -108,7 +107,7 @@ async def test_pipe_errors() -> None: r, w = os.pipe() os.close(w) async with FdStream(r) as s: - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="^max_bytes must be integer >= 1$"): await s.receive_some(0) @@ -118,11 +117,11 @@ async def test_del() -> None: del w, r gc_collect_harder() - with pytest.raises(OSError) as excinfo: + with pytest.raises(OSError, match="Bad file descriptor$") as excinfo: os.close(f1) assert excinfo.value.errno == errno.EBADF - with pytest.raises(OSError) as excinfo: + with pytest.raises(OSError, match="Bad file descriptor$") as excinfo: os.close(f2) assert excinfo.value.errno == errno.EBADF @@ -135,11 +134,11 @@ async def test_async_with() -> None: assert w.fileno() == -1 assert r.fileno() == -1 - with pytest.raises(OSError) as excinfo: + with pytest.raises(OSError, match="Bad file descriptor$") as excinfo: os.close(w.fileno()) assert excinfo.value.errno == errno.EBADF - with pytest.raises(OSError) as excinfo: + with pytest.raises(OSError, match="Bad file descriptor$") as excinfo: os.close(r.fileno()) assert excinfo.value.errno == errno.EBADF @@ -182,7 +181,9 @@ async def expect_eof() -> None: os.close(w2_fd) -async def test_close_at_bad_time_for_receive_some(monkeypatch: MonkeyPatch) -> None: +async def test_close_at_bad_time_for_receive_some( + monkeypatch: pytest.MonkeyPatch, +) -> None: # We used to have race conditions where if one task was using the pipe, # and another closed it at *just* the wrong moment, it would give an # unexpected error instead of ClosedResourceError: @@ -210,7 +211,7 @@ async def patched_wait_readable(*args, **kwargs) -> None: await s.send_all(b"x") -async def test_close_at_bad_time_for_send_all(monkeypatch: MonkeyPatch) -> None: +async def test_close_at_bad_time_for_send_all(monkeypatch: pytest.MonkeyPatch) -> None: # We used to have race conditions where if one task was using the pipe, # and another closed it at *just* the wrong moment, it would give an # unexpected error instead of ClosedResourceError: diff --git a/trio/_tests/test_util.py b/src/trio/_tests/test_util.py similarity index 95% rename from trio/_tests/test_util.py rename to src/trio/_tests/test_util.py index 40c2fd11bb..3e62eb622e 100644 --- a/trio/_tests/test_util.py +++ b/src/trio/_tests/test_util.py @@ -6,6 +6,7 @@ import pytest import trio +from trio.testing import Matcher, RaisesGroup from .. import _core from .._core._tests.tutil import ( @@ -49,21 +50,19 @@ async def test_ConflictDetector() -> None: with ul2: print("ok") - with pytest.raises(_core.BusyResourceError) as excinfo: + with pytest.raises(_core.BusyResourceError, match="ul1"): with ul1: with ul1: pass # pragma: no cover - assert "ul1" in str(excinfo.value) async def wait_with_ul1() -> None: with ul1: await wait_all_tasks_blocked() - with pytest.raises(_core.BusyResourceError) as excinfo: + with RaisesGroup(Matcher(_core.BusyResourceError, "ul1")): async with _core.open_nursery() as nursery: nursery.start_soon(wait_with_ul1) nursery.start_soon(wait_with_ul1) - assert "ul1" in str(excinfo.value) def test_module_metadata_is_fixed_up() -> None: @@ -162,8 +161,8 @@ async def async_gen(_: object) -> Any: # pragma: no cover def test_generic_function() -> None: - @generic_function - def test_func(arg: T) -> T: + @generic_function # Decorated function contains "Any". + def test_func(arg: T) -> T: # type: ignore[misc] """Look, a docstring!""" return arg @@ -270,7 +269,7 @@ def test_fixup_module_metadata() -> None: assert mod.SomeClass.method.__module__ == "trio.somemodule" # type: ignore[attr-defined] assert mod.SomeClass.method.__qualname__ == "SomeClass.method" # type: ignore[attr-defined] # Make coverage happy. - non_trio_module.some_func() # type: ignore[no-untyped-call] - mod.some_func() # type: ignore[no-untyped-call] - mod._private() # type: ignore[no-untyped-call] + non_trio_module.some_func() + mod.some_func() + mod._private() mod.SomeClass().method() diff --git a/trio/_tests/test_wait_for_object.py b/src/trio/_tests/test_wait_for_object.py similarity index 97% rename from trio/_tests/test_wait_for_object.py rename to src/trio/_tests/test_wait_for_object.py index b41bcba3a5..54bbb77567 100644 --- a/trio/_tests/test_wait_for_object.py +++ b/src/trio/_tests/test_wait_for_object.py @@ -54,7 +54,7 @@ async def test_WaitForMultipleObjects_sync() -> None: handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) kernel32.CloseHandle(handle1) - with pytest.raises(OSError): + with pytest.raises(OSError, match=r"^\[WinError 6\] The handle is invalid$"): WaitForMultipleObjects_sync(handle1, handle2) kernel32.CloseHandle(handle2) print("test_WaitForMultipleObjects_sync close first OK") @@ -63,7 +63,7 @@ async def test_WaitForMultipleObjects_sync() -> None: handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) kernel32.CloseHandle(handle2) - with pytest.raises(OSError): + with pytest.raises(OSError, match=r"^\[WinError 6\] The handle is invalid$"): WaitForMultipleObjects_sync(handle1, handle2) kernel32.CloseHandle(handle1) print("test_WaitForMultipleObjects_sync close second OK") @@ -147,7 +147,7 @@ async def test_WaitForSingleObject() -> None: # Test already closed handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL) kernel32.CloseHandle(handle) - with pytest.raises(OSError): + with pytest.raises(OSError, match=r"^\[WinError 6\] The handle is invalid$"): await WaitForSingleObject(handle) # should return at once print("test_WaitForSingleObject already closed OK") diff --git a/trio/_tests/test_windows_pipes.py b/src/trio/_tests/test_windows_pipes.py similarity index 95% rename from trio/_tests/test_windows_pipes.py rename to src/trio/_tests/test_windows_pipes.py index f0783b7b06..38a25cdc54 100644 --- a/trio/_tests/test_windows_pipes.py +++ b/src/trio/_tests/test_windows_pipes.py @@ -45,9 +45,9 @@ async def test_pipe_error_on_close() -> None: assert kernel32.CloseHandle(_handle(r)) assert kernel32.CloseHandle(_handle(w)) - with pytest.raises(OSError): + with pytest.raises(OSError, match=r"^\[WinError 6\] The handle is invalid$"): await send_stream.aclose() - with pytest.raises(OSError): + with pytest.raises(OSError, match=r"^\[WinError 6\] The handle is invalid$"): await receive_stream.aclose() diff --git a/trio/_tests/tools/__init__.py b/src/trio/_tests/tools/__init__.py similarity index 100% rename from trio/_tests/tools/__init__.py rename to src/trio/_tests/tools/__init__.py diff --git a/trio/_tests/tools/test_gen_exports.py b/src/trio/_tests/tools/test_gen_exports.py similarity index 90% rename from trio/_tests/tools/test_gen_exports.py rename to src/trio/_tests/tools/test_gen_exports.py index 9c0b5f625d..19158451f7 100644 --- a/trio/_tests/tools/test_gen_exports.py +++ b/src/trio/_tests/tools/test_gen_exports.py @@ -19,7 +19,6 @@ create_passthrough_args, get_public_methods, process, - run_black, run_linters, run_ruff, ) @@ -123,23 +122,6 @@ def test_process(tmp_path: Path, imports: str) -> None: assert excinfo.value.code == 1 -@skip_lints -def test_run_black(tmp_path: Path) -> None: - """Test that processing properly fails if black does.""" - try: - import black # noqa: F401 - except ImportError as error: # pragma: no cover - skip_if_optional_else_raise(error) - - file = File(tmp_path / "module.py", "module") - - success, _ = run_black(file, "class not valid code ><") - assert not success - - success, _ = run_black(file, "import waffle\n;import trio") - assert not success - - @skip_lints def test_run_ruff(tmp_path: Path) -> None: """Test that processing properly fails if ruff does.""" diff --git a/trio/_tests/tools/test_mypy_annotate.py b/src/trio/_tests/tools/test_mypy_annotate.py similarity index 97% rename from trio/_tests/tools/test_mypy_annotate.py rename to src/trio/_tests/tools/test_mypy_annotate.py index 28ebebb592..0ff4babb99 100644 --- a/trio/_tests/tools/test_mypy_annotate.py +++ b/src/trio/_tests/tools/test_mypy_annotate.py @@ -2,15 +2,18 @@ import io import sys -from pathlib import Path +from typing import TYPE_CHECKING import pytest from trio._tools.mypy_annotate import Result, export, main, process_line +if TYPE_CHECKING: + from pathlib import Path + @pytest.mark.parametrize( - "src, expected", + ("src", "expected"), [ ("", None), ("a regular line\n", None), diff --git a/trio/_tests/type_tests/check_wraps.py b/src/trio/_tests/type_tests/check_wraps.py similarity index 74% rename from trio/_tests/type_tests/check_wraps.py rename to src/trio/_tests/type_tests/check_wraps.py index 5692738be4..058e9d0609 100644 --- a/trio/_tests/type_tests/check_wraps.py +++ b/src/trio/_tests/type_tests/check_wraps.py @@ -1,9 +1,9 @@ # https://github.com/python-trio/trio/issues/2775#issuecomment-1702267589 # (except platform independent...) import trio -import typing_extensions +from typing_extensions import assert_type async def fn(s: trio.SocketStream) -> None: result = await s.socket.sendto(b"a", "h") - typing_extensions.assert_type(result, int) + assert_type(result, int) diff --git a/src/trio/_tests/type_tests/open_memory_channel.py b/src/trio/_tests/type_tests/open_memory_channel.py new file mode 100644 index 0000000000..e37f59c2c8 --- /dev/null +++ b/src/trio/_tests/type_tests/open_memory_channel.py @@ -0,0 +1,4 @@ +# https://github.com/python-trio/trio/issues/2873 +import trio + +s, r = trio.open_memory_channel[int](0) diff --git a/trio/_tests/type_tests/path.py b/src/trio/_tests/type_tests/path.py similarity index 84% rename from trio/_tests/type_tests/path.py rename to src/trio/_tests/type_tests/path.py index 321fd1043b..15d25ae954 100644 --- a/trio/_tests/type_tests/path.py +++ b/src/trio/_tests/type_tests/path.py @@ -1,4 +1,5 @@ """Path wrapping is quite complex, ensure all methods are understood as wrapped correctly.""" + import io import os import pathlib @@ -6,7 +7,7 @@ from typing import IO, Any, BinaryIO, List, Tuple import trio -from trio._path import _AsyncIOWrapper +from trio._file_io import AsyncIOWrapper from typing_extensions import assert_type @@ -38,7 +39,7 @@ def sync_attrs(path: trio.Path) -> None: assert_type(path.drive, str) assert_type(path.root, str) assert_type(path.anchor, str) - assert_type(path.parents[3], pathlib.Path) + assert_type(path.parents[3], trio.Path) assert_type(path.parent, trio.Path) assert_type(path.name, str) assert_type(path.suffix, str) @@ -118,16 +119,16 @@ async def async_attrs(path: trio.Path) -> None: async def open_results(path: trio.Path, some_int: int, some_str: str) -> None: # Check the overloads. - assert_type(await path.open(), _AsyncIOWrapper[io.TextIOWrapper]) - assert_type(await path.open("r"), _AsyncIOWrapper[io.TextIOWrapper]) - assert_type(await path.open("r+"), _AsyncIOWrapper[io.TextIOWrapper]) - assert_type(await path.open("w"), _AsyncIOWrapper[io.TextIOWrapper]) - assert_type(await path.open("rb", buffering=0), _AsyncIOWrapper[io.FileIO]) - assert_type(await path.open("rb+"), _AsyncIOWrapper[io.BufferedRandom]) - assert_type(await path.open("wb"), _AsyncIOWrapper[io.BufferedWriter]) - assert_type(await path.open("rb"), _AsyncIOWrapper[io.BufferedReader]) - assert_type(await path.open("rb", buffering=some_int), _AsyncIOWrapper[BinaryIO]) - assert_type(await path.open(some_str), _AsyncIOWrapper[IO[Any]]) + assert_type(await path.open(), AsyncIOWrapper[io.TextIOWrapper]) + assert_type(await path.open("r"), AsyncIOWrapper[io.TextIOWrapper]) + assert_type(await path.open("r+"), AsyncIOWrapper[io.TextIOWrapper]) + assert_type(await path.open("w"), AsyncIOWrapper[io.TextIOWrapper]) + assert_type(await path.open("rb", buffering=0), AsyncIOWrapper[io.FileIO]) + assert_type(await path.open("rb+"), AsyncIOWrapper[io.BufferedRandom]) + assert_type(await path.open("wb"), AsyncIOWrapper[io.BufferedWriter]) + assert_type(await path.open("rb"), AsyncIOWrapper[io.BufferedReader]) + assert_type(await path.open("rb", buffering=some_int), AsyncIOWrapper[BinaryIO]) + assert_type(await path.open(some_str), AsyncIOWrapper[IO[Any]]) # Check they produce the right types. file_bin = await path.open("rb+") @@ -138,4 +139,5 @@ async def open_results(path: trio.Path, some_int: int, some_str: str) -> None: file_text = await path.open("r+t") assert_type(await file_text.read(), str) assert_type(await file_text.write("test"), int) + # TODO: report mypy bug: equiv to https://github.com/microsoft/pyright/issues/6833 assert_type(await file_text.readlines(), List[str]) diff --git a/src/trio/_tests/type_tests/raisesgroup.py b/src/trio/_tests/type_tests/raisesgroup.py new file mode 100644 index 0000000000..fe4053ebc5 --- /dev/null +++ b/src/trio/_tests/type_tests/raisesgroup.py @@ -0,0 +1,252 @@ +"""The typing of RaisesGroup involves a lot of deception and lies, since AFAIK what we +actually want to achieve is ~impossible. This is because we specify what we expect with +instances of RaisesGroup and exception classes, but excinfo.value will be instances of +[Base]ExceptionGroup and instances of exceptions. So we need to "translate" from +RaisesGroup to ExceptionGroup. + +The way it currently works is that RaisesGroup[E] corresponds to +ExceptionInfo[BaseExceptionGroup[E]], so the top-level group will be correct. But +RaisesGroup[RaisesGroup[ValueError]] will become +ExceptionInfo[BaseExceptionGroup[RaisesGroup[ValueError]]]. To get around that we specify +RaisesGroup as a subclass of BaseExceptionGroup during type checking - which should mean +that most static type checking for end users should be mostly correct. +""" + +from __future__ import annotations + +import sys +from typing import Union + +from trio.testing import Matcher, RaisesGroup +from typing_extensions import assert_type + +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup, ExceptionGroup + +# split into functions to isolate the different scopes + + +def check_inheritance_and_assignments() -> None: + # Check inheritance + _: BaseExceptionGroup[ValueError] = RaisesGroup(ValueError) + _ = RaisesGroup(RaisesGroup(ValueError)) # type: ignore + + a: BaseExceptionGroup[BaseExceptionGroup[ValueError]] + a = RaisesGroup(RaisesGroup(ValueError)) + a = BaseExceptionGroup("", (BaseExceptionGroup("", (ValueError(),)),)) + assert a + + +def check_matcher_typevar_default(e: Matcher) -> object: + assert e.exception_type is not None + exc: type[BaseException] = e.exception_type + # this would previously pass, as the type would be `Any` + e.exception_type().blah() # type: ignore + return exc # Silence Pyright unused var warning + + +def check_basic_contextmanager() -> None: + # One level of Group is correctly translated - except it's a BaseExceptionGroup + # instead of an ExceptionGroup. + with RaisesGroup(ValueError) as e: + raise ExceptionGroup("foo", (ValueError(),)) + assert_type(e.value, BaseExceptionGroup[ValueError]) + + +def check_basic_matches() -> None: + # check that matches gets rid of the naked ValueError in the union + exc: ExceptionGroup[ValueError] | ValueError = ExceptionGroup("", (ValueError(),)) + if RaisesGroup(ValueError).matches(exc): + assert_type(exc, BaseExceptionGroup[ValueError]) + + +def check_matches_with_different_exception_type() -> None: + # This should probably raise some type error somewhere, since + # ValueError != KeyboardInterrupt + e: BaseExceptionGroup[KeyboardInterrupt] = BaseExceptionGroup( + "", (KeyboardInterrupt(),) + ) + if RaisesGroup(ValueError).matches(e): + assert_type(e, BaseExceptionGroup[ValueError]) + + +def check_matcher_init() -> None: + def check_exc(exc: BaseException) -> bool: + return isinstance(exc, ValueError) + + # Check various combinations of constructor signatures. + # At least 1 arg must be provided. + Matcher() # type: ignore + Matcher(ValueError) + Matcher(ValueError, "regex") + Matcher(ValueError, "regex", check_exc) + Matcher(exception_type=ValueError) + Matcher(match="regex") + Matcher(check=check_exc) + Matcher(ValueError, match="regex") + Matcher(match="regex", check=check_exc) + + def check_filenotfound(exc: FileNotFoundError) -> bool: + return not exc.filename.endswith(".tmp") + + # If exception_type is provided, that narrows the `check` method's argument. + Matcher(FileNotFoundError, check=check_filenotfound) + Matcher(ValueError, check=check_filenotfound) # type: ignore + Matcher(check=check_filenotfound) # type: ignore + Matcher(FileNotFoundError, match="regex", check=check_filenotfound) + + +def raisesgroup_check_type_narrowing() -> None: + """Check type narrowing on the `check` argument to `RaisesGroup`. + All `type: ignore`s are correctly pointing out type errors, except + where otherwise noted. + + + """ + + def handle_exc(e: BaseExceptionGroup[BaseException]) -> bool: + return True + + def handle_kbi(e: BaseExceptionGroup[KeyboardInterrupt]) -> bool: + return True + + def handle_value(e: BaseExceptionGroup[ValueError]) -> bool: + return True + + RaisesGroup(BaseException, check=handle_exc) + RaisesGroup(BaseException, check=handle_kbi) # type: ignore + + RaisesGroup(Exception, check=handle_exc) + RaisesGroup(Exception, check=handle_value) # type: ignore + + RaisesGroup(KeyboardInterrupt, check=handle_exc) + RaisesGroup(KeyboardInterrupt, check=handle_kbi) + RaisesGroup(KeyboardInterrupt, check=handle_value) # type: ignore + + RaisesGroup(ValueError, check=handle_exc) + RaisesGroup(ValueError, check=handle_kbi) # type: ignore + RaisesGroup(ValueError, check=handle_value) + + RaisesGroup(ValueError, KeyboardInterrupt, check=handle_exc) + RaisesGroup(ValueError, KeyboardInterrupt, check=handle_kbi) # type: ignore + RaisesGroup(ValueError, KeyboardInterrupt, check=handle_value) # type: ignore + + +def raisesgroup_narrow_baseexceptiongroup() -> None: + """Check type narrowing specifically for the container exceptiongroup. + This is not currently working, and after playing around with it for a bit + I think the only way is to introduce a subclass `NonBaseRaisesGroup`, and overload + `__new__` in Raisesgroup to return the subclass when exceptions are non-base. + (or make current class BaseRaisesGroup and introduce RaisesGroup for non-base) + I encountered problems trying to type this though, see + https://github.com/python/mypy/issues/17251 + That is probably possible to work around by entirely using `__new__` instead of + `__init__`, but........ ugh. + """ + + def handle_group(e: ExceptionGroup[Exception]) -> bool: + return True + + def handle_group_value(e: ExceptionGroup[ValueError]) -> bool: + return True + + # should work, but BaseExceptionGroup does not get narrowed to ExceptionGroup + RaisesGroup(ValueError, check=handle_group_value) # type: ignore + + # should work, but BaseExceptionGroup does not get narrowed to ExceptionGroup + RaisesGroup(Exception, check=handle_group) # type: ignore + + +def check_matcher_transparent() -> None: + with RaisesGroup(Matcher(ValueError)) as e: + ... + _: BaseExceptionGroup[ValueError] = e.value + assert_type(e.value, BaseExceptionGroup[ValueError]) + + +def check_nested_raisesgroups_contextmanager() -> None: + with RaisesGroup(RaisesGroup(ValueError)) as excinfo: + raise ExceptionGroup("foo", (ValueError(),)) + + # thanks to inheritance this assignment works + _: BaseExceptionGroup[BaseExceptionGroup[ValueError]] = excinfo.value + # and it can mostly be treated like an exceptiongroup + print(excinfo.value.exceptions[0].exceptions[0]) + + # but assert_type reveals the lies + print(type(excinfo.value)) # would print "ExceptionGroup" + # typing says it's a BaseExceptionGroup + assert_type( + excinfo.value, + BaseExceptionGroup[RaisesGroup[ValueError]], + ) + + print(type(excinfo.value.exceptions[0])) # would print "ExceptionGroup" + # but type checkers are utterly confused + assert_type( + excinfo.value.exceptions[0], + Union[RaisesGroup[ValueError], BaseExceptionGroup[RaisesGroup[ValueError]]], + ) + + +def check_nested_raisesgroups_matches() -> None: + """Check nested RaisesGroups with .matches""" + exc: ExceptionGroup[ExceptionGroup[ValueError]] = ExceptionGroup( + "", (ExceptionGroup("", (ValueError(),)),) + ) + # has the same problems as check_nested_raisesgroups_contextmanager + if RaisesGroup(RaisesGroup(ValueError)).matches(exc): + assert_type(exc, BaseExceptionGroup[RaisesGroup[ValueError]]) + + +def check_multiple_exceptions_1() -> None: + a = RaisesGroup(ValueError, ValueError) + b = RaisesGroup(Matcher(ValueError), Matcher(ValueError)) + c = RaisesGroup(ValueError, Matcher(ValueError)) + + d: BaseExceptionGroup[ValueError] + d = a + d = b + d = c + assert d + + +def check_multiple_exceptions_2() -> None: + # This previously failed due to lack of covariance in the TypeVar + a = RaisesGroup(Matcher(ValueError), Matcher(TypeError)) + b = RaisesGroup(Matcher(ValueError), TypeError) + c = RaisesGroup(ValueError, TypeError) + + d: BaseExceptionGroup[Exception] + d = a + d = b + d = c + assert d + + +def check_raisesgroup_overloads() -> None: + # allow_unwrapped=True does not allow: + # multiple exceptions + RaisesGroup(ValueError, TypeError, allow_unwrapped=True) # type: ignore + # nested RaisesGroup + RaisesGroup(RaisesGroup(ValueError), allow_unwrapped=True) # type: ignore + # specifying match + RaisesGroup(ValueError, match="foo", allow_unwrapped=True) # type: ignore + # specifying check + RaisesGroup(ValueError, check=bool, allow_unwrapped=True) # type: ignore + # allowed variants + RaisesGroup(ValueError, allow_unwrapped=True) + RaisesGroup(ValueError, allow_unwrapped=True, flatten_subgroups=True) + RaisesGroup(Matcher(ValueError), allow_unwrapped=True) + + # flatten_subgroups=True does not allow nested RaisesGroup + RaisesGroup(RaisesGroup(ValueError), flatten_subgroups=True) # type: ignore + # but rest is plenty fine + RaisesGroup(ValueError, TypeError, flatten_subgroups=True) + RaisesGroup(ValueError, match="foo", flatten_subgroups=True) + RaisesGroup(ValueError, check=bool, flatten_subgroups=True) + RaisesGroup(ValueError, flatten_subgroups=True) + RaisesGroup(Matcher(ValueError), flatten_subgroups=True) + + # if they're both false we can of course specify nested raisesgroup + RaisesGroup(RaisesGroup(ValueError)) diff --git a/src/trio/_tests/type_tests/task_status.py b/src/trio/_tests/type_tests/task_status.py new file mode 100644 index 0000000000..90cfc6957f --- /dev/null +++ b/src/trio/_tests/type_tests/task_status.py @@ -0,0 +1,29 @@ +"""Check that started() can only be called for TaskStatus[None].""" + +from trio import TaskStatus +from typing_extensions import assert_type + + +async def check_status( + none_status_explicit: TaskStatus[None], + none_status_implicit: TaskStatus, + int_status: TaskStatus[int], +) -> None: + assert_type(none_status_explicit, TaskStatus[None]) + assert_type(none_status_implicit, TaskStatus[None]) # Default typevar + assert_type(int_status, TaskStatus[int]) + + # Omitting the parameter is only allowed for None. + none_status_explicit.started() + none_status_implicit.started() + int_status.started() # type: ignore + + # Explicit None is allowed. + none_status_explicit.started(None) + none_status_implicit.started(None) + int_status.started(None) # type: ignore + + none_status_explicit.started(42) # type: ignore + none_status_implicit.started(42) # type: ignore + int_status.started(42) + int_status.started(True) diff --git a/trio/_threads.py b/src/trio/_threads.py similarity index 74% rename from trio/_threads.py rename to src/trio/_threads.py index 30c3fd835e..a04b737292 100644 --- a/trio/_threads.py +++ b/src/trio/_threads.py @@ -2,31 +2,35 @@ import contextlib import contextvars -import functools import inspect import queue as stdlib_queue import threading -from collections.abc import Awaitable, Callable from itertools import count -from typing import Generic, TypeVar +from typing import TYPE_CHECKING, Generic, TypeVar -import attr +import attrs import outcome +from attrs import define from sniffio import current_async_library_cvar import trio -from trio._core._traps import RaiseCancelT from ._core import ( RunVar, TrioToken, + checkpoint, disable_ki_protection, enable_ki_protection, start_thread_soon, ) -from ._sync import CapacityLimiter +from ._sync import CapacityLimiter, Event from ._util import coroutine_or_error +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable, Generator + + from trio._core._traps import RaiseCancelT + RetT = TypeVar("RetT") @@ -49,6 +53,72 @@ class _ParentTaskData(threading.local): _thread_counter = count() +@define +class _ActiveThreadCount: + count: int + event: Event + + +_active_threads_local: RunVar[_ActiveThreadCount] = RunVar("active_threads") + + +@contextlib.contextmanager +def _track_active_thread() -> Generator[None, None, None]: + try: + active_threads_local = _active_threads_local.get() + except LookupError: + active_threads_local = _ActiveThreadCount(0, Event()) + _active_threads_local.set(active_threads_local) + + active_threads_local.count += 1 + try: + yield + finally: + active_threads_local.count -= 1 + if active_threads_local.count == 0: + active_threads_local.event.set() + active_threads_local.event = Event() + + +async def wait_all_threads_completed() -> None: + """Wait until no threads are still running tasks. + + This is intended to be used when testing code with trio.to_thread to + make sure no tasks are still making progress in a thread. See the + following code for a usage example:: + + async def wait_all_settled(): + while True: + await trio.testing.wait_all_threads_complete() + await trio.testing.wait_all_tasks_blocked() + if trio.testing.active_thread_count() == 0: + break + """ + + await checkpoint() + + try: + active_threads_local = _active_threads_local.get() + except LookupError: + # If there would have been active threads, the + # _active_threads_local would have been set + return + + while active_threads_local.count != 0: + await active_threads_local.event.wait() + + +def active_thread_count() -> int: + """Returns the number of threads that are currently running a task + + See `trio.testing.wait_all_threads_completed` + """ + try: + return _active_threads_local.get().count + except LookupError: + return 0 + + def current_default_thread_limiter() -> CapacityLimiter: """Get the default `~trio.CapacityLimiter` used by `trio.to_thread.run_sync`. @@ -69,18 +139,20 @@ def current_default_thread_limiter() -> CapacityLimiter: # system; see https://github.com/python-trio/trio/issues/182 # But for now we just need an object to stand in for the thread, so we can # keep track of who's holding the CapacityLimiter's token. -@attr.s(frozen=True, eq=False, hash=False) +@attrs.frozen(eq=False, hash=False, slots=False) class ThreadPlaceholder: - name: str = attr.ib() + name: str # Types for the to_thread_run_sync message loop -@attr.s(frozen=True, eq=False) +@attrs.frozen(eq=False, slots=False) class Run(Generic[RetT]): - afn: Callable[..., Awaitable[RetT]] = attr.ib() - args: tuple[object, ...] = attr.ib() - context: contextvars.Context = attr.ib() - queue: stdlib_queue.SimpleQueue[outcome.Outcome[RetT]] = attr.ib( + afn: Callable[..., Awaitable[RetT]] + args: tuple[object, ...] + context: contextvars.Context = attrs.field( + init=False, factory=contextvars.copy_context + ) + queue: stdlib_queue.SimpleQueue[outcome.Outcome[RetT]] = attrs.field( init=False, factory=stdlib_queue.SimpleQueue ) @@ -128,18 +200,20 @@ def in_trio_thread() -> None: token.run_sync_soon(in_trio_thread) -@attr.s(frozen=True, eq=False) +@attrs.frozen(eq=False, slots=False) class RunSync(Generic[RetT]): - fn: Callable[..., RetT] = attr.ib() - args: tuple[object, ...] = attr.ib() - context: contextvars.Context = attr.ib() - queue: stdlib_queue.SimpleQueue[outcome.Outcome[RetT]] = attr.ib( + fn: Callable[..., RetT] + args: tuple[object, ...] + context: contextvars.Context = attrs.field( + init=False, factory=contextvars.copy_context + ) + queue: stdlib_queue.SimpleQueue[outcome.Outcome[RetT]] = attrs.field( init=False, factory=stdlib_queue.SimpleQueue ) @disable_ki_protection def unprotected_fn(self) -> RetT: - ret = self.fn(*self.args) + ret = self.context.run(self.fn, *self.args) if inspect.iscoroutine(ret): # Manually close coroutine to avoid RuntimeWarnings @@ -152,9 +226,7 @@ def unprotected_fn(self) -> RetT: return ret def run_sync(self) -> None: - # Two paramspecs + overload is a bit too hard for mypy to handle. Tell it what to infer. - runner: Callable[[Callable[[], RetT]], RetT] = self.context.run - result = outcome.capture(runner, self.unprotected_fn) + result = outcome.capture(self.unprotected_fn) self.queue.put_nowait(result) def run_in_host_task(self, token: TrioToken) -> None: @@ -176,7 +248,7 @@ async def to_thread_run_sync( # type: ignore[misc] sync_fn: Callable[..., RetT], *args: object, thread_name: str | None = None, - cancellable: bool = False, + abandon_on_cancel: bool = False, limiter: CapacityLimiter | None = None, ) -> RetT: """Convert a blocking operation into an async operation using a thread. @@ -198,8 +270,8 @@ async def to_thread_run_sync( # type: ignore[misc] sync_fn: An arbitrary synchronous callable. *args: Positional arguments to pass to sync_fn. If you need keyword arguments, use :func:`functools.partial`. - cancellable (bool): Whether to allow cancellation of this operation. See - discussion below. + abandon_on_cancel (bool): Whether to abandon this thread upon + cancellation of this operation. See discussion below. thread_name (str): Optional string to set the name of the thread. Will always set `threading.Thread.name`, but only set the os name if pthread.h is available (i.e. most POSIX installations). @@ -225,17 +297,17 @@ async def to_thread_run_sync( # type: ignore[misc] starting the thread. But once the thread is running, there are two ways it can handle being cancelled: - * If ``cancellable=False``, the function ignores the cancellation and + * If ``abandon_on_cancel=False``, the function ignores the cancellation and keeps going, just like if we had called ``sync_fn`` synchronously. This is the default behavior. - * If ``cancellable=True``, then this function immediately raises + * If ``abandon_on_cancel=True``, then this function immediately raises `~trio.Cancelled`. In this case **the thread keeps running in background** – we just abandon it to do whatever it's going to do, and silently discard any return value or errors that it raises. Only use this if you know that the operation is safe and side-effect free. (For example: :func:`trio.socket.getaddrinfo` uses a thread with - ``cancellable=True``, because it doesn't really affect anything if a + ``abandon_on_cancel=True``, because it doesn't really affect anything if a stray hostname lookup keeps running in the background.) The ``limiter`` is only released after the thread has *actually* @@ -263,7 +335,9 @@ async def to_thread_run_sync( # type: ignore[misc] """ await trio.lowlevel.checkpoint_if_cancelled() - abandon_on_cancel = bool(cancellable) # raise early if cancellable.__bool__ raises + # raise early if abandon_on_cancel.__bool__ raises + # and give a new name to ensure mypy knows it's never None + abandon_bool = bool(abandon_on_cancel) if limiter is None: limiter = current_default_thread_limiter() @@ -300,18 +374,12 @@ def do_release_then_return_result() -> RetT: thread_name = f"{getattr(sync_fn, '__name__', None)} from {trio.lowlevel.current_task().name}" def worker_fn() -> RetT: - # Trio doesn't use current_async_library_cvar, but if someone - # else set it, it would now shine through since - # snifio.thread_local isn't set in the new thread. Make sure - # the new thread sees that it's not running in async context. - current_async_library_cvar.set(None) - PARENT_TASK_DATA.token = current_trio_token - PARENT_TASK_DATA.abandon_on_cancel = abandon_on_cancel + PARENT_TASK_DATA.abandon_on_cancel = abandon_bool PARENT_TASK_DATA.cancel_register = cancel_register PARENT_TASK_DATA.task_register = task_register try: - ret = sync_fn(*args) + ret = context.run(sync_fn, *args) if inspect.iscoroutine(ret): # Manually close coroutine to avoid RuntimeWarnings @@ -329,66 +397,67 @@ def worker_fn() -> RetT: del PARENT_TASK_DATA.task_register context = contextvars.copy_context() - # Partial confuses type checkers, coerce to a callable. - contextvars_aware_worker_fn: Callable[[], RetT] = functools.partial(context.run, worker_fn) # type: ignore[assignment] + # Trio doesn't use current_async_library_cvar, but if someone + # else set it, it would now shine through since + # sniffio.thread_local isn't set in the new thread. Make sure + # the new thread sees that it's not running in async context. + context.run(current_async_library_cvar.set, None) def deliver_worker_fn_result(result: outcome.Outcome[RetT]) -> None: - # The entire run finished, so the task we're trying to contact is + # If the entire run finished, the task we're trying to contact is # certainly long gone -- it must have been cancelled and abandoned - # us. + # us. Just ignore the error in this case. with contextlib.suppress(trio.RunFinishedError): current_trio_token.run_sync_soon(report_back_in_trio_thread_fn, result) await limiter.acquire_on_behalf_of(placeholder) - try: - start_thread_soon( - contextvars_aware_worker_fn, deliver_worker_fn_result, thread_name - ) - except: - limiter.release_on_behalf_of(placeholder) - raise - - def abort(raise_cancel: RaiseCancelT) -> trio.lowlevel.Abort: - # fill so from_thread_check_cancelled can raise - cancel_register[0] = raise_cancel - if abandon_on_cancel: - # empty so report_back_in_trio_thread_fn cannot reschedule - task_register[0] = None - return trio.lowlevel.Abort.SUCCEEDED - else: - return trio.lowlevel.Abort.FAILED - - while True: - # wait_task_rescheduled return value cannot be typed - msg_from_thread: outcome.Outcome[RetT] | Run[object] | RunSync[ - object - ] = await trio.lowlevel.wait_task_rescheduled(abort) - if isinstance(msg_from_thread, outcome.Outcome): - return msg_from_thread.unwrap() - elif isinstance(msg_from_thread, Run): - await msg_from_thread.run() - elif isinstance(msg_from_thread, RunSync): - msg_from_thread.run_sync() - else: # pragma: no cover, internal debugging guard TODO: use assert_never - raise TypeError( - "trio.to_thread.run_sync received unrecognized thread message {!r}." - "".format(msg_from_thread) + with _track_active_thread(): + try: + start_thread_soon(worker_fn, deliver_worker_fn_result, thread_name) + except: + limiter.release_on_behalf_of(placeholder) + raise + + def abort(raise_cancel: RaiseCancelT) -> trio.lowlevel.Abort: + # fill so from_thread_check_cancelled can raise + cancel_register[0] = raise_cancel + if abandon_bool: + # empty so report_back_in_trio_thread_fn cannot reschedule + task_register[0] = None + return trio.lowlevel.Abort.SUCCEEDED + else: + return trio.lowlevel.Abort.FAILED + + while True: + # wait_task_rescheduled return value cannot be typed + msg_from_thread: outcome.Outcome[RetT] | Run[object] | RunSync[object] = ( + await trio.lowlevel.wait_task_rescheduled(abort) ) - del msg_from_thread + if isinstance(msg_from_thread, outcome.Outcome): + return msg_from_thread.unwrap() + elif isinstance(msg_from_thread, Run): + await msg_from_thread.run() + elif isinstance(msg_from_thread, RunSync): + msg_from_thread.run_sync() + else: # pragma: no cover, internal debugging guard TODO: use assert_never + raise TypeError( + f"trio.to_thread.run_sync received unrecognized thread message {msg_from_thread!r}." + ) + del msg_from_thread def from_thread_check_cancelled() -> None: """Raise `trio.Cancelled` if the associated Trio task entered a cancelled status. Only applicable to threads spawned by `trio.to_thread.run_sync`. Poll to allow - ``cancellable=False`` threads to raise :exc:`~trio.Cancelled` at a suitable - place, or to end abandoned ``cancellable=True`` threads sooner than they may + ``abandon_on_cancel=False`` threads to raise :exc:`~trio.Cancelled` at a suitable + place, or to end abandoned ``abandon_on_cancel=True`` threads sooner than they may otherwise. Raises: Cancelled: If the corresponding call to `trio.to_thread.run_sync` has had a delivery of cancellation attempted against it, regardless of the value of - ``cancellable`` supplied as an argument to it. + ``abandon_on_cancel`` supplied as an argument to it. RuntimeError: If this thread is not spawned from `trio.to_thread.run_sync`. .. note:: @@ -406,31 +475,29 @@ def from_thread_check_cancelled() -> None: """ try: raise_cancel = PARENT_TASK_DATA.cancel_register[0] - except AttributeError as exc: + except AttributeError: raise RuntimeError( "this thread wasn't created by Trio, can't check for cancellation" - ) from exc + ) from None if raise_cancel is not None: raise_cancel() -def _check_token(trio_token: TrioToken | None) -> TrioToken: - """Raise a RuntimeError if this function is called within a trio run. - - Avoids deadlock by making sure we're not called from inside a context - that we might be waiting for and blocking it. - """ - - if trio_token is not None and not isinstance(trio_token, TrioToken): - raise RuntimeError("Passed kwarg trio_token is not of type TrioToken") +def _send_message_to_trio( + trio_token: TrioToken | None, message_to_trio: Run[RetT] | RunSync[RetT] +) -> RetT: + """Shared logic of from_thread functions""" + token_provided = trio_token is not None - if trio_token is None: + if not token_provided: try: trio_token = PARENT_TASK_DATA.token except AttributeError: raise RuntimeError( "this thread wasn't created by Trio, pass kwarg trio_token=..." ) from None + elif not isinstance(trio_token, TrioToken): + raise RuntimeError("Passed kwarg trio_token is not of type TrioToken") # Avoid deadlock by making sure we're not called from Trio thread try: @@ -440,7 +507,12 @@ def _check_token(trio_token: TrioToken | None) -> TrioToken: else: raise RuntimeError("this is a blocking function; call it from a thread") - return trio_token + if token_provided or PARENT_TASK_DATA.abandon_on_cancel: + message_to_trio.run_in_system_nursery(trio_token) + else: + message_to_trio.run_in_host_task(trio_token) + + return message_to_trio.queue.get().unwrap() def from_thread_run( @@ -464,9 +536,7 @@ def from_thread_run( Cancelled: If the original call to :func:`trio.to_thread.run_sync` is cancelled (if *trio_token* is None) or the call to :func:`trio.run` completes (if *trio_token* is not None) while ``afn(*args)`` is running, - then *afn* is likely to raise - completes while ``afn(*args)`` is running, then ``afn`` is likely to raise - :exc:`trio.Cancelled`. + then *afn* is likely to raise :exc:`trio.Cancelled`. RuntimeError: if you try calling this from inside the Trio thread, which would otherwise cause a deadlock, or if no ``trio_token`` was provided, and we can't infer one from context. @@ -483,19 +553,10 @@ def from_thread_run( "foreign" thread, spawned using some other framework, and still want to enter Trio, or if you want to use a new system task to call ``afn``, maybe to avoid the cancellation context of a corresponding - `trio.to_thread.run_sync` task. + `trio.to_thread.run_sync` task. You can get this token from + :func:`trio.lowlevel.current_trio_token`. """ - token_provided = trio_token is not None - trio_token = _check_token(trio_token) - - message_to_trio = Run(afn, args, contextvars.copy_context()) - - if token_provided or PARENT_TASK_DATA.abandon_on_cancel: - message_to_trio.run_in_system_nursery(trio_token) - else: - message_to_trio.run_in_host_task(trio_token) - - return message_to_trio.queue.get().unwrap() + return _send_message_to_trio(trio_token, Run(afn, args)) def from_thread_run_sync( @@ -533,14 +594,4 @@ def from_thread_run_sync( maybe to avoid the cancellation context of a corresponding `trio.to_thread.run_sync` task. """ - token_provided = trio_token is not None - trio_token = _check_token(trio_token) - - message_to_trio = RunSync(fn, args, contextvars.copy_context()) - - if token_provided or PARENT_TASK_DATA.abandon_on_cancel: - message_to_trio.run_in_system_nursery(trio_token) - else: - message_to_trio.run_in_host_task(trio_token) - - return message_to_trio.queue.get().unwrap() + return _send_message_to_trio(trio_token, RunSync(fn, args)) diff --git a/trio/_timeouts.py b/src/trio/_timeouts.py similarity index 100% rename from trio/_timeouts.py rename to src/trio/_timeouts.py diff --git a/trio/_tools/__init__.py b/src/trio/_tools/__init__.py similarity index 100% rename from trio/_tools/__init__.py rename to src/trio/_tools/__init__.py diff --git a/trio/_tools/gen_exports.py b/src/trio/_tools/gen_exports.py similarity index 92% rename from trio/_tools/gen_exports.py rename to src/trio/_tools/gen_exports.py index 2eec02dfc1..4ecb29511e 100755 --- a/trio/_tools/gen_exports.py +++ b/src/trio/_tools/gen_exports.py @@ -10,14 +10,15 @@ import os import subprocess import sys -from collections.abc import Iterable, Iterator from pathlib import Path from textwrap import indent from typing import TYPE_CHECKING -import attr +import attrs if TYPE_CHECKING: + from collections.abc import Iterable, Iterator + from typing_extensions import TypeGuard # keep these imports up to date with conditional imports in test_gen_exports @@ -31,11 +32,13 @@ # ************************************************************* from __future__ import annotations +import sys + from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._run import GLOBAL_RUN_CONTEXT """ -TEMPLATE = """locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True +TEMPLATE = """sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return{}GLOBAL_RUN_CONTEXT.{}.{} except AttributeError: @@ -43,21 +46,19 @@ """ -@attr.define +@attrs.define class File: path: Path modname: str - platform: str = attr.field(default="", kw_only=True) - imports: str = attr.field(default="", kw_only=True) + platform: str = attrs.field(default="", kw_only=True) + imports: str = attrs.field(default="", kw_only=True) def is_function(node: ast.AST) -> TypeGuard[ast.FunctionDef | ast.AsyncFunctionDef]: """Check if the AST node is either a function or an async function """ - if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): - return True - return False + return isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) def is_public(node: ast.AST) -> TypeGuard[ast.FunctionDef | ast.AsyncFunctionDef]: @@ -90,13 +91,11 @@ def create_passthrough_args(funcdef: ast.FunctionDef | ast.AsyncFunctionDef) -> Example input: ast.parse("def f(a, *, b): ...") Example output: "(a, b=b)" """ - call_args = [] - for arg in funcdef.args.args: - call_args.append(arg.arg) + call_args = [arg.arg for arg in funcdef.args.args] if funcdef.args.vararg: call_args.append("*" + funcdef.args.vararg.arg) for arg in funcdef.args.kwonlyargs: - call_args.append(arg.arg + "=" + arg.arg) + call_args.append(arg.arg + "=" + arg.arg) # noqa: PERF401 # clarity if funcdef.args.kwarg: call_args.append("**" + funcdef.args.kwarg.arg) return "({})".format(", ".join(call_args)) @@ -156,7 +155,7 @@ def run_ruff(file: File, source: str) -> tuple[bool, str]: "ruff", "check", "--fix", - "--output-format=text", + "--unsafe-fixes", "--stdin-filename", file.path, "-", @@ -218,10 +217,12 @@ def gen_public_wrappers_source(file: File) -> str: generated = ["".join(header)] source = astor.code_to_ast.parse_file(file.path) + method_names = [] for method in get_public_methods(source): # Remove self from arguments assert method.args.args[0].arg == "self" del method.args.args[0] + method_names.append(method.name) for dec in method.decorator_list: # pragma: no cover if isinstance(dec, ast.Name) and dec.id == "contextmanager": @@ -261,6 +262,10 @@ def gen_public_wrappers_source(file: File) -> str: # Append the snippet to the corresponding module generated.append(snippet) + + method_names.sort() + # Insert after the header, before function definitions + generated.insert(1, f"__all__ = {method_names!r}") return "\n\n".join(generated) @@ -311,7 +316,7 @@ def main() -> None: # pragma: no cover source_root = Path.cwd() # Double-check we found the right directory assert (source_root / "LICENSE").exists() - core = source_root / "trio/_core" + core = source_root / "src/trio/_core" to_wrap = [ File(core / "_run.py", "runner", imports=IMPORTS_RUN), File( @@ -344,7 +349,7 @@ def main() -> None: # pragma: no cover IMPORTS_RUN = """\ from collections.abc import Awaitable, Callable -from typing import Any +from typing import Any, TYPE_CHECKING from outcome import Outcome import contextvars @@ -352,6 +357,10 @@ def main() -> None: # pragma: no cover from ._run import _NO_SEND, RunStatistics, Task from ._entry_queue import TrioToken from .._abc import Clock + +if TYPE_CHECKING: + from typing_extensions import Unpack + from ._run import PosArgT """ IMPORTS_INSTRUMENT = """\ from ._instrumentation import Instrument diff --git a/trio/_tools/mypy_annotate.py b/src/trio/_tools/mypy_annotate.py similarity index 99% rename from trio/_tools/mypy_annotate.py rename to src/trio/_tools/mypy_annotate.py index 0ae84d61e3..6bd20f401c 100644 --- a/trio/_tools/mypy_annotate.py +++ b/src/trio/_tools/mypy_annotate.py @@ -6,6 +6,7 @@ mypy_annotate.dat. After all platforms run, we run this again, which prints the messages in GitHub's format but with cross-platform failures deduplicated. """ + from __future__ import annotations import argparse diff --git a/trio/_unix_pipes.py b/src/trio/_unix_pipes.py similarity index 98% rename from trio/_unix_pipes.py rename to src/trio/_unix_pipes.py index 476d91f6bc..34340d2b36 100644 --- a/trio/_unix_pipes.py +++ b/src/trio/_unix_pipes.py @@ -179,13 +179,13 @@ async def receive_some(self, max_bytes: int | None = None) -> bytes: data = os.read(self._fd_holder.fd, max_bytes) except BlockingIOError: await trio.lowlevel.wait_readable(self._fd_holder.fd) - except OSError as e: - if e.errno == errno.EBADF: + except OSError as exc: + if exc.errno == errno.EBADF: raise trio.ClosedResourceError( "file was already closed" ) from None else: - raise trio.BrokenResourceError from e + raise trio.BrokenResourceError from exc else: break diff --git a/trio/_util.py b/src/trio/_util.py similarity index 90% rename from trio/_util.py rename to src/trio/_util.py index e9864cf2fd..7c9e194d19 100644 --- a/trio/_util.py +++ b/src/trio/_util.py @@ -6,29 +6,40 @@ import os import signal import threading -import typing as t from abc import ABCMeta from functools import update_wrapper -from types import AsyncGeneratorType, TracebackType +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Generic, + NoReturn, + Sequence, + TypeVar, + final as std_final, +) from sniffio import thread_local as sniffio_loop import trio -CallT = t.TypeVar("CallT", bound=t.Callable[..., t.Any]) -T = t.TypeVar("T") -RetT = t.TypeVar("RetT") +CallT = TypeVar("CallT", bound=Callable[..., Any]) +T = TypeVar("T") +RetT = TypeVar("RetT") -if t.TYPE_CHECKING: - from typing_extensions import ParamSpec, Self +if TYPE_CHECKING: + from types import AsyncGeneratorType, TracebackType + + from typing_extensions import ParamSpec, Self, TypeVarTuple, Unpack ArgsT = ParamSpec("ArgsT") + PosArgsT = TypeVarTuple("PosArgsT") -if t.TYPE_CHECKING: +if TYPE_CHECKING: # Don't type check the implementation below, pthread_kill does not exist on Windows. - def signal_raise(signum: int) -> None: - ... + def signal_raise(signum: int) -> None: ... # Equivalent to the C function raise(), which Python doesn't wrap @@ -101,10 +112,10 @@ def is_main_thread() -> bool: # Call the function and get the coroutine object, while giving helpful # errors for common mistakes. Returns coroutine object. ###### -# TODO: Use TypeVarTuple here. def coroutine_or_error( - async_fn: t.Callable[..., t.Awaitable[RetT]], *args: t.Any -) -> collections.abc.Coroutine[object, t.NoReturn, RetT]: + async_fn: Callable[[Unpack[PosArgsT]], Awaitable[RetT]], + *args: Unpack[PosArgsT], +) -> collections.abc.Coroutine[object, NoReturn, RetT]: def _return_value_looks_like_wrong_library(value: object) -> bool: # Returned by legacy @asyncio.coroutine functions, which includes # a surprising proportion of asyncio builtins. @@ -116,9 +127,7 @@ def _return_value_looks_like_wrong_library(value: object) -> bool: # This janky check catches tornado Futures and twisted Deferreds. # By the time we're calling this function, we already know # something has gone wrong, so a heuristic is pretty safe. - if value.__class__.__name__ in ("Future", "Deferred"): - return True - return False + return value.__class__.__name__ in ("Future", "Deferred") # Make sure a sync-fn-that-returns-coroutine still sees itself as being # in trio context @@ -229,18 +238,14 @@ def async_wraps( cls: type[object], wrapped_cls: type[object], attr_name: str, -) -> t.Callable[[CallT], CallT]: +) -> Callable[[CallT], CallT]: """Similar to wraps, but for async wrappers of non-async functions.""" def decorator(func: CallT) -> CallT: func.__name__ = attr_name func.__qualname__ = ".".join((cls.__qualname__, attr_name)) - func.__doc__ = """Like :meth:`~{}.{}.{}`, but async. - - """.format( - wrapped_cls.__module__, wrapped_cls.__qualname__, attr_name - ) + func.__doc__ = f"Like :meth:`~{wrapped_cls.__module__}.{wrapped_cls.__qualname__}.{attr_name}`, but async." return func @@ -281,7 +286,7 @@ def fix_one(qualname: str, name: str, obj: object) -> None: # We need ParamSpec to type this "properly", but that requires a runtime typing_extensions import # to use as a class base. This is only used at runtime and isn't correct for type checkers anyway, # so don't bother. -class generic_function(t.Generic[RetT]): +class generic_function(Generic[RetT]): """Decorator that makes a function indexable, to communicate non-inferrable generic type parameters to a static type checker. @@ -298,18 +303,18 @@ def open_memory_channel(max_buffer_size: int) -> Tuple[ but at least it becomes possible to write those. """ - def __init__(self, fn: t.Callable[..., RetT]) -> None: + def __init__(self, fn: Callable[..., RetT]) -> None: update_wrapper(self, fn) self._fn = fn - def __call__(self, *args: t.Any, **kwargs: t.Any) -> RetT: + def __call__(self, *args: Any, **kwargs: Any) -> RetT: return self._fn(*args, **kwargs) def __getitem__(self, subscript: object) -> Self: return self -def _init_final_cls(cls: type[object]) -> t.NoReturn: +def _init_final_cls(cls: type[object]) -> NoReturn: """Raises an exception when a final class is subclassed.""" raise TypeError(f"{cls.__module__}.{cls.__qualname__} does not support subclassing") @@ -333,10 +338,10 @@ class SomeClass: # matter what the original did (if anything). decorated.__init_subclass__ = classmethod(_init_final_cls) # type: ignore[assignment] # Apply the typing decorator, in 3.11+ it adds a __final__ marker attribute. - return t.final(decorated) + return std_final(decorated) -if t.TYPE_CHECKING: +if TYPE_CHECKING: from typing import final else: final = _final_impl @@ -372,7 +377,7 @@ def _create(cls: type[T], *args: object, **kwargs: object) -> T: return super().__call__(*args, **kwargs) # type: ignore -def name_asyncgen(agen: AsyncGeneratorType[object, t.NoReturn]) -> str: +def name_asyncgen(agen: AsyncGeneratorType[object, NoReturn]) -> str: """Return the fully-qualified name of the async generator function that produced the async generator iterator *agen*. """ @@ -390,15 +395,14 @@ def name_asyncgen(agen: AsyncGeneratorType[object, t.NoReturn]) -> str: # work around a pyright error -if t.TYPE_CHECKING: - Fn = t.TypeVar("Fn", bound=t.Callable[..., object]) +if TYPE_CHECKING: + Fn = TypeVar("Fn", bound=Callable[..., object]) def wraps( - wrapped: t.Callable[..., object], - assigned: t.Sequence[str] = ..., - updated: t.Sequence[str] = ..., - ) -> t.Callable[[Fn], Fn]: - ... + wrapped: Callable[..., object], + assigned: Sequence[str] = ..., + updated: Sequence[str] = ..., + ) -> Callable[[Fn], Fn]: ... else: from functools import wraps # noqa: F401 # this is re-exported diff --git a/src/trio/_version.py b/src/trio/_version.py new file mode 100644 index 0000000000..b777fa4efe --- /dev/null +++ b/src/trio/_version.py @@ -0,0 +1,3 @@ +# This file is imported from __init__.py and parsed by setuptools + +__version__ = "0.26.0+dev" diff --git a/trio/_wait_for_object.py b/src/trio/_wait_for_object.py similarity index 98% rename from trio/_wait_for_object.py rename to src/trio/_wait_for_object.py index d2193d9c86..53832513a3 100644 --- a/trio/_wait_for_object.py +++ b/src/trio/_wait_for_object.py @@ -45,7 +45,7 @@ async def WaitForSingleObject(obj: int | CData) -> None: WaitForMultipleObjects_sync, handle, cancel_handle, - cancellable=True, + abandon_on_cancel=True, limiter=trio.CapacityLimiter(math.inf), ) finally: diff --git a/trio/_windows_pipes.py b/src/trio/_windows_pipes.py similarity index 100% rename from trio/_windows_pipes.py rename to src/trio/_windows_pipes.py diff --git a/trio/abc.py b/src/trio/abc.py similarity index 100% rename from trio/abc.py rename to src/trio/abc.py diff --git a/trio/from_thread.py b/src/trio/from_thread.py similarity index 99% rename from trio/from_thread.py rename to src/trio/from_thread.py index 0de0023941..50f3bac28b 100644 --- a/trio/from_thread.py +++ b/src/trio/from_thread.py @@ -3,7 +3,6 @@ an external thread by means of a Trio Token present in Thread Local Storage """ - from ._threads import ( from_thread_check_cancelled as check_cancelled, from_thread_run as run, diff --git a/trio/lowlevel.py b/src/trio/lowlevel.py similarity index 94% rename from trio/lowlevel.py rename to src/trio/lowlevel.py index 964dabb556..1df7019637 100644 --- a/trio/lowlevel.py +++ b/src/trio/lowlevel.py @@ -3,7 +3,11 @@ but useful for extending Trio's functionality. """ +# imports are renamed with leading underscores to indicate they are not part of the public API + import select as _select + +# static checkers don't understand if importing this as _sys, so it's deleted later import sys import typing as _t diff --git a/trio/py.typed b/src/trio/py.typed similarity index 100% rename from trio/py.typed rename to src/trio/py.typed diff --git a/trio/socket.py b/src/trio/socket.py similarity index 94% rename from trio/socket.py rename to src/trio/socket.py index 3fd9e3ce91..e38501fb60 100644 --- a/trio/socket.py +++ b/src/trio/socket.py @@ -1,3 +1,5 @@ +from __future__ import annotations + # This is a public namespace, so we don't want to expose any non-underscored # attributes that aren't actually part of our public API. But it's very # annoying to carefully always use underscored names for module-level @@ -5,25 +7,23 @@ # implementation in an underscored module, and then re-export the public parts # here. # We still have some underscore names though but only a few. - - -# Uses `from x import y as y` for compatibility with `pyright --verifytypes` (#2625) - -# Dynamically re-export whatever constants this particular Python happens to -# have: import socket as _stdlib_socket + +# static checkers don't understand if importing this as _sys, so it's deleted later import sys import typing as _t from . import _socket -_bad_symbols: _t.Set[str] = set() +_bad_symbols: set[str] = set() if sys.platform == "win32": # See https://github.com/python-trio/trio/issues/39 # Do not import for windows platform # (you can still get it from stdlib socket, of course, if you want it) _bad_symbols.add("SO_REUSEADDR") +# Dynamically re-export whatever constants this particular Python happens to +# have: globals().update( { _name: getattr(_stdlib_socket, _name) @@ -35,6 +35,7 @@ # import the overwrites from contextlib import suppress as _suppress +# Uses `from x import y as y` for compatibility with `pyright --verifytypes` (#2625) from ._socket import ( SocketType as SocketType, from_stdlib_socket as from_stdlib_socket, @@ -67,13 +68,24 @@ ntohs as ntohs, ) +if sys.implementation.name == "cpython": + from socket import ( + if_indextoname as if_indextoname, + if_nametoindex as if_nametoindex, + ) + + # For android devices, if_nameindex support was introduced in API 24, + # so it doesn't exist for any version prior. + with _suppress(ImportError): + from socket import ( + if_nameindex as if_nameindex, + ) + + # not always available so expose only if if sys.platform != "win32" or not _t.TYPE_CHECKING: with _suppress(ImportError): from socket import ( - if_indextoname as if_indextoname, - if_nameindex as if_nameindex, - if_nametoindex as if_nametoindex, sethostname as sethostname, ) @@ -93,7 +105,7 @@ # re-export them. Since the exact set of constants varies depending on Python # version, platform, the libc installed on the system where Python was built, # etc., we figure out which constants to re-export dynamically at runtime (see -# below). But that confuses static analysis tools like jedi and mypy. So this +# above). But that confuses static analysis tools like jedi and mypy. So this # import statement statically lists every constant that *could* be # exported. There's a test in test_exports.py to make sure that the list is # kept up to date. @@ -109,6 +121,7 @@ AF_BRIDGE as AF_BRIDGE, AF_CAN as AF_CAN, AF_ECONET as AF_ECONET, + AF_HYPERV as AF_HYPERV, AF_INET as AF_INET, AF_INET6 as AF_INET6, AF_IPX as AF_IPX, @@ -237,6 +250,17 @@ HCI_DATA_DIR as HCI_DATA_DIR, HCI_FILTER as HCI_FILTER, HCI_TIME_STAMP as HCI_TIME_STAMP, + HV_GUID_BROADCAST as HV_GUID_BROADCAST, + HV_GUID_CHILDREN as HV_GUID_CHILDREN, + HV_GUID_LOOPBACK as HV_GUID_LOOPBACK, + HV_GUID_PARENT as HV_GUID_PARENT, + HV_GUID_WILDCARD as HV_GUID_WILDCARD, + HV_GUID_ZERO as HV_GUID_ZERO, + HV_PROTOCOL_RAW as HV_PROTOCOL_RAW, + HVSOCKET_ADDRESS_FLAG_PASSTHRU as HVSOCKET_ADDRESS_FLAG_PASSTHRU, + HVSOCKET_CONNECT_TIMEOUT as HVSOCKET_CONNECT_TIMEOUT, + HVSOCKET_CONNECT_TIMEOUT_MAX as HVSOCKET_CONNECT_TIMEOUT_MAX, + HVSOCKET_CONNECTED_SUSPEND as HVSOCKET_CONNECTED_SUSPEND, INADDR_ALLHOSTS_GROUP as INADDR_ALLHOSTS_GROUP, INADDR_ANY as INADDR_ANY, INADDR_BROADCAST as INADDR_BROADCAST, @@ -383,6 +407,7 @@ NETLINK_USERSOCK as NETLINK_USERSOCK, NETLINK_XFRM as NETLINK_XFRM, NI_DGRAM as NI_DGRAM, + NI_IDN as NI_IDN, NI_MAXHOST as NI_MAXHOST, NI_MAXSERV as NI_MAXSERV, NI_NAMEREQD as NI_NAMEREQD, @@ -431,6 +456,7 @@ SIOCGIFNAME as SIOCGIFNAME, SO_ACCEPTCONN as SO_ACCEPTCONN, SO_BINDTODEVICE as SO_BINDTODEVICE, + SO_BINDTOIFINDEX as SO_BINDTOIFINDEX, SO_BROADCAST as SO_BROADCAST, SO_DEBUG as SO_DEBUG, SO_DOMAIN as SO_DOMAIN, @@ -487,6 +513,7 @@ SYSPROTO_CONTROL as SYSPROTO_CONTROL, TCP_CC_INFO as TCP_CC_INFO, TCP_CONGESTION as TCP_CONGESTION, + TCP_CONNECTION_INFO as TCP_CONNECTION_INFO, TCP_CORK as TCP_CORK, TCP_DEFER_ACCEPT as TCP_DEFER_ACCEPT, TCP_FASTOPEN as TCP_FASTOPEN, diff --git a/trio/testing/__init__.py b/src/trio/testing/__init__.py similarity index 85% rename from trio/testing/__init__.py rename to src/trio/testing/__init__.py index fa683e1145..d93d33aab7 100644 --- a/trio/testing/__init__.py +++ b/src/trio/testing/__init__.py @@ -4,6 +4,10 @@ MockClock as MockClock, wait_all_tasks_blocked as wait_all_tasks_blocked, ) +from .._threads import ( + active_thread_count as active_thread_count, + wait_all_threads_completed as wait_all_threads_completed, +) from .._util import fixup_module_metadata from ._check_streams import ( check_half_closeable_stream as check_half_closeable_stream, @@ -24,6 +28,7 @@ memory_stream_pump as memory_stream_pump, ) from ._network import open_stream_to_socket_listener as open_stream_to_socket_listener +from ._raises_group import Matcher as Matcher, RaisesGroup as RaisesGroup from ._sequencer import Sequencer as Sequencer from ._trio_test import trio_test as trio_test diff --git a/trio/testing/_check_streams.py b/src/trio/testing/_check_streams.py similarity index 93% rename from trio/testing/_check_streams.py rename to src/trio/testing/_check_streams.py index fd91c43209..c54c99c1fe 100644 --- a/trio/testing/_check_streams.py +++ b/src/trio/testing/_check_streams.py @@ -2,9 +2,17 @@ from __future__ import annotations import random -from collections.abc import Generator +import sys from contextlib import contextmanager, suppress -from typing import TYPE_CHECKING, Awaitable, Callable, Generic, Tuple, TypeVar +from typing import ( + TYPE_CHECKING, + Awaitable, + Callable, + Generator, + Generic, + Tuple, + TypeVar, +) from .. import CancelScope, _core from .._abc import AsyncResource, HalfCloseableStream, ReceiveStream, SendStream, Stream @@ -18,6 +26,9 @@ ArgsT = ParamSpec("ArgsT") +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup + Res1 = TypeVar("Res1", bound=AsyncResource) Res2 = TypeVar("Res2", bound=AsyncResource) StreamMaker: TypeAlias = Callable[[], Awaitable[Tuple[Res1, Res2]]] @@ -42,15 +53,24 @@ async def __aexit__( await aclose_forcefully(self._second) +# This is used in this file instead of pytest.raises in order to avoid a dependency +# on pytest, as the check_* functions are publicly exported. @contextmanager -def _assert_raises(exc: type[BaseException]) -> Generator[None, None, None]: +def _assert_raises( + expected_exc: type[BaseException], wrapped: bool = False +) -> Generator[None, None, None]: __tracebackhide__ = True try: yield - except exc: - pass + except BaseExceptionGroup as exc: + assert wrapped, "caught exceptiongroup, but expected an unwrapped exception" + # assert in except block ignored below + assert len(exc.exceptions) == 1 # noqa: PT017 + assert isinstance(exc.exceptions[0], expected_exc) # noqa: PT017 + except expected_exc: + assert not wrapped, "caught exception, but expected an exceptiongroup" else: - raise AssertionError(f"expected exception: {exc}") + raise AssertionError(f"expected exception: {expected_exc}") async def check_one_way_stream( @@ -135,7 +155,7 @@ async def send_empty_then_y() -> None: nursery.start_soon(do_send_all, b"x") assert await do_receive_some(None) == b"x" - with _assert_raises(_core.BusyResourceError): + with _assert_raises(_core.BusyResourceError, wrapped=True): async with _core.open_nursery() as nursery: nursery.start_soon(do_receive_some, 1) nursery.start_soon(do_receive_some, 1) @@ -333,7 +353,7 @@ async def receiver() -> None: async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r): # simultaneous wait_send_all_might_not_block fails - with _assert_raises(_core.BusyResourceError): + with _assert_raises(_core.BusyResourceError, wrapped=True): async with _core.open_nursery() as nursery: nursery.start_soon(s.wait_send_all_might_not_block) nursery.start_soon(s.wait_send_all_might_not_block) @@ -342,7 +362,7 @@ async def receiver() -> None: # this test might destroy the stream b/c we end up cancelling # send_all and e.g. SSLStream can't handle that, so we have to # recreate afterwards) - with _assert_raises(_core.BusyResourceError): + with _assert_raises(_core.BusyResourceError, wrapped=True): async with _core.open_nursery() as nursery: nursery.start_soon(s.wait_send_all_might_not_block) nursery.start_soon(s.send_all, b"123") @@ -350,7 +370,7 @@ async def receiver() -> None: async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r): # send_all and send_all blocked simultaneously should also raise # (but again this might destroy the stream) - with _assert_raises(_core.BusyResourceError): + with _assert_raises(_core.BusyResourceError, wrapped=True): async with _core.open_nursery() as nursery: nursery.start_soon(s.send_all, b"123") nursery.start_soon(s.send_all, b"123") @@ -532,7 +552,7 @@ async def expect_x_then_eof(r: HalfCloseableStream) -> None: if clogged_stream_maker is not None: async with _ForceCloseBoth(await clogged_stream_maker()) as (s1, s2): # send_all and send_eof simultaneously is not ok - with _assert_raises(_core.BusyResourceError): + with _assert_raises(_core.BusyResourceError, wrapped=True): async with _core.open_nursery() as nursery: nursery.start_soon(s1.send_all, b"x") await _core.wait_all_tasks_blocked() @@ -541,7 +561,7 @@ async def expect_x_then_eof(r: HalfCloseableStream) -> None: async with _ForceCloseBoth(await clogged_stream_maker()) as (s1, s2): # wait_send_all_might_not_block and send_eof simultaneously is not # ok either - with _assert_raises(_core.BusyResourceError): + with _assert_raises(_core.BusyResourceError, wrapped=True): async with _core.open_nursery() as nursery: nursery.start_soon(s1.wait_send_all_might_not_block) await _core.wait_all_tasks_blocked() diff --git a/trio/testing/_checkpoints.py b/src/trio/testing/_checkpoints.py similarity index 95% rename from trio/testing/_checkpoints.py rename to src/trio/testing/_checkpoints.py index 4a4047813b..e51463f071 100644 --- a/trio/testing/_checkpoints.py +++ b/src/trio/testing/_checkpoints.py @@ -1,10 +1,13 @@ from __future__ import annotations -from collections.abc import Generator from contextlib import AbstractContextManager, contextmanager +from typing import TYPE_CHECKING from .. import _core +if TYPE_CHECKING: + from collections.abc import Generator + @contextmanager def _assert_yields_or_not(expected: bool) -> Generator[None, None, None]: diff --git a/trio/testing/_fake_net.py b/src/trio/testing/_fake_net.py similarity index 79% rename from trio/testing/_fake_net.py rename to src/trio/testing/_fake_net.py index fc9c0b361b..f8589f3a9c 100644 --- a/trio/testing/_fake_net.py +++ b/src/trio/testing/_fake_net.py @@ -8,11 +8,12 @@ from __future__ import annotations -import builtins import contextlib import errno import ipaddress import os +import socket +import sys from typing import ( TYPE_CHECKING, Any, @@ -23,12 +24,13 @@ overload, ) -import attr +import attrs import trio from trio._util import NoPublicConstructor, final if TYPE_CHECKING: + import builtins from socket import AddressFamily, SocketKind from types import TracebackType @@ -53,12 +55,13 @@ def _wildcard_ip_for(family: int) -> IPAddress: raise NotImplementedError("Unhandled ip address family") # pragma: no cover -def _localhost_ip_for(family: int) -> IPAddress: +# not used anywhere +def _localhost_ip_for(family: int) -> IPAddress: # pragma: no cover if family == trio.socket.AF_INET: return ipaddress.ip_address("127.0.0.1") elif family == trio.socket.AF_INET6: return ipaddress.ip_address("::1") - raise NotImplementedError("Unhandled ip address family") # pragma: no cover + raise NotImplementedError("Unhandled ip address family") def _fake_err(code: int) -> NoReturn: @@ -67,12 +70,12 @@ def _fake_err(code: int) -> NoReturn: def _scatter(data: bytes, buffers: Iterable[Buffer]) -> int: written = 0 - for buf in buffers: + for buf in buffers: # pragma: no branch next_piece = data[written : written + memoryview(buf).nbytes] with memoryview(buf) as mbuf: mbuf[: len(next_piece)] = next_piece written += len(next_piece) - if written == len(data): + if written == len(data): # pragma: no branch break return written @@ -80,7 +83,7 @@ def _scatter(data: bytes, buffers: Iterable[Buffer]) -> int: T_UDPEndpoint = TypeVar("T_UDPEndpoint", bound="UDPEndpoint") -@attr.frozen +@attrs.frozen class UDPEndpoint: ip: IPAddress port: int @@ -102,39 +105,40 @@ def from_python_sockaddr( return cls(ip=ipaddress.ip_address(ip), port=port) -@attr.frozen +@attrs.frozen class UDPBinding: local: UDPEndpoint # remote: UDPEndpoint # ?? -@attr.frozen +@attrs.frozen class UDPPacket: source: UDPEndpoint destination: UDPEndpoint - payload: bytes = attr.ib(repr=lambda p: p.hex()) + payload: bytes = attrs.field(repr=lambda p: p.hex()) - def reply(self, payload: bytes) -> UDPPacket: + # not used/tested anywhere + def reply(self, payload: bytes) -> UDPPacket: # pragma: no cover return UDPPacket( source=self.destination, destination=self.source, payload=payload ) -@attr.frozen +@attrs.frozen class FakeSocketFactory(trio.abc.SocketFactory): fake_net: FakeNet - def socket(self, family: int, type: int, proto: int) -> FakeSocket: # type: ignore[override] - return FakeSocket._create(self.fake_net, family, type, proto) + def socket(self, family: int, type_: int, proto: int) -> FakeSocket: # type: ignore[override] + return FakeSocket._create(self.fake_net, family, type_, proto) -@attr.frozen +@attrs.frozen class FakeHostnameResolver(trio.abc.HostnameResolver): fake_net: FakeNet async def getaddrinfo( self, - host: bytes | str | None, + host: bytes | None, port: bytes | str | int | None, family: int = 0, type: int = 0, @@ -161,8 +165,8 @@ async def getnameinfo( class FakeNet: def __init__(self) -> None: # When we need to pick an arbitrary unique ip address/port, use these: - self._auto_ipv4_iter = ipaddress.IPv4Network("1.0.0.0/8").hosts() - self._auto_ipv4_iter = ipaddress.IPv6Network("1::/16").hosts() # type: ignore[assignment] + self._auto_ipv4_iter = ipaddress.IPv4Network("1.0.0.0/8").hosts() # untested + self._auto_ipv6_iter = ipaddress.IPv6Network("1::/16").hosts() # untested self._auto_port_iter = iter(range(50000, 65535)) self._bound: dict[UDPBinding, FakeSocket] = {} @@ -196,14 +200,18 @@ def deliver_packet(self, packet: UDPPacket) -> None: @final class FakeSocket(trio.socket.SocketType, metaclass=NoPublicConstructor): def __init__( - self, fake_net: FakeNet, family: AddressFamily, type: SocketKind, proto: int + self, + fake_net: FakeNet, + family: AddressFamily, + type: SocketKind, + proto: int, ): self._fake_net = fake_net - if not family: + if not family: # pragma: no cover family = trio.socket.AF_INET - if not type: - type = trio.socket.SOCK_STREAM + if not type: # pragma: no cover + type = trio.socket.SOCK_STREAM # noqa: A001 # name shadowing builtin if family not in (trio.socket.AF_INET, trio.socket.AF_INET6): raise NotImplementedError(f"FakeNet doesn't (yet) support family={family}") @@ -240,7 +248,6 @@ def _check_closed(self) -> None: _fake_err(errno.EBADF) def close(self) -> None: - # breakpoint() if self._closed: return self._closed = True @@ -274,7 +281,9 @@ async def bind(self, addr: object) -> None: if self._binding is not None: _fake_err(errno.EINVAL) await trio.lowlevel.checkpoint() - ip_str, port = await self._resolve_address_nocp(addr, local=True) + ip_str, port, *_ = await self._resolve_address_nocp(addr, local=True) + assert _ == [], "TODO: handle other values?" + ip = ipaddress.ip_address(ip_str) assert _family_for(ip) == self.family # We convert binds to INET_ANY into binds to localhost @@ -291,25 +300,14 @@ async def bind(self, addr: object) -> None: async def connect(self, peer: object) -> NoReturn: raise NotImplementedError("FakeNet does not (yet) support connected sockets") - async def sendmsg(self, *args: Any) -> int: + async def _sendmsg( + self, + buffers: Iterable[Buffer], + ancdata: Iterable[tuple[int, int, Buffer]] = (), + flags: int = 0, + address: Any | None = None, + ) -> int: self._check_closed() - ancdata = [] - flags = 0 - address = None - - # This does *not* match up with socket.socket.sendmsg (!!!) - # https://docs.python.org/3/library/socket.html#socket.socket.sendmsg - # they always have (buffers, ancdata, flags, address) - if len(args) == 1: - (buffers,) = args - elif len(args) == 2: - buffers, address = args - elif len(args) == 3: - buffers, flags, address = args - elif len(args) == 4: - buffers, ancdata, flags, address = args - else: - raise TypeError("wrong number of arguments") await trio.lowlevel.checkpoint() @@ -341,7 +339,12 @@ async def sendmsg(self, *args: Any) -> int: return len(payload) - async def recvmsg_into( + if sys.platform != "win32" or ( + not TYPE_CHECKING and hasattr(socket.socket, "sendmsg") + ): + sendmsg = _sendmsg + + async def _recvmsg_into( self, buffers: Iterable[Buffer], ancbufsize: int = 0, @@ -351,6 +354,14 @@ async def recvmsg_into( raise NotImplementedError("FakeNet doesn't support ancillary data") if flags != 0: raise NotImplementedError("FakeNet doesn't support any recv flags") + if self._binding is None: + # I messed this up a few times when writing tests ... but it also never happens + # in any of the existing tests, so maybe it could be intentional... + raise NotImplementedError( + "The code will most likely hang if you try to receive on a fakesocket " + "without a binding. If that is not the case, or you explicitly want to " + "test that, remove this warning." + ) self._check_closed() @@ -364,6 +375,11 @@ async def recvmsg_into( msg_flags |= trio.socket.MSG_TRUNC return written, ancdata, msg_flags, address + if sys.platform != "win32" or ( + not TYPE_CHECKING and hasattr(socket.socket, "sendmsg") + ): + recvmsg_into = _recvmsg_into + ################################################################ # Simple state query stuff ################################################################ @@ -385,7 +401,7 @@ def getpeername(self) -> tuple[str, int] | tuple[str, int, int, int]: assert hasattr( self._binding, "remote" ), "This method seems to assume that self._binding has a remote UDPEndpoint" - if self._binding.remote is not None: + if self._binding.remote is not None: # pragma: no cover assert isinstance( self._binding.remote, UDPEndpoint ), "Self._binding.remote should be a UDPEndpoint" @@ -393,12 +409,10 @@ def getpeername(self) -> tuple[str, int] | tuple[str, int, int, int]: _fake_err(errno.ENOTCONN) @overload - def getsockopt(self, /, level: int, optname: int) -> int: - ... + def getsockopt(self, /, level: int, optname: int) -> int: ... @overload - def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: - ... + def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: ... def getsockopt( self, /, level: int, optname: int, buflen: int | None = None @@ -407,12 +421,12 @@ def getsockopt( raise OSError(f"FakeNet doesn't implement getsockopt({level}, {optname})") @overload - def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: - ... + def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: ... @overload - def setsockopt(self, /, level: int, optname: int, value: None, optlen: int) -> None: - ... + def setsockopt( + self, /, level: int, optname: int, value: None, optlen: int + ) -> None: ... def setsockopt( self, @@ -450,7 +464,23 @@ def __exit__( async def send(self, data: Buffer, flags: int = 0) -> int: return await self.sendto(data, flags, None) + @overload + async def sendto( + self, __data: Buffer, __address: tuple[object, ...] | str | Buffer + ) -> int: ... + + @overload + async def sendto( + self, + __data: Buffer, + __flags: int, + __address: tuple[object, ...] | str | None | Buffer, + ) -> int: ... + async def sendto(self, *args: Any) -> int: + data: Buffer + flags: int + address: tuple[object, ...] | str | Buffer if len(args) == 2: data, address = args flags = 0 @@ -458,7 +488,7 @@ async def sendto(self, *args: Any) -> int: data, flags, address = args else: raise TypeError("wrong number of arguments") - return await self.sendmsg([data], [], flags, address) + return await self._sendmsg([data], [], flags, address) async def recv(self, bufsize: int, flags: int = 0) -> bytes: data, address = await self.recvfrom(bufsize, flags) @@ -469,7 +499,7 @@ async def recv_into(self, buf: Buffer, nbytes: int = 0, flags: int = 0) -> int: return got_bytes async def recvfrom(self, bufsize: int, flags: int = 0) -> tuple[bytes, Any]: - data, ancdata, msg_flags, address = await self.recvmsg(bufsize, flags) + data, ancdata, msg_flags, address = await self._recvmsg(bufsize, flags) return data, address async def recvfrom_into( @@ -477,20 +507,25 @@ async def recvfrom_into( ) -> tuple[int, Any]: if nbytes != 0 and nbytes != memoryview(buf).nbytes: raise NotImplementedError("partial recvfrom_into") - got_nbytes, ancdata, msg_flags, address = await self.recvmsg_into( + got_nbytes, ancdata, msg_flags, address = await self._recvmsg_into( [buf], 0, flags ) return got_nbytes, address - async def recvmsg( + async def _recvmsg( self, bufsize: int, ancbufsize: int = 0, flags: int = 0 ) -> tuple[bytes, list[tuple[int, int, bytes]], int, Any]: buf = bytearray(bufsize) - got_nbytes, ancdata, msg_flags, address = await self.recvmsg_into( + got_nbytes, ancdata, msg_flags, address = await self._recvmsg_into( [buf], ancbufsize, flags ) return (bytes(buf[:got_nbytes]), ancdata, msg_flags, address) + if sys.platform != "win32" or ( + not TYPE_CHECKING and hasattr(socket.socket, "sendmsg") + ): + recvmsg = _recvmsg + def fileno(self) -> int: raise NotImplementedError("can't get fileno() for FakeNet sockets") @@ -504,5 +539,9 @@ def set_inheritable(self, inheritable: bool) -> None: if inheritable: raise NotImplementedError("FakeNet can't make inheritable sockets") - def share(self, process_id: int) -> bytes: - raise NotImplementedError("FakeNet can't share sockets") + if sys.platform == "win32" or ( + not TYPE_CHECKING and hasattr(socket.socket, "share") + ): + + def share(self, process_id: int) -> bytes: + raise NotImplementedError("FakeNet can't share sockets") diff --git a/trio/testing/_memory_streams.py b/src/trio/testing/_memory_streams.py similarity index 98% rename from trio/testing/_memory_streams.py rename to src/trio/testing/_memory_streams.py index 04098c8c60..c9d430a9e6 100644 --- a/trio/testing/_memory_streams.py +++ b/src/trio/testing/_memory_streams.py @@ -368,12 +368,10 @@ def _make_stapled_pair( return stream1, stream2 -def memory_stream_pair() -> ( - tuple[ - StapledStream[MemorySendStream, MemoryReceiveStream], - StapledStream[MemorySendStream, MemoryReceiveStream], - ] -): +def memory_stream_pair() -> tuple[ + StapledStream[MemorySendStream, MemoryReceiveStream], + StapledStream[MemorySendStream, MemoryReceiveStream], +]: """Create a connected, pure-Python, bidirectional stream with infinite buffering and flexible configuration options. @@ -609,12 +607,10 @@ def lockstep_stream_one_way_pair() -> tuple[SendStream, ReceiveStream]: return _LockstepSendStream(lbq), _LockstepReceiveStream(lbq) -def lockstep_stream_pair() -> ( - tuple[ - StapledStream[SendStream, ReceiveStream], - StapledStream[SendStream, ReceiveStream], - ] -): +def lockstep_stream_pair() -> tuple[ + StapledStream[SendStream, ReceiveStream], + StapledStream[SendStream, ReceiveStream], +]: """Create a connected, pure-Python, bidirectional stream where data flows in lockstep. diff --git a/trio/testing/_network.py b/src/trio/testing/_network.py similarity index 100% rename from trio/testing/_network.py rename to src/trio/testing/_network.py diff --git a/src/trio/testing/_raises_group.py b/src/trio/testing/_raises_group.py new file mode 100644 index 0000000000..f96dcb2351 --- /dev/null +++ b/src/trio/testing/_raises_group.py @@ -0,0 +1,569 @@ +from __future__ import annotations + +import re +import sys +from typing import ( + TYPE_CHECKING, + Callable, + ContextManager, + Generic, + Literal, + Pattern, + Sequence, + cast, + overload, +) + +from trio._deprecate import warn_deprecated +from trio._util import final + +if TYPE_CHECKING: + import builtins + + # sphinx will *only* work if we use types.TracebackType, and import + # *inside* TYPE_CHECKING. No other combination works..... + import types + + from _pytest._code.code import ExceptionChainRepr, ReprExceptionInfo, Traceback + from typing_extensions import TypeGuard, TypeVar + + MatchE = TypeVar( + "MatchE", bound=BaseException, default=BaseException, covariant=True + ) +else: + from typing import TypeVar + + MatchE = TypeVar("MatchE", bound=BaseException, covariant=True) +# RaisesGroup doesn't work with a default. +E = TypeVar("E", bound=BaseException, covariant=True) +# These two typevars are special cased in sphinx config to workaround lookup bugs. + +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup + + +@final +class _ExceptionInfo(Generic[MatchE]): + """Minimal re-implementation of pytest.ExceptionInfo, only used if pytest is not available. Supports a subset of its features necessary for functionality of :class:`trio.testing.RaisesGroup` and :class:`trio.testing.Matcher`.""" + + _excinfo: tuple[type[MatchE], MatchE, types.TracebackType] | None + + def __init__( + self, excinfo: tuple[type[MatchE], MatchE, types.TracebackType] | None + ): + self._excinfo = excinfo + + def fill_unfilled( + self, exc_info: tuple[type[MatchE], MatchE, types.TracebackType] + ) -> None: + """Fill an unfilled ExceptionInfo created with ``for_later()``.""" + assert self._excinfo is None, "ExceptionInfo was already filled" + self._excinfo = exc_info + + @classmethod + def for_later(cls) -> _ExceptionInfo[MatchE]: + """Return an unfilled ExceptionInfo.""" + return cls(None) + + # Note, special cased in sphinx config, since "type" conflicts. + @property + def type(self) -> type[MatchE]: + """The exception class.""" + assert ( + self._excinfo is not None + ), ".type can only be used after the context manager exits" + return self._excinfo[0] + + @property + def value(self) -> MatchE: + """The exception value.""" + assert ( + self._excinfo is not None + ), ".value can only be used after the context manager exits" + return self._excinfo[1] + + @property + def tb(self) -> types.TracebackType: + """The exception raw traceback.""" + assert ( + self._excinfo is not None + ), ".tb can only be used after the context manager exits" + return self._excinfo[2] + + def exconly(self, tryshort: bool = False) -> str: + raise NotImplementedError( + "This is a helper method only available if you use RaisesGroup with the pytest package installed" + ) + + def errisinstance( + self, + exc: builtins.type[BaseException] | tuple[builtins.type[BaseException], ...], + ) -> bool: + raise NotImplementedError( + "This is a helper method only available if you use RaisesGroup with the pytest package installed" + ) + + def getrepr( + self, + showlocals: bool = False, + style: str = "long", + abspath: bool = False, + tbfilter: bool | Callable[[_ExceptionInfo], Traceback] = True, + funcargs: bool = False, + truncate_locals: bool = True, + chain: bool = True, + ) -> ReprExceptionInfo | ExceptionChainRepr: + raise NotImplementedError( + "This is a helper method only available if you use RaisesGroup with the pytest package installed" + ) + + +# Type checkers are not able to do conditional types depending on installed packages, so +# we've added signatures for all helpers to _ExceptionInfo, and then always use that. +# If this ends up leading to problems, we can resort to always using _ExceptionInfo and +# users that want to use getrepr/errisinstance/exconly can write helpers on their own, or +# we reimplement them ourselves...or get this merged in upstream pytest. +if TYPE_CHECKING: + ExceptionInfo = _ExceptionInfo + +else: + try: + from pytest import ExceptionInfo # noqa: PT013 + except ImportError: # pragma: no cover + ExceptionInfo = _ExceptionInfo + + +# copied from pytest.ExceptionInfo +def _stringify_exception(exc: BaseException) -> str: + return "\n".join( + [ + getattr(exc, "message", str(exc)), + *getattr(exc, "__notes__", []), + ] + ) + + +# String patterns default to including the unicode flag. +_regex_no_flags = re.compile("").flags + + +@final +class Matcher(Generic[MatchE]): + """Helper class to be used together with RaisesGroups when you want to specify requirements on sub-exceptions. Only specifying the type is redundant, and it's also unnecessary when the type is a nested `RaisesGroup` since it supports the same arguments. + The type is checked with `isinstance`, and does not need to be an exact match. If that is wanted you can use the ``check`` parameter. + :meth:`trio.testing.Matcher.matches` can also be used standalone to check individual exceptions. + + Examples:: + + with RaisesGroups(Matcher(ValueError, match="string")) + ... + with RaisesGroups(Matcher(check=lambda x: x.args == (3, "hello"))): + ... + with RaisesGroups(Matcher(check=lambda x: type(x) is ValueError)): + ... + + """ + + # At least one of the three parameters must be passed. + @overload + def __init__( + self: Matcher[MatchE], + exception_type: type[MatchE], + match: str | Pattern[str] = ..., + check: Callable[[MatchE], bool] = ..., + ): ... + + @overload + def __init__( + self: Matcher[BaseException], # Give E a value. + *, + match: str | Pattern[str], + # If exception_type is not provided, check() must do any typechecks itself. + check: Callable[[BaseException], bool] = ..., + ): ... + + @overload + def __init__(self, *, check: Callable[[BaseException], bool]): ... + + def __init__( + self, + exception_type: type[MatchE] | None = None, + match: str | Pattern[str] | None = None, + check: Callable[[MatchE], bool] | None = None, + ): + if exception_type is None and match is None and check is None: + raise ValueError("You must specify at least one parameter to match on.") + if exception_type is not None and not issubclass(exception_type, BaseException): + raise ValueError( + f"exception_type {exception_type} must be a subclass of BaseException" + ) + self.exception_type = exception_type + self.match: Pattern[str] | None + if isinstance(match, str): + self.match = re.compile(match) + else: + self.match = match + self.check = check + + def matches(self, exception: BaseException) -> TypeGuard[MatchE]: + """Check if an exception matches the requirements of this Matcher. + + Examples:: + + assert Matcher(ValueError).matches(my_exception): + # is equivalent to + assert isinstance(my_exception, ValueError) + + # this can be useful when checking e.g. the ``__cause__`` of an exception. + with pytest.raises(ValueError) as excinfo: + ... + assert Matcher(SyntaxError, match="foo").matches(excinfo.value.__cause__) + # above line is equivalent to + assert isinstance(excinfo.value.__cause__, SyntaxError) + assert re.search("foo", str(excinfo.value.__cause__) + + """ + if self.exception_type is not None and not isinstance( + exception, self.exception_type + ): + return False + if self.match is not None and not re.search( + self.match, _stringify_exception(exception) + ): + return False + # If exception_type is None check() accepts BaseException. + # If non-none, we have done an isinstance check above. + return self.check is None or self.check(cast(MatchE, exception)) + + def __str__(self) -> str: + reqs = [] + if self.exception_type is not None: + reqs.append(self.exception_type.__name__) + if (match := self.match) is not None: + # If no flags were specified, discard the redundant re.compile() here. + reqs.append( + f"match={match.pattern if match.flags == _regex_no_flags else match!r}" + ) + if self.check is not None: + reqs.append(f"check={self.check!r}") + return f'Matcher({", ".join(reqs)})' + + +# typing this has been somewhat of a nightmare, with the primary difficulty making +# the return type of __enter__ correct. Ideally it would function like this +# with RaisesGroup(RaisesGroup(ValueError)) as excinfo: +# ... +# assert_type(excinfo.value, ExceptionGroup[ExceptionGroup[ValueError]]) +# in addition to all the simple cases, but getting all the way to the above seems maybe +# impossible. The type being RaisesGroup[RaisesGroup[ValueError]] is probably also fine, +# as long as I add fake properties corresponding to the properties of exceptiongroup. But +# I had trouble with it handling recursive cases properly. + +# Current solution settles on the above giving BaseExceptionGroup[RaisesGroup[ValueError]], and it not +# being a type error to do `with RaisesGroup(ValueError()): ...` - but that will error on runtime. + +# We lie to type checkers that we inherit, so excinfo.value and sub-exceptiongroups can be treated as ExceptionGroups +if TYPE_CHECKING: + SuperClass = BaseExceptionGroup +else: + # At runtime, use a redundant Generic base class which effectively gets ignored. + SuperClass = Generic + + +@final +class RaisesGroup(ContextManager[ExceptionInfo[BaseExceptionGroup[E]]], SuperClass[E]): + """Contextmanager for checking for an expected `ExceptionGroup`. + This works similar to ``pytest.raises``, and a version of it will hopefully be added upstream, after which this can be deprecated and removed. See https://github.com/pytest-dev/pytest/issues/11538 + + + The catching behaviour differs from :ref:`except* ` in multiple different ways, being much stricter by default. By using ``allow_unwrapped=True`` and ``flatten_subgroups=True`` you can match ``except*`` fully when expecting a single exception. + + #. All specified exceptions must be present, *and no others*. + + * If you expect a variable number of exceptions you need to use ``pytest.raises(ExceptionGroup)`` and manually check the contained exceptions. Consider making use of :func:`Matcher.matches`. + + #. It will only catch exceptions wrapped in an exceptiongroup by default. + + * With ``allow_unwrapped=True`` you can specify a single expected exception or `Matcher` and it will match the exception even if it is not inside an `ExceptionGroup`. If you expect one of several different exception types you need to use a `Matcher` object. + + #. By default it cares about the full structure with nested `ExceptionGroup`'s. You can specify nested `ExceptionGroup`'s by passing `RaisesGroup` objects as expected exceptions. + + * With ``flatten_subgroups=True`` it will "flatten" the raised `ExceptionGroup`, extracting all exceptions inside any nested :class:`ExceptionGroup`, before matching. + + It currently does not care about the order of the exceptions, so ``RaisesGroups(ValueError, TypeError)`` is equivalent to ``RaisesGroups(TypeError, ValueError)``. + + This class is not as polished as ``pytest.raises``, and is currently not as helpful in e.g. printing diffs when strings don't match, suggesting you use ``re.escape``, etc. + + Examples:: + + with RaisesGroups(ValueError): + raise ExceptionGroup("", (ValueError(),)) + with RaisesGroups(ValueError, ValueError, Matcher(TypeError, match="expected int")): + ... + with RaisesGroups(KeyboardInterrupt, match="hello", check=lambda x: type(x) is BaseExceptionGroup): + ... + with RaisesGroups(RaisesGroups(ValueError)): + raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)) + + # flatten_subgroups + with RaisesGroups(ValueError, flatten_subgroups=True): + raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)) + + # allow_unwrapped + with RaisesGroups(ValueError, allow_unwrapped=True): + raise ValueError + + + `RaisesGroup.matches` can also be used directly to check a standalone exception group. + + + The matching algorithm is greedy, which means cases such as this may fail:: + + with RaisesGroups(ValueError, Matcher(ValueError, match="hello")): + raise ExceptionGroup("", (ValueError("hello"), ValueError("goodbye"))) + + even though it generally does not care about the order of the exceptions in the group. + To avoid the above you should specify the first ValueError with a Matcher as well. + + It is also not typechecked perfectly, and that's likely not possible with the current approach. Most common usage should work without issue though. + """ + + # needed for pyright, since BaseExceptionGroup.__new__ takes two arguments + if TYPE_CHECKING: + + def __new__(cls, *args: object, **kwargs: object) -> RaisesGroup[E]: ... + + # allow_unwrapped=True requires: singular exception, exception not being + # RaisesGroup instance, match is None, check is None + @overload + def __init__( + self, + exception: type[E] | Matcher[E], + *, + allow_unwrapped: Literal[True], + flatten_subgroups: bool = False, + match: None = None, + check: None = None, + ): ... + + # flatten_subgroups = True also requires no nested RaisesGroup + @overload + def __init__( + self, + exception: type[E] | Matcher[E], + *other_exceptions: type[E] | Matcher[E], + allow_unwrapped: Literal[False] = False, + flatten_subgroups: Literal[True], + match: str | Pattern[str] | None = None, + check: Callable[[BaseExceptionGroup[E]], bool] | None = None, + ): ... + + @overload + def __init__( + self, + exception: type[E] | Matcher[E] | E, + *other_exceptions: type[E] | Matcher[E] | E, + allow_unwrapped: Literal[False] = False, + flatten_subgroups: Literal[False] = False, + match: str | Pattern[str] | None = None, + check: Callable[[BaseExceptionGroup[E]], bool] | None = None, + ): ... + + def __init__( + self, + exception: type[E] | Matcher[E] | E, + *other_exceptions: type[E] | Matcher[E] | E, + allow_unwrapped: bool = False, + flatten_subgroups: bool = False, + match: str | Pattern[str] | None = None, + check: Callable[[BaseExceptionGroup[E]], bool] | None = None, + strict: None = None, + ): + self.expected_exceptions: tuple[type[E] | Matcher[E] | E, ...] = ( + exception, + *other_exceptions, + ) + self.flatten_subgroups: bool = flatten_subgroups + self.allow_unwrapped = allow_unwrapped + self.match_expr = match + self.check = check + self.is_baseexceptiongroup = False + + if strict is not None: + warn_deprecated( + "The `strict` parameter", + "0.25.1", + issue=2989, + instead="flatten_subgroups=True (for strict=False}", + ) + self.flatten_subgroups = not strict + + if allow_unwrapped and other_exceptions: + raise ValueError( + "You cannot specify multiple exceptions with `allow_unwrapped=True.`" + " If you want to match one of multiple possible exceptions you should" + " use a `Matcher`." + " E.g. `Matcher(check=lambda e: isinstance(e, (...)))`" + ) + if allow_unwrapped and isinstance(exception, RaisesGroup): + raise ValueError( + "`allow_unwrapped=True` has no effect when expecting a `RaisesGroup`." + " You might want it in the expected `RaisesGroup`, or" + " `flatten_subgroups=True` if you don't care about the structure." + ) + if allow_unwrapped and (match is not None or check is not None): + raise ValueError( + "`allow_unwrapped=True` bypasses the `match` and `check` parameters" + " if the exception is unwrapped. If you intended to match/check the" + " exception you should use a `Matcher` object. If you want to match/check" + " the exceptiongroup when the exception *is* wrapped you need to" + " do e.g. `if isinstance(exc.value, ExceptionGroup):" + " assert RaisesGroup(...).matches(exc.value)` afterwards." + ) + + # verify `expected_exceptions` and set `self.is_baseexceptiongroup` + for exc in self.expected_exceptions: + if isinstance(exc, RaisesGroup): + if self.flatten_subgroups: + raise ValueError( + "You cannot specify a nested structure inside a RaisesGroup with" + " `flatten_subgroups=True`. The parameter will flatten subgroups" + " in the raised exceptiongroup before matching, which would never" + " match a nested structure." + ) + self.is_baseexceptiongroup |= exc.is_baseexceptiongroup + elif isinstance(exc, Matcher): + # The Matcher could match BaseExceptions through the other arguments + # but `self.is_baseexceptiongroup` is only used for printing. + if exc.exception_type is None: + continue + # Matcher __init__ assures it's a subclass of BaseException + self.is_baseexceptiongroup |= not issubclass( + exc.exception_type, Exception + ) + elif isinstance(exc, type) and issubclass(exc, BaseException): + self.is_baseexceptiongroup |= not issubclass(exc, Exception) + else: + raise ValueError( + f'Invalid argument "{exc!r}" must be exception type, Matcher, or' + " RaisesGroup." + ) + + def __enter__(self) -> ExceptionInfo[BaseExceptionGroup[E]]: + self.excinfo: ExceptionInfo[BaseExceptionGroup[E]] = ExceptionInfo.for_later() + return self.excinfo + + def _unroll_exceptions( + self, exceptions: Sequence[BaseException] + ) -> Sequence[BaseException]: + """Used if `flatten_subgroups=True`.""" + res: list[BaseException] = [] + for exc in exceptions: + if isinstance(exc, BaseExceptionGroup): + res.extend(self._unroll_exceptions(exc.exceptions)) + + else: + res.append(exc) + return res + + def matches( + self, + exc_val: BaseException | None, + ) -> TypeGuard[BaseExceptionGroup[E]]: + """Check if an exception matches the requirements of this RaisesGroup. + + Example:: + + with pytest.raises(TypeError) as excinfo: + ... + assert RaisesGroups(ValueError).matches(excinfo.value.__cause__) + # the above line is equivalent to + myexc = excinfo.value.__cause + assert isinstance(myexc, BaseExceptionGroup) + assert len(myexc.exceptions) == 1 + assert isinstance(myexc.exceptions[0], ValueError) + """ + if exc_val is None: + return False + # TODO: print/raise why a match fails, in a way that works properly in nested cases + # maybe have a list of strings logging failed matches, that __exit__ can + # recursively step through and print on a failing match. + if not isinstance(exc_val, BaseExceptionGroup): + if self.allow_unwrapped: + exp_exc = self.expected_exceptions[0] + if isinstance(exp_exc, Matcher) and exp_exc.matches(exc_val): + return True + if isinstance(exp_exc, type) and isinstance(exc_val, exp_exc): + return True + return False + + if self.match_expr is not None and not re.search( + self.match_expr, _stringify_exception(exc_val) + ): + return False + if self.check is not None and not self.check(exc_val): + return False + + remaining_exceptions = list(self.expected_exceptions) + actual_exceptions: Sequence[BaseException] = exc_val.exceptions + if self.flatten_subgroups: + actual_exceptions = self._unroll_exceptions(actual_exceptions) + + # important to check the length *after* flattening subgroups + if len(actual_exceptions) != len(self.expected_exceptions): + return False + + # it should be possible to get RaisesGroup.matches typed so as not to + # need type: ignore, but I'm not sure that's possible while also having it + # transparent for the end user. + for e in actual_exceptions: + for rem_e in remaining_exceptions: + if ( + (isinstance(rem_e, type) and isinstance(e, rem_e)) + or (isinstance(rem_e, RaisesGroup) and rem_e.matches(e)) + or (isinstance(rem_e, Matcher) and rem_e.matches(e)) + ): + remaining_exceptions.remove(rem_e) # type: ignore[arg-type] + break + else: + return False + return True + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: types.TracebackType | None, + ) -> bool: + __tracebackhide__ = True + assert ( + exc_type is not None + ), f"DID NOT RAISE any exception, expected {self.expected_type()}" + assert ( + self.excinfo is not None + ), "Internal error - should have been constructed in __enter__" + + if not self.matches(exc_val): + return False + + # Cast to narrow the exception type now that it's verified. + exc_info = cast( + "tuple[type[BaseExceptionGroup[E]], BaseExceptionGroup[E], types.TracebackType]", + (exc_type, exc_val, exc_tb), + ) + self.excinfo.fill_unfilled(exc_info) + return True + + def expected_type(self) -> str: + subexcs = [] + for e in self.expected_exceptions: + if isinstance(e, Matcher): + subexcs.append(str(e)) + elif isinstance(e, RaisesGroup): + subexcs.append(e.expected_type()) + elif isinstance(e, type): + subexcs.append(e.__name__) + else: # pragma: no cover + raise AssertionError("unknown type") + group_type = "Base" if self.is_baseexceptiongroup else "" + return f"{group_type}ExceptionGroup({', '.join(subexcs)})" diff --git a/trio/testing/_sequencer.py b/src/trio/testing/_sequencer.py similarity index 91% rename from trio/testing/_sequencer.py rename to src/trio/testing/_sequencer.py index 4726742995..2bade1b315 100644 --- a/trio/testing/_sequencer.py +++ b/src/trio/testing/_sequencer.py @@ -4,7 +4,7 @@ from contextlib import asynccontextmanager from typing import TYPE_CHECKING -import attr +import attrs from .. import Event, _core, _util @@ -13,7 +13,7 @@ @_util.final -@attr.s(eq=False, hash=False) +@attrs.define(eq=False, hash=False, slots=False) class Sequencer: """A convenience class for forcing code in different tasks to run in an explicit linear order. @@ -54,11 +54,11 @@ async def main(): """ - _sequence_points: defaultdict[int, Event] = attr.ib( + _sequence_points: defaultdict[int, Event] = attrs.field( factory=lambda: defaultdict(Event), init=False ) - _claimed: set[int] = attr.ib(factory=set, init=False) - _broken: bool = attr.ib(default=False, init=False) + _claimed: set[int] = attrs.field(factory=set, init=False) + _broken: bool = attrs.field(default=False, init=False) @asynccontextmanager async def __call__(self, position: int) -> AsyncIterator[None]: diff --git a/trio/testing/_trio_test.py b/src/trio/testing/_trio_test.py similarity index 96% rename from trio/testing/_trio_test.py rename to src/trio/testing/_trio_test.py index 5619352846..a57c0ee4c7 100644 --- a/trio/testing/_trio_test.py +++ b/src/trio/testing/_trio_test.py @@ -1,6 +1,5 @@ from __future__ import annotations -from collections.abc import Awaitable, Callable from functools import partial, wraps from typing import TYPE_CHECKING, TypeVar @@ -8,6 +7,8 @@ from ..abc import Clock, Instrument if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + from typing_extensions import ParamSpec ArgsT = ParamSpec("ArgsT") diff --git a/trio/to_thread.py b/src/trio/to_thread.py similarity index 100% rename from trio/to_thread.py rename to src/trio/to_thread.py diff --git a/test-requirements.in b/test-requirements.in index 37fb6b5157..af8a751b13 100644 --- a/test-requirements.in +++ b/test-requirements.in @@ -13,21 +13,25 @@ cryptography>=41.0.0 # cryptography<41 segfaults on pypy3.10 black; implementation_name == "cpython" mypy; implementation_name == "cpython" types-pyOpenSSL; implementation_name == "cpython" # and annotations -ruff >= 0.0.291 +ruff >= 0.4.3 astor # code generation -pip-tools >= 6.13.0 +uv >= 0.2.24 codespell # https://github.com/python-trio/trio/pull/654#issuecomment-420518745 mypy-extensions; implementation_name == "cpython" typing-extensions types-cffi; implementation_name == "cpython" +# annotations in doc files +types-docutils +sphinx # Trio's own dependencies cffi; os_name == "nt" -attrs >= 20.1.0 +attrs >= 23.2.0 sortedcontainers idna outcome sniffio -exceptiongroup >= 1.0.0rc9; python_version < "3.11" +# 1.2.1 fixes types +exceptiongroup >= 1.2.1; python_version < "3.11" diff --git a/test-requirements.txt b/test-requirements.txt index b5d2c0a4d5..3dec7a570b 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,136 +1,170 @@ -# -# This file is autogenerated by pip-compile with Python 3.8 -# by the following command: -# -# pip-compile test-requirements.in -# +# This file was autogenerated by uv via the following command: +# uv pip compile --universal --python-version=3.8 test-requirements.in -o test-requirements.txt +alabaster==0.7.13 + # via sphinx astor==0.8.1 # via -r test-requirements.in -astroid==3.0.1 +astroid==3.2.2 # via pylint async-generator==1.10 # via -r test-requirements.in -attrs==23.1.0 +attrs==23.2.0 # via # -r test-requirements.in # outcome -black==23.10.0 ; implementation_name == "cpython" +babel==2.15.0 + # via sphinx +black==24.4.2 ; implementation_name == 'cpython' # via -r test-requirements.in -build==1.0.3 - # via pip-tools -cffi==1.16.0 - # via cryptography -click==8.1.7 +certifi==2024.7.4 + # via requests +cffi==1.17.0rc1 ; os_name == 'nt' or platform_python_implementation != 'PyPy' # via - # black - # pip-tools -codespell==2.2.6 + # -r test-requirements.in + # cryptography +charset-normalizer==3.3.2 + # via requests +click==8.1.7 ; implementation_name == 'cpython' + # via black +codespell==2.3.0 # via -r test-requirements.in -coverage==7.3.2 +colorama==0.4.6 ; sys_platform == 'win32' or (implementation_name == 'cpython' and platform_system == 'Windows') + # via + # click + # pylint + # pytest + # sphinx +coverage==7.5.4 # via -r test-requirements.in -cryptography==41.0.4 +cryptography==42.0.8 # via # -r test-requirements.in # pyopenssl # trustme # types-pyopenssl -dill==0.3.7 +dill==0.3.8 # via pylint -exceptiongroup==1.1.3 ; python_version < "3.11" +docutils==0.20.1 + # via sphinx +exceptiongroup==1.2.1 ; python_version < '3.11' # via # -r test-requirements.in # pytest -idna==3.4 +idna==3.7 # via # -r test-requirements.in + # requests # trustme -importlib-metadata==6.8.0 - # via build +imagesize==1.4.1 + # via sphinx +importlib-metadata==8.0.0 ; python_version < '3.10' + # via sphinx iniconfig==2.0.0 # via pytest -isort==5.12.0 +isort==5.13.2 # via pylint jedi==0.19.1 # via -r test-requirements.in +jinja2==3.1.4 + # via sphinx +markupsafe==2.1.5 + # via jinja2 mccabe==0.7.0 # via pylint -mypy==1.6.0 ; implementation_name == "cpython" +mypy==1.11.0 ; implementation_name == 'cpython' # via -r test-requirements.in -mypy-extensions==1.0.0 ; implementation_name == "cpython" +mypy-extensions==1.0.0 ; implementation_name == 'cpython' # via # -r test-requirements.in # black # mypy -nodeenv==1.8.0 +nodeenv==1.9.1 # via pyright -outcome==1.3.0 +outcome==1.3.0.post0 # via -r test-requirements.in -packaging==23.2 +packaging==24.1 # via # black - # build # pytest -parso==0.8.3 + # sphinx +parso==0.8.4 # via jedi -pathspec==0.11.2 +pathspec==0.12.1 ; implementation_name == 'cpython' # via black -pip-tools==7.3.0 - # via -r test-requirements.in -platformdirs==3.11.0 +platformdirs==4.2.2 # via # black # pylint -pluggy==1.3.0 +pluggy==1.5.0 # via pytest -pycparser==2.21 +pycparser==2.22 ; os_name == 'nt' or platform_python_implementation != 'PyPy' # via cffi -pylint==3.0.1 +pygments==2.18.0 + # via sphinx +pylint==3.2.5 # via -r test-requirements.in -pyopenssl==23.2.0 +pyopenssl==24.1.0 # via -r test-requirements.in -pyproject-hooks==1.0.0 - # via build -pyright==1.1.331 +pyright==1.1.370 # via -r test-requirements.in -pytest==7.4.2 +pytest==8.2.2 # via -r test-requirements.in -ruff==0.1.0 +pytz==2024.1 ; python_version < '3.9' + # via babel +requests==2.32.3 + # via sphinx +ruff==0.5.1 # via -r test-requirements.in -sniffio==1.3.0 +sniffio==1.3.1 # via -r test-requirements.in +snowballstemmer==2.2.0 + # via sphinx sortedcontainers==2.4.0 # via -r test-requirements.in -tomli==2.0.1 +sphinx==7.1.2 + # via -r test-requirements.in +sphinxcontrib-applehelp==1.0.4 + # via sphinx +sphinxcontrib-devhelp==1.0.2 + # via sphinx +sphinxcontrib-htmlhelp==2.0.1 + # via sphinx +sphinxcontrib-jsmath==1.0.1 + # via sphinx +sphinxcontrib-qthelp==1.0.3 + # via sphinx +sphinxcontrib-serializinghtml==1.1.5 + # via sphinx +tomli==2.0.1 ; python_version < '3.11' # via # black - # build # mypy - # pip-tools # pylint - # pyproject-hooks # pytest -tomlkit==0.12.1 +tomlkit==0.12.5 # via pylint trustme==1.1.0 # via -r test-requirements.in -types-cffi==1.16.0.0 ; implementation_name == "cpython" +types-cffi==1.16.0.20240331 ; implementation_name == 'cpython' + # via + # -r test-requirements.in + # types-pyopenssl +types-docutils==0.21.0.20240704 # via -r test-requirements.in -types-pyopenssl==23.2.0.2 ; implementation_name == "cpython" +types-pyopenssl==24.1.0.20240425 ; implementation_name == 'cpython' # via -r test-requirements.in -types-setuptools==68.2.0.0 +types-setuptools==70.2.0.20240704 ; implementation_name == 'cpython' # via types-cffi -typing-extensions==4.8.0 +typing-extensions==4.12.2 # via # -r test-requirements.in # astroid # black # mypy # pylint -wheel==0.41.2 - # via pip-tools -zipp==3.17.0 +urllib3==2.2.2 + # via requests +uv==0.2.26 + # via -r test-requirements.in +zipp==3.19.2 ; python_version < '3.10' # via importlib-metadata - -# The following packages are considered to be unsafe in a requirements file: -# pip -# setuptools diff --git a/tests/cython/test_cython.pyx b/tests/cython/test_cython.pyx new file mode 100644 index 0000000000..b836caf90c --- /dev/null +++ b/tests/cython/test_cython.pyx @@ -0,0 +1,22 @@ +# cython: language_level=3 +import trio + +# the output of the prints are not currently checked, we only check +# if the program can be compiled and doesn't crash when run. + +# The content of the program can easily be extended if there's other behaviour +# that might be likely to be problematic for cython. +async def foo() -> None: + print('.') + +async def trio_main() -> None: + print('hello...') + await trio.sleep(1) + print(' world !') + + async with trio.open_nursery() as nursery: + nursery.start_soon(foo) + nursery.start_soon(foo) + nursery.start_soon(foo) + +trio.run(trio_main) diff --git a/trio/_core/_generated_io_epoll.py b/trio/_core/_generated_io_epoll.py deleted file mode 100644 index f80e988f38..0000000000 --- a/trio/_core/_generated_io_epoll.py +++ /dev/null @@ -1,39 +0,0 @@ -# *********************************************************** -# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** -# ************************************************************* -from __future__ import annotations - -from typing import TYPE_CHECKING - -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED -from ._run import GLOBAL_RUN_CONTEXT - -if TYPE_CHECKING: - from .._file_io import _HasFileNo -import sys - -assert not TYPE_CHECKING or sys.platform == "linux" - - -async def wait_readable(fd: (int | _HasFileNo)) -> None: - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True - try: - return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) - except AttributeError: - raise RuntimeError("must be called from async context") from None - - -async def wait_writable(fd: (int | _HasFileNo)) -> None: - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True - try: - return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) - except AttributeError: - raise RuntimeError("must be called from async context") from None - - -def notify_closing(fd: (int | _HasFileNo)) -> None: - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True - try: - return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) - except AttributeError: - raise RuntimeError("must be called from async context") from None diff --git a/trio/_core/_generated_io_kqueue.py b/trio/_core/_generated_io_kqueue.py deleted file mode 100644 index b572831076..0000000000 --- a/trio/_core/_generated_io_kqueue.py +++ /dev/null @@ -1,73 +0,0 @@ -# *********************************************************** -# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** -# ************************************************************* -from __future__ import annotations - -from typing import TYPE_CHECKING, Callable, ContextManager - -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED -from ._run import GLOBAL_RUN_CONTEXT - -if TYPE_CHECKING: - import select - - from .. import _core - from .._file_io import _HasFileNo - from ._traps import Abort, RaiseCancelT -import sys - -assert not TYPE_CHECKING or sys.platform == "darwin" - - -def current_kqueue() -> select.kqueue: - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True - try: - return GLOBAL_RUN_CONTEXT.runner.io_manager.current_kqueue() - except AttributeError: - raise RuntimeError("must be called from async context") from None - - -def monitor_kevent( - ident: int, filter: int -) -> ContextManager[_core.UnboundedQueue[select.kevent]]: - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True - try: - return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_kevent(ident, filter) - except AttributeError: - raise RuntimeError("must be called from async context") from None - - -async def wait_kevent( - ident: int, filter: int, abort_func: Callable[[RaiseCancelT], Abort] -) -> Abort: - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True - try: - return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_kevent( - ident, filter, abort_func - ) - except AttributeError: - raise RuntimeError("must be called from async context") from None - - -async def wait_readable(fd: (int | _HasFileNo)) -> None: - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True - try: - return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd) - except AttributeError: - raise RuntimeError("must be called from async context") from None - - -async def wait_writable(fd: (int | _HasFileNo)) -> None: - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True - try: - return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd) - except AttributeError: - raise RuntimeError("must be called from async context") from None - - -def notify_closing(fd: (int | _HasFileNo)) -> None: - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True - try: - return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd) - except AttributeError: - raise RuntimeError("must be called from async context") from None diff --git a/trio/_core/_generated_io_windows.py b/trio/_core/_generated_io_windows.py deleted file mode 100644 index e859829e2b..0000000000 --- a/trio/_core/_generated_io_windows.py +++ /dev/null @@ -1,103 +0,0 @@ -# *********************************************************** -# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** -# ************************************************************* -from __future__ import annotations - -from typing import TYPE_CHECKING, ContextManager - -from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED -from ._run import GLOBAL_RUN_CONTEXT - -if TYPE_CHECKING: - from typing_extensions import Buffer - - from .._file_io import _HasFileNo - from ._unbounded_queue import UnboundedQueue - from ._windows_cffi import CData, Handle -import sys - -assert not TYPE_CHECKING or sys.platform == "win32" - - -async def wait_readable(sock: (_HasFileNo | int)) -> None: - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True - try: - return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(sock) - except AttributeError: - raise RuntimeError("must be called from async context") from None - - -async def wait_writable(sock: (_HasFileNo | int)) -> None: - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True - try: - return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(sock) - except AttributeError: - raise RuntimeError("must be called from async context") from None - - -def notify_closing(handle: (Handle | int | _HasFileNo)) -> None: - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True - try: - return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(handle) - except AttributeError: - raise RuntimeError("must be called from async context") from None - - -def register_with_iocp(handle: (int | CData)) -> None: - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True - try: - return GLOBAL_RUN_CONTEXT.runner.io_manager.register_with_iocp(handle) - except AttributeError: - raise RuntimeError("must be called from async context") from None - - -async def wait_overlapped( - handle_: (int | CData), lpOverlapped: (CData | int) -) -> object: - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True - try: - return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_overlapped( - handle_, lpOverlapped - ) - except AttributeError: - raise RuntimeError("must be called from async context") from None - - -async def write_overlapped( - handle: (int | CData), data: Buffer, file_offset: int = 0 -) -> int: - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True - try: - return await GLOBAL_RUN_CONTEXT.runner.io_manager.write_overlapped( - handle, data, file_offset - ) - except AttributeError: - raise RuntimeError("must be called from async context") from None - - -async def readinto_overlapped( - handle: (int | CData), buffer: Buffer, file_offset: int = 0 -) -> int: - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True - try: - return await GLOBAL_RUN_CONTEXT.runner.io_manager.readinto_overlapped( - handle, buffer, file_offset - ) - except AttributeError: - raise RuntimeError("must be called from async context") from None - - -def current_iocp() -> int: - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True - try: - return GLOBAL_RUN_CONTEXT.runner.io_manager.current_iocp() - except AttributeError: - raise RuntimeError("must be called from async context") from None - - -def monitor_completion_key() -> ContextManager[tuple[int, UnboundedQueue[object]]]: - locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True - try: - return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_completion_key() - except AttributeError: - raise RuntimeError("must be called from async context") from None diff --git a/trio/_core/_multierror.py b/trio/_core/_multierror.py deleted file mode 100644 index cd839e0b4a..0000000000 --- a/trio/_core/_multierror.py +++ /dev/null @@ -1,500 +0,0 @@ -from __future__ import annotations - -import sys -from collections.abc import Callable, Sequence -from types import TracebackType -from typing import TYPE_CHECKING, Any, ClassVar, cast, overload - -import attr - -from trio._deprecate import warn_deprecated - -if sys.version_info < (3, 11): - from exceptiongroup import BaseExceptionGroup, ExceptionGroup - -if TYPE_CHECKING: - from typing_extensions import Self -################################################################ -# MultiError -################################################################ - - -def _filter_impl( - handler: Callable[[BaseException], BaseException | None], root_exc: BaseException -) -> BaseException | None: - # We have a tree of MultiError's, like: - # - # MultiError([ - # ValueError, - # MultiError([ - # KeyError, - # ValueError, - # ]), - # ]) - # - # or similar. - # - # We want to - # 1) apply the filter to each of the leaf exceptions -- each leaf - # might stay the same, be replaced (with the original exception - # potentially sticking around as __context__ or __cause__), or - # disappear altogether. - # 2) simplify the resulting tree -- remove empty nodes, and replace - # singleton MultiError's with their contents, e.g.: - # MultiError([KeyError]) -> KeyError - # (This can happen recursively, e.g. if the two ValueErrors above - # get caught then we'll just be left with a bare KeyError.) - # 3) preserve sensible tracebacks - # - # It's the tracebacks that are most confusing. As a MultiError - # propagates through the stack, it accumulates traceback frames, but - # the exceptions inside it don't. Semantically, the traceback for a - # leaf exception is the concatenation the tracebacks of all the - # exceptions you see when traversing the exception tree from the root - # to that leaf. Our correctness invariant is that this concatenated - # traceback should be the same before and after. - # - # The easy way to do that would be to, at the beginning of this - # function, "push" all tracebacks down to the leafs, so all the - # MultiErrors have __traceback__=None, and all the leafs have complete - # tracebacks. But whenever possible, we'd actually prefer to keep - # tracebacks as high up in the tree as possible, because this lets us - # keep only a single copy of the common parts of these exception's - # tracebacks. This is cheaper (in memory + time -- tracebacks are - # unpleasantly quadratic-ish to work with, and this might matter if - # you have thousands of exceptions, which can happen e.g. after - # cancelling a large task pool, and no-one will ever look at their - # tracebacks!), and more importantly, factoring out redundant parts of - # the tracebacks makes them more readable if/when users do see them. - # - # So instead our strategy is: - # - first go through and construct the new tree, preserving any - # unchanged subtrees - # - then go through the original tree (!) and push tracebacks down - # until either we hit a leaf, or we hit a subtree which was - # preserved in the new tree. - - # This used to also support async handler functions. But that runs into: - # https://bugs.python.org/issue29600 - # which is difficult to fix on our end. - - # Filters a subtree, ignoring tracebacks, while keeping a record of - # which MultiErrors were preserved unchanged - def filter_tree( - exc: MultiError | BaseException, preserved: set[int] - ) -> MultiError | BaseException | None: - if isinstance(exc, MultiError): - new_exceptions = [] - changed = False - for child_exc in exc.exceptions: - new_child_exc = filter_tree( # noqa: F821 # Deleted in local scope below, causes ruff to think it's not defined (astral-sh/ruff#7733) - child_exc, preserved - ) - if new_child_exc is not child_exc: - changed = True - if new_child_exc is not None: - new_exceptions.append(new_child_exc) - if not new_exceptions: - return None - elif changed: - return MultiError(new_exceptions) - else: - preserved.add(id(exc)) - return exc - else: - new_exc = handler(exc) - # Our version of implicit exception chaining - if new_exc is not None and new_exc is not exc: - new_exc.__context__ = exc - return new_exc - - def push_tb_down( - tb: TracebackType | None, exc: BaseException, preserved: set[int] - ) -> None: - if id(exc) in preserved: - return - new_tb = concat_tb(tb, exc.__traceback__) - if isinstance(exc, MultiError): - for child_exc in exc.exceptions: - push_tb_down( # noqa: F821 # Deleted in local scope below, causes ruff to think it's not defined (astral-sh/ruff#7733) - new_tb, child_exc, preserved - ) - exc.__traceback__ = None - else: - exc.__traceback__ = new_tb - - preserved: set[int] = set() - new_root_exc = filter_tree(root_exc, preserved) - push_tb_down(None, root_exc, preserved) - # Delete the local functions to avoid a reference cycle (see - # test_simple_cancel_scope_usage_doesnt_create_cyclic_garbage) - del filter_tree, push_tb_down - return new_root_exc - - -# Normally I'm a big fan of (a)contextmanager, but in this case I found it -# easier to use the raw context manager protocol, because it makes it a lot -# easier to reason about how we're mutating the traceback as we go. (End -# result: if the exception gets modified, then the 'raise' here makes this -# frame show up in the traceback; otherwise, we leave no trace.) -@attr.s(frozen=True) -class MultiErrorCatcher: - _handler: Callable[[BaseException], BaseException | None] = attr.ib() - - def __enter__(self) -> None: - pass - - def __exit__( - self, - exc_type: type[BaseException] | None, - exc_value: BaseException | None, - traceback: TracebackType | None, - ) -> bool | None: - if exc_value is not None: - filtered_exc = _filter_impl(self._handler, exc_value) - - if filtered_exc is exc_value: - # Let the interpreter re-raise it - return False - if filtered_exc is None: - # Swallow the exception - return True - # When we raise filtered_exc, Python will unconditionally blow - # away its __context__ attribute and replace it with the original - # exc we caught. So after we raise it, we have to pause it while - # it's in flight to put the correct __context__ back. - old_context = filtered_exc.__context__ - try: - raise filtered_exc - finally: - _, value, _ = sys.exc_info() - assert value is filtered_exc - value.__context__ = old_context - # delete references from locals to avoid creating cycles - # see test_MultiError_catch_doesnt_create_cyclic_garbage - del _, filtered_exc, value - return False - - -if TYPE_CHECKING: - _BaseExceptionGroup = BaseExceptionGroup[BaseException] -else: - _BaseExceptionGroup = BaseExceptionGroup - - -class MultiError(_BaseExceptionGroup): - """An exception that contains other exceptions; also known as an - "inception". - - It's main use is to represent the situation when multiple child tasks all - raise errors "in parallel". - - Args: - exceptions (list): The exceptions - - Returns: - If ``len(exceptions) == 1``, returns that exception. This means that a - call to ``MultiError(...)`` is not guaranteed to return a - :exc:`MultiError` object! - - Otherwise, returns a new :exc:`MultiError` object. - - Raises: - TypeError: if any of the passed in objects are not instances of - :exc:`BaseException`. - - """ - - def __init__( - self, exceptions: Sequence[BaseException], *, _collapse: bool = True - ) -> None: - self.collapse = _collapse - - # Avoid double initialization when _collapse is True and exceptions[0] returned - # by __new__() happens to be a MultiError and subsequently __init__() is called. - if _collapse and getattr(self, "exceptions", None) is not None: - # This exception was already initialized. - return - - super().__init__("multiple tasks failed", exceptions) - - def __new__( # type: ignore[misc] # mypy says __new__ must return a class instance - cls, exceptions: Sequence[BaseException], *, _collapse: bool = True - ) -> NonBaseMultiError | Self | BaseException: - exceptions = list(exceptions) - for exc in exceptions: - if not isinstance(exc, BaseException): - raise TypeError(f"Expected an exception object, not {exc!r}") - if _collapse and len(exceptions) == 1: - # If this lone object happens to itself be a MultiError, then - # Python will implicitly call our __init__ on it again. See - # special handling in __init__. - return exceptions[0] - else: - # The base class __new__() implicitly invokes our __init__, which - # is what we want. - # - # In an earlier version of the code, we didn't define __init__ and - # simply set the `exceptions` attribute directly on the new object. - # However, linters expect attributes to be initialized in __init__. - from_class: type[Self | NonBaseMultiError] = cls - if all(isinstance(exc, Exception) for exc in exceptions): - from_class = NonBaseMultiError - - # Ignoring arg-type: 'Argument 3 to "__new__" of "BaseExceptionGroup" has incompatible type "list[BaseException]"; expected "Sequence[_BaseExceptionT_co]"' - # We have checked that exceptions is indeed a list of BaseException objects, this is fine. - new_obj = super().__new__(from_class, "multiple tasks failed", exceptions) # type: ignore[arg-type] - assert isinstance(new_obj, (cls, NonBaseMultiError)) - return new_obj - - def __reduce__( - self, - ) -> tuple[object, tuple[type[Self], list[BaseException]], dict[str, bool]]: - return ( - self.__new__, - (self.__class__, list(self.exceptions)), - {"collapse": self.collapse}, - ) - - def __str__(self) -> str: - return ", ".join(repr(exc) for exc in self.exceptions) - - def __repr__(self) -> str: - return f"" - - @overload # type: ignore[override] # 'Exception' != '_ExceptionT' - def derive(self, excs: Sequence[Exception], /) -> NonBaseMultiError: - ... - - @overload - def derive(self, excs: Sequence[BaseException], /) -> MultiError: - ... - - def derive( - self, excs: Sequence[Exception | BaseException], / - ) -> NonBaseMultiError | MultiError: - # We use _collapse=False here to get ExceptionGroup semantics, since derive() - # is part of the PEP 654 API - exc = MultiError(excs, _collapse=False) - exc.collapse = self.collapse - return exc - - @classmethod - def filter( - cls, - handler: Callable[[BaseException], BaseException | None], - root_exc: BaseException, - ) -> BaseException | None: - """Apply the given ``handler`` to all the exceptions in ``root_exc``. - - Args: - handler: A callable that takes an atomic (non-MultiError) exception - as input, and returns either a new exception object or None. - root_exc: An exception, often (though not necessarily) a - :exc:`MultiError`. - - Returns: - A new exception object in which each component exception ``exc`` has - been replaced by the result of running ``handler(exc)`` – or, if - ``handler`` returned None for all the inputs, returns None. - - """ - warn_deprecated( - "MultiError.filter()", - "0.22.0", - instead="BaseExceptionGroup.split()", - issue=2211, - ) - return _filter_impl(handler, root_exc) - - @classmethod - def catch( - cls, handler: Callable[[BaseException], BaseException | None] - ) -> MultiErrorCatcher: - """Return a context manager that catches and re-throws exceptions - after running :meth:`filter` on them. - - Args: - handler: as for :meth:`filter` - - """ - warn_deprecated( - "MultiError.catch", - "0.22.0", - instead="except* or exceptiongroup.catch()", - issue=2211, - ) - - return MultiErrorCatcher(handler) - - -if TYPE_CHECKING: # noqa: SIM108 - _ExceptionGroup = ExceptionGroup[Exception] -else: - _ExceptionGroup = ExceptionGroup - - -class NonBaseMultiError(MultiError, _ExceptionGroup): - __slots__ = () - - -# Clean up exception printing: -MultiError.__module__ = "trio" -NonBaseMultiError.__module__ = "trio" - -################################################################ -# concat_tb -################################################################ - -# We need to compute a new traceback that is the concatenation of two existing -# tracebacks. This requires copying the entries in 'head' and then pointing -# the final tb_next to 'tail'. -# -# NB: 'tail' might be None, which requires some special handling in the ctypes -# version. -# -# The complication here is that Python doesn't actually support copying or -# modifying traceback objects, so we have to get creative... -# -# On CPython, we use ctypes. On PyPy, we use "transparent proxies". -# -# Jinja2 is a useful source of inspiration: -# https://github.com/pallets/jinja/blob/master/jinja2/debug.py - -try: - import tputil -except ImportError: - # ctypes it is - import ctypes - - # How to handle refcounting? I don't want to use ctypes.py_object because - # I don't understand or trust it, and I don't want to use - # ctypes.pythonapi.Py_{Inc,Dec}Ref because we might clash with user code - # that also tries to use them but with different types. So private _ctypes - # APIs it is! - import _ctypes - - class CTraceback(ctypes.Structure): - _fields_: ClassVar = [ - ("PyObject_HEAD", ctypes.c_byte * object().__sizeof__()), - ("tb_next", ctypes.c_void_p), - ("tb_frame", ctypes.c_void_p), - ("tb_lasti", ctypes.c_int), - ("tb_lineno", ctypes.c_int), - ] - - def copy_tb(base_tb: TracebackType, tb_next: TracebackType | None) -> TracebackType: - # TracebackType has no public constructor, so allocate one the hard way - try: - raise ValueError - except ValueError as exc: - new_tb = exc.__traceback__ - assert new_tb is not None - c_new_tb = CTraceback.from_address(id(new_tb)) - - # At the C level, tb_next either pointer to the next traceback or is - # NULL. c_void_p and the .tb_next accessor both convert NULL to None, - # but we shouldn't DECREF None just because we assigned to a NULL - # pointer! Here we know that our new traceback has only 1 frame in it, - # so we can assume the tb_next field is NULL. - assert c_new_tb.tb_next is None - # If tb_next is None, then we want to set c_new_tb.tb_next to NULL, - # which it already is, so we're done. Otherwise, we have to actually - # do some work: - if tb_next is not None: - _ctypes.Py_INCREF(tb_next) # type: ignore[attr-defined] - c_new_tb.tb_next = id(tb_next) - - assert c_new_tb.tb_frame is not None - _ctypes.Py_INCREF(base_tb.tb_frame) # type: ignore[attr-defined] - old_tb_frame = new_tb.tb_frame - c_new_tb.tb_frame = id(base_tb.tb_frame) - _ctypes.Py_DECREF(old_tb_frame) # type: ignore[attr-defined] - - c_new_tb.tb_lasti = base_tb.tb_lasti - c_new_tb.tb_lineno = base_tb.tb_lineno - - try: - return new_tb - finally: - # delete references from locals to avoid creating cycles - # see test_MultiError_catch_doesnt_create_cyclic_garbage - del new_tb, old_tb_frame - -else: - # http://doc.pypy.org/en/latest/objspace-proxies.html - def copy_tb(base_tb: TracebackType, tb_next: TracebackType | None) -> TracebackType: - # tputil.ProxyOperation is PyPy-only, but we run mypy on CPython - def controller(operation: tputil.ProxyOperation) -> Any | None: # type: ignore[no-any-unimported] - # Rationale for pragma: I looked fairly carefully and tried a few - # things, and AFAICT it's not actually possible to get any - # 'opname' that isn't __getattr__ or __getattribute__. So there's - # no missing test we could add, and no value in coverage nagging - # us about adding one. - if ( - operation.opname - in { - "__getattribute__", - "__getattr__", - } - and operation.args[0] == "tb_next" - ): # pragma: no cover - return tb_next - return operation.delegate() # Deligate is reverting to original behaviour - - return cast( - TracebackType, tputil.make_proxy(controller, type(base_tb), base_tb) - ) # Returns proxy to traceback - - -def concat_tb( - head: TracebackType | None, tail: TracebackType | None -) -> TracebackType | None: - # We have to use an iterative algorithm here, because in the worst case - # this might be a RecursionError stack that is by definition too deep to - # process by recursion! - head_tbs = [] - pointer = head - while pointer is not None: - head_tbs.append(pointer) - pointer = pointer.tb_next - current_head = tail - for head_tb in reversed(head_tbs): - current_head = copy_tb(head_tb, tb_next=current_head) - return current_head - - -# Ubuntu's system Python has a sitecustomize.py file that import -# apport_python_hook and replaces sys.excepthook. -# -# The custom hook captures the error for crash reporting, and then calls -# sys.__excepthook__ to actually print the error. -# -# We don't mind it capturing the error for crash reporting, but we want to -# take over printing the error. So we monkeypatch the apport_python_hook -# module so that instead of calling sys.__excepthook__, it calls our custom -# hook. -# -# More details: https://github.com/python-trio/trio/issues/1065 -if sys.version_info < (3, 11) and getattr(sys.excepthook, "__name__", None) in ( - "apport_excepthook", - "partial_apport_excepthook", -): - from types import ModuleType - - import apport_python_hook - from exceptiongroup import format_exception - - assert sys.excepthook is apport_python_hook.apport_excepthook - - def replacement_excepthook( - etype: type[BaseException], value: BaseException, tb: TracebackType | None - ) -> None: - # This does work, it's an overloaded function - sys.stderr.write("".join(format_exception(etype, value, tb))) # type: ignore[arg-type] - - fake_sys = ModuleType("trio_fake_sys") - fake_sys.__dict__.update(sys.__dict__) - # Fake does have __excepthook__ after __dict__ update, but type checkers don't recognize this - fake_sys.__excepthook__ = replacement_excepthook # type: ignore[attr-defined] - apport_python_hook.sys = fake_sys diff --git a/trio/_core/_tests/test_multierror.py b/trio/_core/_tests/test_multierror.py deleted file mode 100644 index d1cd0bd4a1..0000000000 --- a/trio/_core/_tests/test_multierror.py +++ /dev/null @@ -1,517 +0,0 @@ -from __future__ import annotations - -import gc -import os -import pickle -import re -import subprocess -import sys -import warnings -from pathlib import Path -from traceback import extract_tb, print_exception -from types import TracebackType -from typing import Callable, NoReturn - -import pytest - -from ... import TrioDeprecationWarning -from ..._core import open_nursery -from .._multierror import MultiError, NonBaseMultiError, concat_tb -from .tutil import slow - -if sys.version_info < (3, 11): - from exceptiongroup import ExceptionGroup - - -class NotHashableException(Exception): - code: int | None = None - - def __init__(self, code: int) -> None: - super().__init__() - self.code = code - - def __eq__(self, other: object) -> bool: - if not isinstance(other, NotHashableException): - return False - return self.code == other.code - - -async def raise_nothashable(code: int) -> NoReturn: - raise NotHashableException(code) - - -def raiser1() -> NoReturn: - raiser1_2() - - -def raiser1_2() -> NoReturn: - raiser1_3() - - -def raiser1_3() -> NoReturn: - raise ValueError("raiser1_string") - - -def raiser2() -> NoReturn: - raiser2_2() - - -def raiser2_2() -> NoReturn: - raise KeyError("raiser2_string") - - -def raiser3() -> NoReturn: - raise NameError - - -def get_exc(raiser: Callable[[], NoReturn]) -> BaseException: - try: - raiser() - except Exception as exc: - return exc - raise AssertionError("raiser should always raise") # pragma: no cover - - -def get_tb(raiser: Callable[[], NoReturn]) -> TracebackType | None: - return get_exc(raiser).__traceback__ - - -def test_concat_tb() -> None: - tb1 = get_tb(raiser1) - tb2 = get_tb(raiser2) - - # These return a list of (filename, lineno, fn name, text) tuples - # https://docs.python.org/3/library/traceback.html#traceback.extract_tb - entries1 = extract_tb(tb1) - entries2 = extract_tb(tb2) - - tb12 = concat_tb(tb1, tb2) - assert extract_tb(tb12) == entries1 + entries2 - - tb21 = concat_tb(tb2, tb1) - assert extract_tb(tb21) == entries2 + entries1 - - # Check degenerate cases - assert extract_tb(concat_tb(None, tb1)) == entries1 - assert extract_tb(concat_tb(tb1, None)) == entries1 - assert concat_tb(None, None) is None - - # Make sure the original tracebacks didn't get mutated by mistake - assert extract_tb(get_tb(raiser1)) == entries1 - assert extract_tb(get_tb(raiser2)) == entries2 - - -def test_MultiError() -> None: - exc1 = get_exc(raiser1) - exc2 = get_exc(raiser2) - - assert MultiError([exc1]) is exc1 - m = MultiError([exc1, exc2]) - assert m.exceptions == (exc1, exc2) - assert "ValueError" in str(m) - assert "ValueError" in repr(m) - - with pytest.raises(TypeError): - MultiError(object()) # type: ignore[arg-type] - with pytest.raises(TypeError): - MultiError([KeyError(), ValueError]) # type: ignore[list-item] - - -def test_MultiErrorOfSingleMultiError() -> None: - # For MultiError([MultiError]), ensure there is no bad recursion by the - # constructor where __init__ is called if __new__ returns a bare MultiError. - exceptions = (KeyError(), ValueError()) - a = MultiError(exceptions) - b = MultiError([a]) - assert b == a - assert b.exceptions == exceptions - - -async def test_MultiErrorNotHashable() -> None: - exc1 = NotHashableException(42) - exc2 = NotHashableException(4242) - exc3 = ValueError() - assert exc1 != exc2 - assert exc1 != exc3 - - with pytest.raises(MultiError): - async with open_nursery() as nursery: - nursery.start_soon(raise_nothashable, 42) - nursery.start_soon(raise_nothashable, 4242) - - -def test_MultiError_filter_NotHashable() -> None: - excs = MultiError([NotHashableException(42), ValueError()]) - - def handle_ValueError(exc: BaseException) -> BaseException | None: - if isinstance(exc, ValueError): - return None - else: - return exc - - with pytest.warns(TrioDeprecationWarning): - filtered_excs = MultiError.filter(handle_ValueError, excs) - - assert isinstance(filtered_excs, NotHashableException) - - -def make_tree() -> MultiError: - # Returns an object like: - # MultiError([ - # MultiError([ - # ValueError, - # KeyError, - # ]), - # NameError, - # ]) - # where all exceptions except the root have a non-trivial traceback. - exc1 = get_exc(raiser1) - exc2 = get_exc(raiser2) - exc3 = get_exc(raiser3) - - # Give m12 a non-trivial traceback - try: - raise MultiError([exc1, exc2]) - except BaseException as m12: - return MultiError([m12, exc3]) - - -def assert_tree_eq( - m1: BaseException | MultiError | None, m2: BaseException | MultiError | None -) -> None: - if m1 is None or m2 is None: - assert m1 is m2 - return - assert type(m1) is type(m2) - assert extract_tb(m1.__traceback__) == extract_tb(m2.__traceback__) - assert_tree_eq(m1.__cause__, m2.__cause__) - assert_tree_eq(m1.__context__, m2.__context__) - if isinstance(m1, MultiError): - assert isinstance(m2, MultiError) - assert len(m1.exceptions) == len(m2.exceptions) - for e1, e2 in zip(m1.exceptions, m2.exceptions): - assert_tree_eq(e1, e2) - - -def test_MultiError_filter() -> None: - def null_handler(exc: BaseException) -> BaseException: - return exc - - m = make_tree() - assert_tree_eq(m, m) - with pytest.warns(TrioDeprecationWarning): - assert MultiError.filter(null_handler, m) is m - - assert_tree_eq(m, make_tree()) - - # Make sure we don't pick up any detritus if run in a context where - # implicit exception chaining would like to kick in - m = make_tree() - try: - raise ValueError - except ValueError: - with pytest.warns(TrioDeprecationWarning): - assert MultiError.filter(null_handler, m) is m - assert_tree_eq(m, make_tree()) - - def simple_filter(exc: BaseException) -> BaseException | None: - if isinstance(exc, ValueError): - return None - if isinstance(exc, KeyError): - return RuntimeError() - return exc - - with pytest.warns(TrioDeprecationWarning): - new_m = MultiError.filter(simple_filter, make_tree()) - - assert isinstance(new_m, MultiError) - assert len(new_m.exceptions) == 2 - # was: [[ValueError, KeyError], NameError] - # ValueError disappeared & KeyError became RuntimeError, so now: - assert isinstance(new_m.exceptions[0], RuntimeError) - assert isinstance(new_m.exceptions[1], NameError) - - # implicit chaining: - assert isinstance(new_m.exceptions[0].__context__, KeyError) - - # also, the traceback on the KeyError incorporates what used to be the - # traceback on its parent MultiError - orig = make_tree() - # make sure we have the right path - assert isinstance(orig.exceptions[0], MultiError) - assert isinstance(orig.exceptions[0].exceptions[1], KeyError) - # get original traceback summary - orig_extracted = ( - extract_tb(orig.__traceback__) - + extract_tb(orig.exceptions[0].__traceback__) - + extract_tb(orig.exceptions[0].exceptions[1].__traceback__) - ) - - def p(exc: BaseException) -> None: - print_exception(type(exc), exc, exc.__traceback__) - - p(orig) - p(orig.exceptions[0]) - p(orig.exceptions[0].exceptions[1]) - p(new_m.exceptions[0].__context__) - # compare to the new path - assert new_m.__traceback__ is None - new_extracted = extract_tb(new_m.exceptions[0].__context__.__traceback__) - assert orig_extracted == new_extracted - - # check preserving partial tree - def filter_NameError(exc: BaseException) -> BaseException | None: - if isinstance(exc, NameError): - return None - return exc - - m = make_tree() - with pytest.warns(TrioDeprecationWarning): - new_m = MultiError.filter(filter_NameError, m) - # with the NameError gone, the other branch gets promoted - assert new_m is m.exceptions[0] - - # check fully handling everything - def filter_all(exc: BaseException) -> None: - return None - - with pytest.warns(TrioDeprecationWarning): - assert MultiError.filter(filter_all, make_tree()) is None - - -def test_MultiError_catch() -> None: - # No exception to catch - - def noop(_: object) -> None: - pass # pragma: no cover - - with pytest.warns(TrioDeprecationWarning), MultiError.catch(noop): - pass - - # Simple pass-through of all exceptions - m = make_tree() - with pytest.raises(MultiError) as excinfo: - with pytest.warns(TrioDeprecationWarning), MultiError.catch(lambda exc: exc): - raise m - assert excinfo.value is m - # Should be unchanged, except that we added a traceback frame by raising - # it here - assert m.__traceback__ is not None - assert m.__traceback__.tb_frame.f_code.co_name == "test_MultiError_catch" - assert m.__traceback__.tb_next is None - m.__traceback__ = None - assert_tree_eq(m, make_tree()) - - # Swallows everything - with pytest.warns(TrioDeprecationWarning), MultiError.catch(lambda _: None): - raise make_tree() - - def simple_filter(exc): - if isinstance(exc, ValueError): - return None - if isinstance(exc, KeyError): - return RuntimeError() - return exc - - with pytest.raises(MultiError) as excinfo: - with pytest.warns(TrioDeprecationWarning), MultiError.catch(simple_filter): - raise make_tree() - new_m = excinfo.value - assert isinstance(new_m, MultiError) - assert len(new_m.exceptions) == 2 - # was: [[ValueError, KeyError], NameError] - # ValueError disappeared & KeyError became RuntimeError, so now: - assert isinstance(new_m.exceptions[0], RuntimeError) - assert isinstance(new_m.exceptions[1], NameError) - # Make sure that Python did not successfully attach the old MultiError to - # our new MultiError's __context__ - assert not new_m.__suppress_context__ - assert new_m.__context__ is None - - # check preservation of __cause__ and __context__ - v = ValueError() - v.__cause__ = KeyError() - with pytest.raises(ValueError) as excinfo: - with pytest.warns(TrioDeprecationWarning), MultiError.catch(lambda exc: exc): - raise v - assert isinstance(excinfo.value.__cause__, KeyError) - - v = ValueError() - context = KeyError() - v.__context__ = context - with pytest.raises(ValueError) as excinfo: - with pytest.warns(TrioDeprecationWarning), MultiError.catch(lambda exc: exc): - raise v - assert excinfo.value.__context__ is context - assert not excinfo.value.__suppress_context__ - - for suppress_context in [True, False]: - v = ValueError() - context = KeyError() - v.__context__ = context - v.__suppress_context__ = suppress_context - distractor = RuntimeError() - with pytest.raises(ValueError) as excinfo: - - def catch_RuntimeError(exc): - if isinstance(exc, RuntimeError): - return None - else: - return exc - - with pytest.warns(TrioDeprecationWarning): - with MultiError.catch(catch_RuntimeError): - raise MultiError([v, distractor]) - assert excinfo.value.__context__ is context - assert excinfo.value.__suppress_context__ == suppress_context - - -@pytest.mark.skipif( - sys.implementation.name != "cpython", reason="Only makes sense with refcounting GC" -) -def test_MultiError_catch_doesnt_create_cyclic_garbage() -> None: - # https://github.com/python-trio/trio/pull/2063 - gc.collect() - old_flags = gc.get_debug() - - def make_multi() -> NoReturn: - # make_tree creates cycles itself, so a simple - raise MultiError([get_exc(raiser1), get_exc(raiser2)]) - - def simple_filter(exc: BaseException) -> Exception | RuntimeError: - if isinstance(exc, ValueError): - return Exception() - if isinstance(exc, KeyError): - return RuntimeError() - raise AssertionError( - "only ValueError and KeyError should exist" - ) # pragma: no cover - - try: - gc.set_debug(gc.DEBUG_SAVEALL) - with pytest.raises(MultiError): - # covers MultiErrorCatcher.__exit__ and _multierror.copy_tb - with pytest.warns(TrioDeprecationWarning), MultiError.catch(simple_filter): - raise make_multi() - gc.collect() - assert not gc.garbage - finally: - gc.set_debug(old_flags) - gc.garbage.clear() - - -def assert_match_in_seq(pattern_list: list[str], string: str) -> None: - offset = 0 - print("looking for pattern matches...") - for pattern in pattern_list: - print("checking pattern:", pattern) - reobj = re.compile(pattern) - match = reobj.search(string, offset) - assert match is not None - offset = match.end() - - -def test_assert_match_in_seq() -> None: - assert_match_in_seq(["a", "b"], "xx a xx b xx") - assert_match_in_seq(["b", "a"], "xx b xx a xx") - with pytest.raises(AssertionError): - assert_match_in_seq(["a", "b"], "xx b xx a xx") - - -def test_base_multierror() -> None: - """ - Test that MultiError() with at least one base exception will return a MultiError - object. - """ - - exc = MultiError([ZeroDivisionError(), KeyboardInterrupt()]) - assert type(exc) is MultiError - - -def test_non_base_multierror() -> None: - """ - Test that MultiError() without base exceptions will return a NonBaseMultiError - object. - """ - - exc = MultiError([ZeroDivisionError(), ValueError()]) - assert type(exc) is NonBaseMultiError - assert isinstance(exc, ExceptionGroup) - - -def run_script(name: str) -> subprocess.CompletedProcess[bytes]: - import trio - - trio_path = Path(trio.__file__).parent.parent - script_path = Path(__file__).parent / "test_multierror_scripts" / name - - env = dict(os.environ) - print("parent PYTHONPATH:", env.get("PYTHONPATH")) - pp = [] - if "PYTHONPATH" in env: # pragma: no cover - pp = env["PYTHONPATH"].split(os.pathsep) - pp.insert(0, str(trio_path)) - pp.insert(0, str(script_path.parent)) - env["PYTHONPATH"] = os.pathsep.join(pp) - print("subprocess PYTHONPATH:", env.get("PYTHONPATH")) - - cmd = [sys.executable, "-u", str(script_path)] - print("running:", cmd) - completed = subprocess.run( - cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT - ) - print("process output:") - print(completed.stdout.decode("utf-8")) - return completed - - -@slow -@pytest.mark.skipif( - not Path("/usr/lib/python3/dist-packages/apport_python_hook.py").exists(), - reason="need Ubuntu with python3-apport installed", -) -def test_apport_excepthook_monkeypatch_interaction() -> None: - completed = run_script("apport_excepthook.py") - stdout = completed.stdout.decode("utf-8") - - # No warning - assert "custom sys.excepthook" not in stdout - - # Proper traceback - assert_match_in_seq( - ["--- 1 ---", "KeyError", "--- 2 ---", "ValueError"], - stdout, - ) - - -@pytest.mark.parametrize("protocol", range(0, pickle.HIGHEST_PROTOCOL + 1)) -def test_pickle_multierror(protocol: int) -> None: - # use trio.MultiError to make sure that pickle works through the deprecation layer - import trio - - my_except = ZeroDivisionError() - - try: - 1 / 0 # noqa: B018 # "useless statement" - except ZeroDivisionError as exc: - my_except = exc - - # MultiError will collapse into different classes depending on the errors - for cls, errors in ( - (ZeroDivisionError, [my_except]), - (NonBaseMultiError, [my_except, ValueError()]), - (MultiError, [BaseException(), my_except]), - ): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", TrioDeprecationWarning) - me = trio.MultiError(errors) # type: ignore[attr-defined] - dump = pickle.dumps(me, protocol=protocol) - load = pickle.loads(dump) - assert repr(me) == repr(load) - assert me.__class__ == load.__class__ == cls - - assert me.__dict__.keys() == load.__dict__.keys() - for me_val, load_val in zip(me.__dict__.values(), load.__dict__.values()): - # tracebacks etc are not preserved through pickling for the default - # exceptions, so we only check that the repr stays the same - assert repr(me_val) == repr(load_val) diff --git a/trio/_core/_tests/test_multierror_scripts/__init__.py b/trio/_core/_tests/test_multierror_scripts/__init__.py deleted file mode 100644 index a1f6cb598d..0000000000 --- a/trio/_core/_tests/test_multierror_scripts/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# This isn't really a package, everything in here is a standalone script. This -# __init__.py is just to fool setup.py into actually installing the things. diff --git a/trio/_core/_tests/test_multierror_scripts/_common.py b/trio/_core/_tests/test_multierror_scripts/_common.py deleted file mode 100644 index 0c70df1840..0000000000 --- a/trio/_core/_tests/test_multierror_scripts/_common.py +++ /dev/null @@ -1,7 +0,0 @@ -# https://coverage.readthedocs.io/en/latest/subprocess.html -try: - import coverage -except ImportError: # pragma: no cover - pass -else: - coverage.process_startup() diff --git a/trio/_core/_tests/test_multierror_scripts/apport_excepthook.py b/trio/_core/_tests/test_multierror_scripts/apport_excepthook.py deleted file mode 100644 index 0e46f37e17..0000000000 --- a/trio/_core/_tests/test_multierror_scripts/apport_excepthook.py +++ /dev/null @@ -1,15 +0,0 @@ -# The apport_python_hook package is only installed as part of Ubuntu's system -# python, and not available in venvs. So before we can import it we have to -# make sure it's on sys.path. -import sys - -import _common # isort: split - -sys.path.append("/usr/lib/python3/dist-packages") -import apport_python_hook - -apport_python_hook.install() - -from trio._core._multierror import MultiError # Bypass deprecation warnings - -raise MultiError([KeyError("key_error"), ValueError("value_error")]) diff --git a/trio/_core/_tests/test_multierror_scripts/simple_excepthook.py b/trio/_core/_tests/test_multierror_scripts/simple_excepthook.py deleted file mode 100644 index 236d34e9ba..0000000000 --- a/trio/_core/_tests/test_multierror_scripts/simple_excepthook.py +++ /dev/null @@ -1,21 +0,0 @@ -import _common # isort: split - -from trio._core._multierror import MultiError # Bypass deprecation warnings - - -def exc1_fn() -> Exception: - try: - raise ValueError - except Exception as exc: - return exc - - -def exc2_fn() -> Exception: - try: - raise KeyError - except Exception as exc: - return exc - - -# This should be printed nicely, because Trio overrode sys.excepthook -raise MultiError([exc1_fn(), exc2_fn()]) diff --git a/trio/_path.py b/trio/_path.py deleted file mode 100644 index 508ad5d04d..0000000000 --- a/trio/_path.py +++ /dev/null @@ -1,490 +0,0 @@ -from __future__ import annotations - -import inspect -import os -import pathlib -import sys -import types -from collections.abc import Awaitable, Callable, Iterable, Sequence -from functools import partial -from io import BufferedRandom, BufferedReader, BufferedWriter, FileIO, TextIOWrapper -from typing import ( - IO, - TYPE_CHECKING, - Any, - BinaryIO, - ClassVar, - TypeVar, - Union, - cast, - overload, -) - -import trio -from trio._file_io import AsyncIOWrapper as _AsyncIOWrapper -from trio._util import async_wraps, final, wraps - -if TYPE_CHECKING: - from _typeshed import ( - OpenBinaryMode, - OpenBinaryModeReading, - OpenBinaryModeUpdating, - OpenBinaryModeWriting, - OpenTextMode, - ) - from typing_extensions import Concatenate, Literal, ParamSpec, TypeAlias - - P = ParamSpec("P") - -T = TypeVar("T") -StrPath: TypeAlias = Union[str, "os.PathLike[str]"] # Only subscriptable in 3.9+ - - -# re-wrap return value from methods that return new instances of pathlib.Path -def rewrap_path(value: T) -> T | Path: - if isinstance(value, pathlib.Path): - return Path(value) - else: - return value - - -def _forward_factory( - cls: AsyncAutoWrapperType, - attr_name: str, - attr: Callable[Concatenate[pathlib.Path, P], T], -) -> Callable[Concatenate[Path, P], T | Path]: - @wraps(attr) - def wrapper(self: Path, *args: P.args, **kwargs: P.kwargs) -> T | Path: - attr = getattr(self._wrapped, attr_name) - value = attr(*args, **kwargs) - return rewrap_path(value) - - # Assigning this makes inspect and therefore Sphinx show the original parameters. - # It's not defined on functions normally though, this is a custom attribute. - assert isinstance(wrapper, types.FunctionType) - wrapper.__signature__ = inspect.signature(attr) - - return wrapper - - -def _forward_magic( - cls: AsyncAutoWrapperType, attr: Callable[..., T] -) -> Callable[..., Path | T]: - sentinel = object() - - @wraps(attr) - def wrapper(self: Path, other: object = sentinel) -> Path | T: - if other is sentinel: - return attr(self._wrapped) - if isinstance(other, cls): - other = cast(Path, other)._wrapped - value = attr(self._wrapped, other) - return rewrap_path(value) - - assert isinstance(wrapper, types.FunctionType) - wrapper.__signature__ = inspect.signature(attr) - return wrapper - - -def iter_wrapper_factory( - cls: AsyncAutoWrapperType, meth_name: str -) -> Callable[Concatenate[Path, P], Awaitable[Iterable[Path]]]: - @async_wraps(cls, cls._wraps, meth_name) - async def wrapper(self: Path, *args: P.args, **kwargs: P.kwargs) -> Iterable[Path]: - meth = getattr(self._wrapped, meth_name) - func = partial(meth, *args, **kwargs) - # Make sure that the full iteration is performed in the thread - # by converting the generator produced by pathlib into a list - items = await trio.to_thread.run_sync(lambda: list(func())) - return (rewrap_path(item) for item in items) - - return wrapper - - -def thread_wrapper_factory( - cls: AsyncAutoWrapperType, meth_name: str -) -> Callable[Concatenate[Path, P], Awaitable[Path]]: - @async_wraps(cls, cls._wraps, meth_name) - async def wrapper(self: Path, *args: P.args, **kwargs: P.kwargs) -> Path: - meth = getattr(self._wrapped, meth_name) - func = partial(meth, *args, **kwargs) - value = await trio.to_thread.run_sync(func) - return rewrap_path(value) - - return wrapper - - -def classmethod_wrapper_factory( - cls: AsyncAutoWrapperType, meth_name: str -) -> classmethod: # type: ignore[type-arg] - @async_wraps(cls, cls._wraps, meth_name) - async def wrapper(cls: type[Path], *args: Any, **kwargs: Any) -> Path: # type: ignore[misc] # contains Any - meth = getattr(cls._wraps, meth_name) - func = partial(meth, *args, **kwargs) - value = await trio.to_thread.run_sync(func) - return rewrap_path(value) - - assert isinstance(wrapper, types.FunctionType) - wrapper.__signature__ = inspect.signature(getattr(cls._wraps, meth_name)) - return classmethod(wrapper) - - -class AsyncAutoWrapperType(type): - _forwards: type - _wraps: type - _forward_magic: list[str] - _wrap_iter: list[str] - _forward: list[str] - - def __init__( - cls, name: str, bases: tuple[type, ...], attrs: dict[str, object] - ) -> None: - super().__init__(name, bases, attrs) - - cls._forward = [] - type(cls).generate_forwards(cls, attrs) - type(cls).generate_wraps(cls, attrs) - type(cls).generate_magic(cls, attrs) - type(cls).generate_iter(cls, attrs) - - def generate_forwards(cls, attrs: dict[str, object]) -> None: - # forward functions of _forwards - for attr_name, attr in cls._forwards.__dict__.items(): - if attr_name.startswith("_") or attr_name in attrs: - continue - - if isinstance(attr, property): - cls._forward.append(attr_name) - elif isinstance(attr, types.FunctionType): - wrapper = _forward_factory(cls, attr_name, attr) - setattr(cls, attr_name, wrapper) - else: - raise TypeError(attr_name, type(attr)) - - def generate_wraps(cls, attrs: dict[str, object]) -> None: - # generate wrappers for functions of _wraps - wrapper: classmethod | Callable[..., object] # type: ignore[type-arg] - for attr_name, attr in cls._wraps.__dict__.items(): - # .z. exclude cls._wrap_iter - if attr_name.startswith("_") or attr_name in attrs: - continue - if isinstance(attr, classmethod): - wrapper = classmethod_wrapper_factory(cls, attr_name) - setattr(cls, attr_name, wrapper) - elif isinstance(attr, types.FunctionType): - wrapper = thread_wrapper_factory(cls, attr_name) - assert isinstance(wrapper, types.FunctionType) - wrapper.__signature__ = inspect.signature(attr) - setattr(cls, attr_name, wrapper) - else: - raise TypeError(attr_name, type(attr)) - - def generate_magic(cls, attrs: dict[str, object]) -> None: - # generate wrappers for magic - for attr_name in cls._forward_magic: - attr = getattr(cls._forwards, attr_name) - wrapper = _forward_magic(cls, attr) - setattr(cls, attr_name, wrapper) - - def generate_iter(cls, attrs: dict[str, object]) -> None: - # generate wrappers for methods that return iterators - wrapper: Callable[..., object] - for attr_name, attr in cls._wraps.__dict__.items(): - if attr_name in cls._wrap_iter: - wrapper = iter_wrapper_factory(cls, attr_name) - assert isinstance(wrapper, types.FunctionType) - wrapper.__signature__ = inspect.signature(attr) - setattr(cls, attr_name, wrapper) - - -@final -class Path(metaclass=AsyncAutoWrapperType): - """A :class:`pathlib.Path` wrapper that executes blocking methods in - :meth:`trio.to_thread.run_sync`. - - """ - - _forward: ClassVar[list[str]] - _wraps: ClassVar[type] = pathlib.Path - _forwards: ClassVar[type] = pathlib.PurePath - _forward_magic: ClassVar[list[str]] = [ - "__str__", - "__bytes__", - "__truediv__", - "__rtruediv__", - "__eq__", - "__lt__", - "__le__", - "__gt__", - "__ge__", - "__hash__", - ] - _wrap_iter: ClassVar[list[str]] = ["glob", "rglob", "iterdir"] - - def __init__(self, *args: StrPath) -> None: - self._wrapped = pathlib.Path(*args) - - # type checkers allow accessing any attributes on class instances with `__getattr__` - # so we hide it behind a type guard forcing it to rely on the hardcoded attribute - # list below. - if not TYPE_CHECKING: - - def __getattr__(self, name): - if name in self._forward: - value = getattr(self._wrapped, name) - return rewrap_path(value) - raise AttributeError(name) - - def __dir__(self) -> list[str]: - return [*super().__dir__(), *self._forward] - - def __repr__(self) -> str: - return f"trio.Path({str(self)!r})" - - def __fspath__(self) -> str: - return os.fspath(self._wrapped) - - @overload - async def open( - self, - mode: OpenTextMode = "r", - buffering: int = -1, - encoding: str | None = None, - errors: str | None = None, - newline: str | None = None, - ) -> _AsyncIOWrapper[TextIOWrapper]: - ... - - @overload - async def open( - self, - mode: OpenBinaryMode, - buffering: Literal[0], - encoding: None = None, - errors: None = None, - newline: None = None, - ) -> _AsyncIOWrapper[FileIO]: - ... - - @overload - async def open( - self, - mode: OpenBinaryModeUpdating, - buffering: Literal[-1, 1] = -1, - encoding: None = None, - errors: None = None, - newline: None = None, - ) -> _AsyncIOWrapper[BufferedRandom]: - ... - - @overload - async def open( - self, - mode: OpenBinaryModeWriting, - buffering: Literal[-1, 1] = -1, - encoding: None = None, - errors: None = None, - newline: None = None, - ) -> _AsyncIOWrapper[BufferedWriter]: - ... - - @overload - async def open( - self, - mode: OpenBinaryModeReading, - buffering: Literal[-1, 1] = -1, - encoding: None = None, - errors: None = None, - newline: None = None, - ) -> _AsyncIOWrapper[BufferedReader]: - ... - - @overload - async def open( - self, - mode: OpenBinaryMode, - buffering: int = -1, - encoding: None = None, - errors: None = None, - newline: None = None, - ) -> _AsyncIOWrapper[BinaryIO]: - ... - - @overload - async def open( # type: ignore[misc] # Any usage matches builtins.open(). - self, - mode: str, - buffering: int = -1, - encoding: str | None = None, - errors: str | None = None, - newline: str | None = None, - ) -> _AsyncIOWrapper[IO[Any]]: - ... - - @wraps(pathlib.Path.open) # type: ignore[misc] # Overload return mismatch. - async def open(self, *args: Any, **kwargs: Any) -> _AsyncIOWrapper[IO[Any]]: - """Open the file pointed to by the path, like the :func:`trio.open_file` - function does. - - """ - - func = partial(self._wrapped.open, *args, **kwargs) - value = await trio.to_thread.run_sync(func) - return trio.wrap_file(value) - - if TYPE_CHECKING: - # the dunders listed in _forward_magic that aren't seen otherwise - # fmt: off - def __bytes__(self) -> bytes: ... - def __truediv__(self, other: StrPath) -> Path: ... - def __rtruediv__(self, other: StrPath) -> Path: ... - def __lt__(self, other: Path | pathlib.PurePath) -> bool: ... - def __le__(self, other: Path | pathlib.PurePath) -> bool: ... - def __gt__(self, other: Path | pathlib.PurePath) -> bool: ... - def __ge__(self, other: Path | pathlib.PurePath) -> bool: ... - - # The following are ordered the same as in typeshed. - - # Properties produced by __getattr__() - all synchronous. - @property - def parts(self) -> tuple[str, ...]: ... - @property - def drive(self) -> str: ... - @property - def root(self) -> str: ... - @property - def anchor(self) -> str: ... - @property - def name(self) -> str: ... - @property - def suffix(self) -> str: ... - @property - def suffixes(self) -> list[str]: ... - @property - def stem(self) -> str: ... - @property - def parents(self) -> Sequence[pathlib.Path]: ... # TODO: Convert these to trio Paths? - @property - def parent(self) -> Path: ... - - # PurePath methods - synchronous. - def as_posix(self) -> str: ... - def as_uri(self) -> str: ... - def is_absolute(self) -> bool: ... - def is_reserved(self) -> bool: ... - def match(self, path_pattern: str) -> bool: ... - def relative_to(self, *other: StrPath) -> Path: ... - def with_name(self, name: str) -> Path: ... - def with_suffix(self, suffix: str) -> Path: ... - def joinpath(self, *other: StrPath) -> Path: ... - - if sys.version_info >= (3, 9): - def is_relative_to(self, *other: StrPath) -> bool: ... - def with_stem(self, stem: str) -> Path: ... - - # pathlib.Path methods and properties - async. - @classmethod - async def cwd(self) -> Path: ... - - if sys.version_info >= (3, 10): - async def stat(self, *, follow_symlinks: bool = True) -> os.stat_result: ... - async def chmod(self, mode: int, *, follow_symlinks: bool = True) -> None: ... - else: - async def stat(self) -> os.stat_result: ... - async def chmod(self, mode: int) -> None: ... - - async def exists(self) -> bool: ... - async def glob(self, pattern: str) -> Iterable[Path]: ... - async def is_dir(self) -> bool: ... - async def is_file(self) -> bool: ... - async def is_symlink(self) -> bool: ... - async def is_socket(self) -> bool: ... - async def is_fifo(self) -> bool: ... - async def is_block_device(self) -> bool: ... - async def is_char_device(self) -> bool: ... - async def iterdir(self) -> Iterable[Path]: ... - async def lchmod(self, mode: int) -> None: ... - async def lstat(self) -> os.stat_result: ... - async def mkdir(self, mode: int = 0o777, parents: bool = False, exist_ok: bool = False) -> None: ... - - if sys.platform != "win32": - async def owner(self) -> str: ... - async def group(self) -> str: ... - async def is_mount(self) -> bool: ... - if sys.version_info >= (3, 9): - async def readlink(self) -> Path: ... - async def rename(self, target: StrPath) -> Path: ... - async def replace(self, target: StrPath) -> Path: ... - async def resolve(self, strict: bool = False) -> Path: ... - async def rglob(self, pattern: str) -> Iterable[Path]: ... - async def rmdir(self) -> None: ... - async def symlink_to(self, target: StrPath, target_is_directory: bool = False) -> None: ... - if sys.version_info >= (3, 10): - async def hardlink_to(self, target: str | pathlib.Path) -> None: ... - async def touch(self, mode: int = 0o666, exist_ok: bool = True) -> None: ... - async def unlink(self, missing_ok: bool = False) -> None: ... - @classmethod - async def home(self) -> Path: ... - async def absolute(self) -> Path: ... - async def expanduser(self) -> Path: ... - async def read_bytes(self) -> bytes: ... - async def read_text(self, encoding: str | None = None, errors: str | None = None) -> str: ... - async def samefile(self, other_path: bytes | int | StrPath) -> bool: ... - async def write_bytes(self, data: bytes) -> int: ... - - if sys.version_info >= (3, 10): - async def write_text( - self, data: str, - encoding: str | None = None, - errors: str | None = None, - newline: str | None = None, - ) -> int: ... - else: - async def write_text( - self, data: str, - encoding: str | None = None, - errors: str | None = None, - ) -> int: ... - - if sys.version_info < (3, 12): - async def link_to(self, target: StrPath | bytes) -> None: ... - if sys.version_info >= (3, 12): - async def is_junction(self) -> bool: ... - walk: Any # TODO - async def with_segments(self, *pathsegments: StrPath) -> Path: ... - - -Path.iterdir.__doc__ = """ - Like :meth:`~pathlib.Path.iterdir`, but async. - - This is an async method that returns a synchronous iterator, so you - use it like:: - - for subpath in await mypath.iterdir(): - ... - - Note that it actually loads the whole directory list into memory - immediately, during the initial call. (See `issue #501 - `__ for discussion.) - -""" - -if sys.version_info < (3, 12): - # Since we synthesise methods from the stdlib, this automatically will - # have deprecation warnings, and disappear entirely in 3.12+. - Path.link_to.__doc__ = """ - Like Python 3.8-3.11's :meth:`~pathlib.Path.link_to`, but async. - - :deprecated: This method was deprecated in Python 3.10 and entirely \ - removed in 3.12. Use :meth:`hardlink_to` instead which has \ - a more meaningful parameter order. -""" - -# The value of Path.absolute.__doc__ makes a reference to -# :meth:~pathlib.Path.absolute, which does not exist. Removing this makes more -# sense than inventing our own special docstring for this. -del Path.absolute.__doc__ - -# TODO: This is likely not supported by all the static tools out there, see discussion in -# https://github.com/python-trio/trio/pull/2631#discussion_r1185612528 -os.PathLike.register(Path) diff --git a/trio/_tests/check_type_completeness.py b/trio/_tests/check_type_completeness.py deleted file mode 100755 index 6401b21002..0000000000 --- a/trio/_tests/check_type_completeness.py +++ /dev/null @@ -1,191 +0,0 @@ -#!/usr/bin/env python3 -from __future__ import annotations - -# this file is not run as part of the tests, instead it's run standalone from check.sh -import argparse -import json -import subprocess -import sys -from collections.abc import Mapping -from pathlib import Path - -# the result file is not marked in MANIFEST.in so it's not included in the package -failed = False - - -def get_result_file_name(platform: str) -> Path: - return Path(__file__).parent / f"verify_types_{platform.lower()}.json" - - -# TODO: consider checking manually without `--ignoreexternal`, and/or -# removing it from the below call later on. -def run_pyright(platform: str) -> subprocess.CompletedProcess[bytes]: - return subprocess.run( - [ - "pyright", - # Specify a platform and version to keep imported modules consistent. - f"--pythonplatform={platform}", - "--pythonversion=3.8", - "--verifytypes=trio", - "--outputjson", - "--ignoreexternal", - ], - capture_output=True, - ) - - -def check_less_than( - key: str, - current_dict: Mapping[str, int | float], - last_dict: Mapping[str, int | float], - /, - invert: bool = False, -) -> None: - global failed - current = current_dict[key] - last = last_dict[key] - assert isinstance(current, (float, int)) - assert isinstance(last, (float, int)) - if current == last: - return - if (current > last) ^ invert: - failed = True - print("ERROR: ", end="") - strcurrent = f"{current:.4}" if isinstance(current, float) else str(current) - strlast = f"{last:.4}" if isinstance(last, float) else str(last) - print( - f"{key} has gone {'down' if current None: - global failed - if current_dict[key] != 0: - failed = True - print(f"ERROR: {key} is {current_dict[key]}") - - -def check_type(args: argparse.Namespace, platform: str) -> int: - print("*" * 20, "\nChecking type completeness hasn't gone down...") - - res = run_pyright(platform) - current_result = json.loads(res.stdout) - py_typed_file: Path | None = None - - # check if py.typed file was missing - if ( - current_result["generalDiagnostics"] - and current_result["generalDiagnostics"][0]["message"] - == "No py.typed file found" - ): - print("creating py.typed") - py_typed_file = ( - Path(current_result["typeCompleteness"]["packageRootDirectory"]) - / "py.typed" - ) - py_typed_file.write_text("") - - res = run_pyright(platform) - current_result = json.loads(res.stdout) - - if res.stderr: - print(res.stderr) - - last_result = json.loads(get_result_file_name(platform).read_text()) - - for key in "errorCount", "warningCount", "informationCount": - check_zero(key, current_result["summary"]) - - for key, invert in ( - ("missingFunctionDocStringCount", False), - ("missingClassDocStringCount", False), - ("missingDefaultParamCount", False), - ("completenessScore", True), - ): - check_less_than( - key, - current_result["typeCompleteness"], - last_result["typeCompleteness"], - invert=invert, - ) - - for key, invert in ( - ("withUnknownType", False), - ("withAmbiguousType", False), - ("withKnownType", True), - ): - check_less_than( - key, - current_result["typeCompleteness"]["exportedSymbolCounts"], - last_result["typeCompleteness"]["exportedSymbolCounts"], - invert=invert, - ) - - if args.overwrite_file: - print("Overwriting file") - - # don't care about differences in time taken - del current_result["time"] - del current_result["summary"]["timeInSec"] - - # don't fail on version diff so pyright updates can be automerged - del current_result["version"] - - for key in ( - # don't save path (because that varies between machines) - "moduleRootDirectory", - "packageRootDirectory", - "pyTypedPath", - ): - del current_result["typeCompleteness"][key] - - # prune the symbols to only be the name of the symbols with - # errors, instead of saving a huge file. - new_symbols: list[dict[str, str]] = [] - for symbol in current_result["typeCompleteness"]["symbols"]: - if symbol["diagnostics"]: - # function name + message should be enough context for people! - new_symbols.extend( - {"name": symbol["name"], "message": diagnostic["message"]} - for diagnostic in symbol["diagnostics"] - ) - continue - - # Ensure order of arrays does not affect result. - new_symbols.sort(key=lambda module: module.get("name", "")) - current_result["generalDiagnostics"].sort() - current_result["typeCompleteness"]["modules"].sort( - key=lambda module: module.get("name", "") - ) - - del current_result["typeCompleteness"]["symbols"] - current_result["typeCompleteness"]["diagnostics"] = new_symbols - - with open(get_result_file_name(platform), "w") as file: - json.dump(current_result, file, sort_keys=True, indent=2) - # add newline at end of file so it's easier to manually modify - file.write("\n") - - if py_typed_file is not None: - print("deleting py.typed") - py_typed_file.unlink() - - print("*" * 20) - - return int(failed) - - -def main(args: argparse.Namespace) -> int: - res = 0 - for platform in "Linux", "Windows", "Darwin": - res += check_type(args, platform) - return res - - -parser = argparse.ArgumentParser() -parser.add_argument("--overwrite-file", action="store_true", default=False) -parser.add_argument("--full-diagnostics-file", type=Path, default=None) -args = parser.parse_args() - -assert __name__ == "__main__", "This script should be run standalone" -sys.exit(main(args)) diff --git a/trio/_tests/test_fakenet.py b/trio/_tests/test_fakenet.py deleted file mode 100644 index d250a105a3..0000000000 --- a/trio/_tests/test_fakenet.py +++ /dev/null @@ -1,51 +0,0 @@ -import errno - -import pytest - -import trio -from trio.testing._fake_net import FakeNet - - -def fn() -> FakeNet: - fn = FakeNet() - fn.enable() - return fn - - -async def test_basic_udp() -> None: - fn() - s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) - s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) - - await s1.bind(("127.0.0.1", 0)) - ip, port = s1.getsockname() - assert ip == "127.0.0.1" - assert port != 0 - - with pytest.raises(OSError) as exc: # Cannot rebind. - await s1.bind(("192.0.2.1", 0)) - assert exc.value.errno == errno.EINVAL - - await s2.sendto(b"xyz", s1.getsockname()) - data, addr = await s1.recvfrom(10) - assert data == b"xyz" - assert addr == s2.getsockname() - await s1.sendto(b"abc", s2.getsockname()) - data, addr = await s2.recvfrom(10) - assert data == b"abc" - assert addr == s1.getsockname() - - -async def test_msg_trunc() -> None: - fn() - s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) - s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM) - await s1.bind(("127.0.0.1", 0)) - await s2.sendto(b"xyz", s1.getsockname()) - data, addr = await s1.recvfrom(10) - - -async def test_basic_tcp() -> None: - fn() - with pytest.raises(NotImplementedError): - trio.socket.socket() diff --git a/trio/_tests/verify_types_darwin.json b/trio/_tests/verify_types_darwin.json deleted file mode 100644 index 26697a9512..0000000000 --- a/trio/_tests/verify_types_darwin.json +++ /dev/null @@ -1,84 +0,0 @@ -{ - "generalDiagnostics": [], - "summary": { - "errorCount": 0, - "filesAnalyzed": 8, - "informationCount": 0, - "warningCount": 0 - }, - "typeCompleteness": { - "completenessScore": 1, - "diagnostics": [ - { - "message": "No docstring found for function \"trio.lowlevel.current_kqueue\"", - "name": "trio.lowlevel.current_kqueue" - }, - { - "message": "No docstring found for function \"trio.lowlevel.monitor_kevent\"", - "name": "trio.lowlevel.monitor_kevent" - }, - { - "message": "No docstring found for function \"trio.lowlevel.notify_closing\"", - "name": "trio.lowlevel.notify_closing" - }, - { - "message": "No docstring found for function \"trio.lowlevel.wait_kevent\"", - "name": "trio.lowlevel.wait_kevent" - }, - { - "message": "No docstring found for function \"trio.lowlevel.wait_readable\"", - "name": "trio.lowlevel.wait_readable" - }, - { - "message": "No docstring found for function \"trio.lowlevel.wait_writable\"", - "name": "trio.lowlevel.wait_writable" - }, - { - "message": "No docstring found for class \"trio.tests.TestsDeprecationWrapper\"", - "name": "trio.tests.TestsDeprecationWrapper" - } - ], - "exportedSymbolCounts": { - "withAmbiguousType": 0, - "withKnownType": 632, - "withUnknownType": 0 - }, - "ignoreUnknownTypesFromImports": true, - "missingClassDocStringCount": 1, - "missingDefaultParamCount": 0, - "missingFunctionDocStringCount": 6, - "moduleName": "trio", - "modules": [ - { - "name": "trio" - }, - { - "name": "trio.abc" - }, - { - "name": "trio.from_thread" - }, - { - "name": "trio.lowlevel" - }, - { - "name": "trio.socket" - }, - { - "name": "trio.testing" - }, - { - "name": "trio.tests" - }, - { - "name": "trio.to_thread" - } - ], - "otherSymbolCounts": { - "withAmbiguousType": 0, - "withKnownType": 699, - "withUnknownType": 0 - }, - "packageName": "trio" - } -} diff --git a/trio/_tests/verify_types_linux.json b/trio/_tests/verify_types_linux.json deleted file mode 100644 index 66edc50b6b..0000000000 --- a/trio/_tests/verify_types_linux.json +++ /dev/null @@ -1,72 +0,0 @@ -{ - "generalDiagnostics": [], - "summary": { - "errorCount": 0, - "filesAnalyzed": 8, - "informationCount": 0, - "warningCount": 0 - }, - "typeCompleteness": { - "completenessScore": 1, - "diagnostics": [ - { - "message": "No docstring found for function \"trio.lowlevel.notify_closing\"", - "name": "trio.lowlevel.notify_closing" - }, - { - "message": "No docstring found for function \"trio.lowlevel.wait_readable\"", - "name": "trio.lowlevel.wait_readable" - }, - { - "message": "No docstring found for function \"trio.lowlevel.wait_writable\"", - "name": "trio.lowlevel.wait_writable" - }, - { - "message": "No docstring found for class \"trio.tests.TestsDeprecationWrapper\"", - "name": "trio.tests.TestsDeprecationWrapper" - } - ], - "exportedSymbolCounts": { - "withAmbiguousType": 0, - "withKnownType": 629, - "withUnknownType": 0 - }, - "ignoreUnknownTypesFromImports": true, - "missingClassDocStringCount": 1, - "missingDefaultParamCount": 0, - "missingFunctionDocStringCount": 3, - "moduleName": "trio", - "modules": [ - { - "name": "trio" - }, - { - "name": "trio.abc" - }, - { - "name": "trio.from_thread" - }, - { - "name": "trio.lowlevel" - }, - { - "name": "trio.socket" - }, - { - "name": "trio.testing" - }, - { - "name": "trio.tests" - }, - { - "name": "trio.to_thread" - } - ], - "otherSymbolCounts": { - "withAmbiguousType": 0, - "withKnownType": 699, - "withUnknownType": 0 - }, - "packageName": "trio" - } -} diff --git a/trio/_tests/verify_types_windows.json b/trio/_tests/verify_types_windows.json deleted file mode 100644 index 4868da895c..0000000000 --- a/trio/_tests/verify_types_windows.json +++ /dev/null @@ -1,108 +0,0 @@ -{ - "generalDiagnostics": [], - "summary": { - "errorCount": 0, - "filesAnalyzed": 8, - "informationCount": 0, - "warningCount": 0 - }, - "typeCompleteness": { - "completenessScore": 1, - "diagnostics": [ - { - "message": "No docstring found for function \"trio.lowlevel.current_iocp\"", - "name": "trio.lowlevel.current_iocp" - }, - { - "message": "No docstring found for function \"trio.lowlevel.monitor_completion_key\"", - "name": "trio.lowlevel.monitor_completion_key" - }, - { - "message": "No docstring found for function \"trio.lowlevel.notify_closing\"", - "name": "trio.lowlevel.notify_closing" - }, - { - "message": "No docstring found for function \"trio.lowlevel.open_process\"", - "name": "trio.lowlevel.open_process" - }, - { - "message": "No docstring found for function \"trio.lowlevel.readinto_overlapped\"", - "name": "trio.lowlevel.readinto_overlapped" - }, - { - "message": "No docstring found for function \"trio.lowlevel.register_with_iocp\"", - "name": "trio.lowlevel.register_with_iocp" - }, - { - "message": "No docstring found for function \"trio.lowlevel.wait_overlapped\"", - "name": "trio.lowlevel.wait_overlapped" - }, - { - "message": "No docstring found for function \"trio.lowlevel.wait_readable\"", - "name": "trio.lowlevel.wait_readable" - }, - { - "message": "No docstring found for function \"trio.lowlevel.wait_writable\"", - "name": "trio.lowlevel.wait_writable" - }, - { - "message": "No docstring found for function \"trio.lowlevel.write_overlapped\"", - "name": "trio.lowlevel.write_overlapped" - }, - { - "message": "No docstring found for function \"trio.run_process\"", - "name": "trio.run_process" - }, - { - "message": "No docstring found for function \"trio.socket.fromshare\"", - "name": "trio.socket.fromshare" - }, - { - "message": "No docstring found for class \"trio.tests.TestsDeprecationWrapper\"", - "name": "trio.tests.TestsDeprecationWrapper" - } - ], - "exportedSymbolCounts": { - "withAmbiguousType": 0, - "withKnownType": 632, - "withUnknownType": 0 - }, - "ignoreUnknownTypesFromImports": true, - "missingClassDocStringCount": 1, - "missingDefaultParamCount": 0, - "missingFunctionDocStringCount": 12, - "moduleName": "trio", - "modules": [ - { - "name": "trio" - }, - { - "name": "trio.abc" - }, - { - "name": "trio.from_thread" - }, - { - "name": "trio.lowlevel" - }, - { - "name": "trio.socket" - }, - { - "name": "trio.testing" - }, - { - "name": "trio.tests" - }, - { - "name": "trio.to_thread" - } - ], - "otherSymbolCounts": { - "withAmbiguousType": 0, - "withKnownType": 691, - "withUnknownType": 0 - }, - "packageName": "trio" - } -} diff --git a/trio/_version.py b/trio/_version.py deleted file mode 100644 index 65242863a9..0000000000 --- a/trio/_version.py +++ /dev/null @@ -1,3 +0,0 @@ -# This file is imported from __init__.py and exec'd from setup.py - -__version__ = "0.22.2+dev" diff --git a/trio/tests.py b/trio/tests.py deleted file mode 100644 index 4ffb583a3a..0000000000 --- a/trio/tests.py +++ /dev/null @@ -1,38 +0,0 @@ -import importlib -import sys -from typing import Any - -from . import _tests -from ._deprecate import warn_deprecated - -warn_deprecated( - "trio.tests", - "0.22.1", - instead="trio._tests", - issue=274, -) - - -# This won't give deprecation warning on import, but will give a warning on use of any -# attribute in tests, and static analysis tools will also not see any content inside. -class TestsDeprecationWrapper: - __name__ = "trio.tests" - - def __getattr__(self, attr: str) -> Any: - warn_deprecated( - f"trio.tests.{attr}", - "0.22.1", - instead=f"trio._tests.{attr}", - issue=274, - ) - - # needed to access e.g. trio._tests.tools, although pytest doesn't need it - if not hasattr(_tests, attr): # pragma: no cover - importlib.import_module(f"trio._tests.{attr}", "trio._tests") - return attr - - return getattr(_tests, attr) - - -# https://stackoverflow.com/questions/2447353/getattr-on-a-module -sys.modules[__name__] = TestsDeprecationWrapper() # type: ignore[assignment]