diff --git a/.circleci/config.yml b/.circleci/config.yml index befa1f3d..ade95aa4 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -2,30 +2,46 @@ version: 2 jobs: build_docs: docker: - - image: circleci/python:3.8.5-buster + - image: cimg/base:current-22.04 steps: # Get our data and merge with upstream - checkout - - run: echo $(git log -1 --pretty=%B) | tee gitlog.txt - - run: echo ${CI_PULL_REQUEST//*pull\//} | tee merge.txt - - run: sudo apt update - - run: sudo apt install libglu1-mesa ffmpeg - run: - command: | - if [[ $(cat merge.txt) != "" ]]; then - echo "Merging $(cat merge.txt)"; - git pull --ff-only origin "refs/pull/$(cat merge.txt)/merge"; - fi - - run: echo "export DISPLAY=:99" >> $BASH_ENV - - run: echo "export _EXPYFUN_SILENT=true" >> $BASH_ENV - - run: echo "export PATH=~/.local/bin:$PATH" >> $BASH_ENV - - run: echo "export SOUND_CARD_BACKEND=pyglet >> $BASH_ENV" # rtmixer needs pulse, which is a huge pain to get running on CircleCI - - run: /sbin/start-stop-daemon --start --quiet --pidfile /tmp/custom_xvfb_99.pid --make-pidfile --background --exec /usr/bin/Xvfb -- :99 -screen 0 1400x900x24 -ac +extension GLX +render -noreset; - - run: pip install --quiet --upgrade --user pip - - run: pip install --quiet --upgrade --user numpy scipy matplotlib sphinx pillow pandas h5py mne pyglet psutil sphinx_bootstrap_theme sphinx_fontawesome numpydoc https://api.github.com/repos/sphinx-gallery/sphinx-gallery/zipball/master + name: Merge + command: | + set -eo pipefail + echo $(git log -1 --pretty=%B) | tee gitlog.txt + echo ${CI_PULL_REQUEST//*pull\//} | tee merge.txt + if [[ $(cat merge.txt) != "" ]]; then + echo "Merging $(cat merge.txt)"; + git pull --ff-only origin "refs/pull/$(cat merge.txt)/merge"; + fi + - run: + name: Prep env + command: | + set -eo pipefail + echo "set -eo pipefail" >> $BASH_ENV + sudo apt update + sudo apt install libglu1-mesa python3.10-venv python3-venv libxft2 ffmpeg ffmpeg xvfb + /sbin/start-stop-daemon --start --quiet --pidfile /tmp/custom_xvfb_99.pid --make-pidfile --background --exec /usr/bin/Xvfb -- :99 -screen 0 1400x900x24 -ac +extension GLX +render -noreset + python3.10 -m venv ~/python_env + echo "export PATH=~/.local/bin:$PATH" >> $BASH_ENV + echo "export SOUND_CARD_BACKEND=pyglet >> $BASH_ENV" # rtmixer needs pulse, which is a huge pain to get running on CircleCI + echo "export OPENBLAS_NUM_THREADS=4" >> $BASH_ENV + echo "export XDG_RUNTIME_DIR=/tmp/runtime-circleci" >> $BASH_ENV + echo "export PATH=~/.local/bin/:$PATH" >> $BASH_ENV + echo "export DISPLAY=:99" >> $BASH_ENV + echo "export _EXPYFUN_SILENT=true" >> $BASH_ENV + echo "source ~/python_env/bin/activate" >> $BASH_ENV + mkdir -p ~/.local/bin + ln -s ~/python_env/bin/python ~/.local/bin/python + echo "BASH_ENV:" + cat $BASH_ENV + - run: pip install --quiet --upgrade pip setuptools wheel + - run: pip install --quiet --upgrade numpy scipy matplotlib sphinx pandas h5py mne "pyglet<2.0" psutil pydata-sphinx-theme numpydoc git+https://github.com/sphinx-gallery/sphinx-gallery + - run: python -m pip install -ve . - run: python -c "import mne; mne.sys_info()" - run: python -c "import pyglet; print(pyglet.version)" - - run: python setup.py develop --user - run: cd doc && make html - store_artifacts: @@ -44,7 +60,7 @@ jobs: steps: - add_ssh_keys: fingerprints: - - d4:4f:25:af:ed:5f:61:01:dc:b6:3a:9e:b5:d6:8d:d1 + - "25:b7:f2:bf:d7:38:6d:b6:c7:78:41:05:01:f8:41:7b" - attach_workspace: at: /tmp/_build - run: diff --git a/.coveragerc b/.coveragerc index 7978105a..9a1bf336 100644 --- a/.coveragerc +++ b/.coveragerc @@ -5,4 +5,3 @@ include = */expyfun/* omit = */setup.py */expyfun/codeblocks/* - */expyfun/_externals/* diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 00000000..32db017b --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1 @@ +c7c1b18440968e2def388dff25118e13fe3c3b9a # ruff format diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..d57929b9 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,10 @@ +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "monthly" + groups: + actions: + patterns: + - "*" diff --git a/.github/release.yml b/.github/release.yml new file mode 100644 index 00000000..9d1e0987 --- /dev/null +++ b/.github/release.yml @@ -0,0 +1,5 @@ +changelog: + exclude: + authors: + - dependabot + - pre-commit-ci diff --git a/.github/workflows/circle_artifacts.yml b/.github/workflows/circle_artifacts.yml index e0294320..1026bc29 100644 --- a/.github/workflows/circle_artifacts.yml +++ b/.github/workflows/circle_artifacts.yml @@ -1,12 +1,15 @@ -on: [status] +on: [status] # yamllint disable-line rule:truthy jobs: circleci_artifacts_redirector_job: + if: "${{ startsWith(github.event.context, 'ci/circleci: build_docs') }}" runs-on: ubuntu-20.04 name: Run CircleCI artifacts redirector steps: - name: GitHub Action step - uses: larsoner/circleci-artifacts-redirector-action@master + uses: scientific-python/circleci-artifacts-redirector-action@master with: repo-token: ${{ secrets.GITHUB_TOKEN }} + api-token: ${{ secrets.CIRCLECI_TOKEN }} artifact-path: 0/html/index.html circleci-jobs: build_docs + job-title: Check the rendered docs here! diff --git a/.github/workflows/codespell_and_flake.yml b/.github/workflows/codespell_and_flake.yml deleted file mode 100644 index a96ab6a0..00000000 --- a/.github/workflows/codespell_and_flake.yml +++ /dev/null @@ -1,41 +0,0 @@ -name: 'codespell_and_flake' -concurrency: - group: ${{ github.workflow }}-${{ github.event.number }}-${{ github.event.ref }} - cancel-in-progress: true -on: - push: - branches: - - '*' - pull_request: - branches: - - '*' - -jobs: - style: - runs-on: ubuntu-20.04 - env: - CODESPELL_DIRS: 'expyfun/ doc/ examples/' - CODESPELL_SKIPS: '*.log,*.doctree,*.pickle,*.png,*.js,*.html,*.orig' - steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 - with: - python-version: '3.9' - architecture: 'x64' - - run: | - python -m pip install --upgrade pip setuptools wheel - python -m pip install flake8 pydocstyle check-manifest numpy - name: 'Install dependencies' - - uses: rbialon/flake8-annotations@v1 - name: 'Setup flake8 annotations' - - run: make flake - - run: make docstyle - - run: make check-manifest - - uses: GuillaumeFavelier/actions-codespell@feat/quiet_level - with: - path: ${{ env.CODESPELL_DIRS }} - skip: ${{ env.CODESPELL_SKIPS }} - quiet_level: '3' - builtin: 'clear,rare,informal,names' - ignore_words_file: 'ignore_words.txt' - name: 'make codespell-error' diff --git a/.github/workflows/compat_old.yml b/.github/workflows/compat_old.yml deleted file mode 100644 index 66817b24..00000000 --- a/.github/workflows/compat_old.yml +++ /dev/null @@ -1,57 +0,0 @@ -name: 'compat' -concurrency: - group: ${{ github.workflow }}-${{ github.event.number }}-${{ github.event.ref }} - cancel-in-progress: true -on: - push: - branches: - - '*' - pull_request: - branches: - - '*' - -jobs: - job: - name: conda ${{ matrix.python }} - runs-on: ubuntu-20.04 - defaults: - run: - shell: bash -el {0} - strategy: - matrix: - python: ['3.7'] - env: - DISPLAY: ':99.0' - steps: - - uses: actions/checkout@v2 - name: Checkout - - run: /sbin/start-stop-daemon --start --quiet --pidfile /tmp/custom_xvfb_99.pid --make-pidfile --background --exec /usr/bin/Xvfb -- :99 -screen 0 1400x900x24 -ac +extension GLX +render -noreset - name: Start Xvfb - - run: sudo apt update -q && sudo apt install -q libglu1-mesa - name: Install system dependencies - - uses: conda-incubator/setup-miniconda@v2 - with: - activate-environment: 'test' - python-version: ${{ matrix.python }} - environment-file: 'environment_test.yml' - name: 'Setup conda' - - run: | - set -e - conda remove -n test pandas h5py - pip install sounddevice rtmixer "pyglet<1.4" - name: Dependencies - - run: git clone --depth=1 git://github.com/LABSN/sound-ci-helpers.git && sound-ci-helpers/auto.sh - name: Get sound working - - run: python -m sounddevice - name: List sound devices - - run: python -c "import pyglet; print(pyglet.version)" - name: Print Pyglet version - - run: python -c "import matplotlib.pyplot as plt" - name: Make sure matplotlib works - - run: python setup.py develop - name: Install - - run: pytest --tb=short --cov=expyfun --cov-report=xml expyfun - name: Pytest - - uses: codecov/codecov-action@v1 - if: success() - name: 'Codecov' diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml deleted file mode 100644 index d5cbb604..00000000 --- a/.github/workflows/linux.yml +++ /dev/null @@ -1,58 +0,0 @@ -name: 'linux' -concurrency: - group: ${{ github.workflow }}-${{ github.event.number }}-${{ github.event.ref }} - cancel-in-progress: true -on: - push: - branches: - - '*' - pull_request: - branches: - - '*' - -jobs: - job: - name: pip ${{ matrix.python }} - runs-on: ubuntu-20.04 - defaults: - run: - shell: bash -el {0} - strategy: - matrix: - python: ['3.10'] - env: - DISPLAY: ':99.0' - _EXPYFUN_SILENT: 'true' - SOUND_CARD_BACKEND: 'pyglet' - steps: - - uses: actions/checkout@v2 - name: Checkout - - run: /sbin/start-stop-daemon --start --quiet --pidfile /tmp/custom_xvfb_99.pid --make-pidfile --background --exec /usr/bin/Xvfb -- :99 -screen 0 1400x900x24 -ac +extension GLX +render -noreset - name: Start Xvfb - - run: sudo apt update -q && sudo apt install -q libavutil56 libavcodec58 libavformat58 libswscale5 libglu1-mesa gstreamer1.0-alsa gstreamer1.0-libav python3-gst-1.0 - name: Install system dependencies - - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.PYTHON_VERSION }} - name: 'Setup python' - - run: pip install --upgrade pip setuptools wheel - name: Upgrade pip - - run: pip install --upgrade sounddevice rtmixer "pyglet<1.6" pyglet_ffmpeg scipy matplotlib pandas h5py coverage mne numpydoc pytest pytest-cov pytest-timeout pillow joblib codecov - name: Dependencies - - run: git clone --depth=1 git://github.com/LABSN/sound-ci-helpers.git && sound-ci-helpers/auto.sh - name: Get sound working - - run: python -m sounddevice - name: List sound devices - - run: python -c "import pyglet; print(pyglet.version)" - name: Print Pyglet version - - run: python -c "import matplotlib.pyplot as plt" - name: Make sure matplotlib works - - run: pip install -ve . - name: Install - - run: python -c "import expyfun; expyfun._utils._has_video(raise_error=True)" - name: Check video - - run: pytest --tb=short --cov=expyfun --cov-report=xml expyfun - name: Pytest - - uses: codecov/codecov-action@v1 - if: success() - name: Codecov diff --git a/.github/workflows/macos_conda.yml b/.github/workflows/macos_conda.yml deleted file mode 100644 index 16a8039a..00000000 --- a/.github/workflows/macos_conda.yml +++ /dev/null @@ -1,42 +0,0 @@ -name: 'macos' -concurrency: - group: ${{ github.workflow }}-${{ github.event.number }}-${{ github.event.ref }} - cancel-in-progress: true -on: - push: - branches: - - '*' - pull_request: - branches: - - '*' - -jobs: - job: - name: conda ${{ matrix.python }} - runs-on: macos-latest - strategy: - matrix: - python: ['3.10'] - defaults: - run: - shell: bash -el {0} - steps: - - uses: actions/checkout@v2 - - uses: conda-incubator/setup-miniconda@v2 - with: - activate-environment: 'test' - python-version: ${{ matrix.python }} - environment-file: 'environment_test.yml' - name: 'Setup conda' - - run: pip install sounddevice rtmixer "pyglet<1.6" - - run: git clone --depth=1 git://github.com/LABSN/sound-ci-helpers.git && sound-ci-helpers/auto.sh - name: Get sound working - - run: python -m sounddevice - - run: python -c "import pyglet; print(pyglet.version)" - - run: python -c "import matplotlib.pyplot as plt" - - run: pip install -ve . - - run: python -c "import expyfun; expyfun._utils._has_video(raise_error=True)" - - run: pytest --tb=short --cov=expyfun --cov-report=xml expyfun - - uses: codecov/codecov-action@v1 - if: success() - name: 'Codecov' diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 00000000..a684dc71 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,107 @@ +name: 'tests' +concurrency: + group: ${{ github.workflow }}-${{ github.event.number }}-${{ github.event.ref }} + cancel-in-progress: true +on: # yamllint disable-line rule:truthy + push: + branches: + - '*' + pull_request: + branches: + - '*' + +jobs: + job: + name: ${{ matrix.os }} ${{ matrix.kind }} + continue-on-error: true + runs-on: ${{ matrix.os }} + defaults: + run: + shell: bash -el {0} + strategy: + matrix: + include: + # 24.04 works except for the video test, even though it works locally on 24.10 + - os: ubuntu-22.04 + kind: pip + python: '3.12' + # ARM64 will probably need to wait until + # - os: 'macos-latest' # arm64 + # kind: 'conda' + # python: '3.12' + - os: 'macos-13' # intel + kind: 'conda' + python: '3.12' + # TODO: There is a bug on Python 3.12 on Windows :( + - os: 'windows-latest' + kind: 'pip' + python: '3.11' + - os: 'ubuntu-20.04' + kind: 'old' + python: '3.8' + steps: + - uses: actions/checkout@v4 + - uses: LABSN/sound-ci-helpers@v1 + - uses: pyvista/setup-headless-display-action@main + with: + qt: true + pyvista: false + # Use -dev here just to get whichever version is right (e.g., 22.04 has a different version from 24.04) + - run: sudo apt install -q libavutil-dev libavcodec-dev libavformat-dev libswscale-dev libglu1-mesa gstreamer1.0-alsa gstreamer1.0-libav + if: ${{ startsWith(matrix.os, 'ubuntu') }} + - run: powershell tools/get_video.ps1 + if: ${{ startsWith(matrix.os, 'windows') }} + - run: | + set -xeo pipefail + if [[ "${{ runner.os }}" == "Windows" ]]; then + echo "Setting env vars for Windows" + echo "AZURE_CI_WINDOWS=true" >> $GITHUB_ENV + echo "SOUND_CARD_BACKEND=rtmixer" >> $GITHUB_ENV + echo "SOUND_CARD_NAME=Speakers" >> $GITHUB_ENV + echo "SOUND_CARD_FS=48000" >> $GITHUB_ENV + echo "SOUND_CARD_API=Windows WDM-KS" >> $GITHUB_ENV + elif [[ "${{ runner.os }}" == "Linux" ]]; then + echo "Setting env vars for Linux" + echo "_EXPYFUN_SILENT=true" >> $GITHUB_ENV + echo "SOUND_CARD_BACKEND=pyglet" >> $GITHUB_ENV + elif [[ "${{ runner.os }}" == "macOS" ]]; then + echo "Setting env vars for macOS" + fi + name: Set env vars + - uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python }} + if: matrix.kind != 'conda' + - uses: mamba-org/setup-micromamba@v1 + with: + environment-file: 'environment_test.yml' + create-args: python=${{ matrix.python }} + init-shell: bash + if: matrix.kind == 'conda' + # Pyglet pin: https://github.com/pyglet/pyglet/issues/1089 (and need OpenGL2 compat for Pyglet>=2, too) + - run: python -m pip install --upgrade pip setuptools wheel sounddevice "pyglet<1.5.28" + - run: python -m pip install --upgrade --only-binary="rtmixer,scipy,matplotlib,pandas,numpy" rtmixer pyglet-ffmpeg scipy matplotlib pandas h5py mne numpydoc pillow joblib + if: matrix.kind == 'pip' + # arm64 has issues with rtmixer / PortAudio + - run: python -m pip install --only-binary="rtmixer" rtmixer + if: matrix.kind == 'conda' && matrix.os != 'macos-latest' + - run: python -m pip install --only-binary="rtmixer,numpy,scipy,matplotlib" rtmixer "pyglet<1.4" numpy scipy matplotlib "pillow<8" + if: matrix.kind == 'old' + - run: python -m pip install tdtpy + if: startsWith(matrix.os, 'windows') + - run: python -m sounddevice + - run: | + set -o pipefail + python -m sounddevice | grep "[82] out" + name: Check that there is some output device + - run: python -c "import pyglet; print(pyglet.version)" + - run: python -c "import matplotlib.pyplot as plt" + - run: pip install -ve .[test] + # Video hangs on macOS arm64, not sure why + - run: python -c "import expyfun; expyfun._utils._has_video(raise_error=True)" + if: matrix.kind != 'old' && matrix.os != 'macos-latest' + - run: pytest expyfun --cov-report=xml --cov=expyfun + - uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} + if: always() diff --git a/.gitignore b/.gitignore index 128a90ef..0961ef97 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ *.orig .vscode doc/generated +doc/sg_execution_times.rst .DS_Store # C extensions diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..3c5c3232 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,43 @@ +repos: + # Ruff mne + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.3.7 + hooks: + - id: ruff + name: ruff lint expyfun + args: ["--fix"] + files: ^expyfun/ + - id: ruff + name: ruff lint doc and examples + # D103: missing docstring in public function + # D400: docstring first line must end with period + args: ["--ignore=D103,D400", "--fix"] + files: ^doc/|^examples/ + - id: ruff-format + files: ^expyfun/|^doc/|^examples/ + + # Codespell + - repo: https://github.com/codespell-project/codespell + rev: v2.2.6 + hooks: + - id: codespell + additional_dependencies: + - tomli + files: ^expyfun/|^doc/|^examples/ + types_or: [python, bib, rst, inc] + + # yamllint + - repo: https://github.com/adrienverge/yamllint.git + rev: v1.35.1 + hooks: + - id: yamllint + args: [--strict, -c, .yamllint.yml] + + # rstcheck + - repo: https://github.com/rstcheck/rstcheck.git + rev: v6.2.0 + hooks: + - id: rstcheck + additional_dependencies: + - tomli + files: ^doc/.*\.(rst|inc)$ diff --git a/.yamllint.yml b/.yamllint.yml new file mode 100644 index 00000000..f54915d4 --- /dev/null +++ b/.yamllint.yml @@ -0,0 +1,8 @@ +extends: default + +ignore: | + .github/workflows/codeql-analysis.yml + +rules: + line-length: disable + document-start: disable diff --git a/MANIFEST.in b/MANIFEST.in index e0e8ff96..7f7f4c2f 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -10,7 +10,7 @@ recursive-include expyfun/data * ### Exclude -exclude make +exclude tools exclude doc exclude .circleci exclude Makefile @@ -21,6 +21,6 @@ exclude .mailmap recursive-exclude expyfun *.pyc recursive-exclude doc * -recursive-exclude make * +recursive-exclude tools * recursive-exclude examples *.tab recursive-exclude .circleci * diff --git a/appveyor.yml b/appveyor.yml deleted file mode 100644 index 4cabffdf..00000000 --- a/appveyor.yml +++ /dev/null @@ -1,39 +0,0 @@ -environment: - matrix: - - PYTHON: "C:\\Python37-x64" - PYTHON_ARCH: "64" - -platform: - -x64 - -install: - - "SET PATH=%PYTHON%;%PYTHON%\\Scripts;%PATH%" - - "python --version" - - "pip install -q numpy scipy matplotlib coverage setuptools h5py pandas pytest pytest-cov pytest-timeout pytest-xdist codecov \"pyglet!=1.5.16\" mne tdtpy joblib numpydoc pillow" - - "python -c \"import mne; mne.sys_info()\"" - - "python -c \"import pyglet; print(pyglet.version)\"" - # Get a virtual sound card / VBAudioVACWDM device - - "git clone --depth 1 git://github.com/LABSN/sound-ci-helpers.git" - - "powershell sound-ci-helpers/windows/setup_sound.ps1" - - "pip install rtmixer" - - "python -m sounddevice" - # OpenGL (should provide a Gallium driver) - - "git clone --depth 1 git://github.com/pyvista/gl-ci-helpers.git" - - "powershell gl-ci-helpers/appveyor/install_opengl.ps1" - - "python -c \"import pyglet; r = pyglet.gl.gl_info.get_renderer(); print(r); assert 'gallium' in r.lower()\"" - # expyfun - - "powershell make/get_video.ps1" - - "python setup.py develop" - -build: false # Not a C# project, build stuff at the test step instead. - -test_script: - # Ensure that video works - - "python -c \"from ctypes import cdll; print(cdll.LoadLibrary('avcodec-58'))\"" - - "python -c \"from ctypes import cdll; print(cdll.LoadLibrary('avformat-58'))\"" - - "python -c \"import expyfun; assert expyfun._utils._has_video(raise_error=True)\"" - # Run the project tests - - "pytest -n 1 --tb=short --cov=expyfun expyfun" - -on_success: - - "codecov" diff --git a/azure-pipelines.yml b/azure-pipelines.yml deleted file mode 100644 index 0182c3ed..00000000 --- a/azure-pipelines.yml +++ /dev/null @@ -1,115 +0,0 @@ -trigger: - # start a new build for every push - batch: False - branches: - include: - - main - -stages: - -- stage: Check - jobs: - - job: Skip - pool: - vmImage: 'ubuntu-18.04' - variables: - DECODE_PERCENTS: 'false' - RET: 'true' - steps: - - bash: | - git_log=`git log --max-count=1 --skip=1 --pretty=format:"%s"` - echo "##vso[task.setvariable variable=log]$git_log" - - bash: echo "##vso[task.setvariable variable=RET]false" - condition: or(contains(variables.log, '[skip azp]'), contains(variables.log, '[azp skip]'), contains(variables.log, '[skip ci]'), contains(variables.log, '[ci skip]')) - - bash: echo "##vso[task.setvariable variable=start_main;isOutput=true]$RET" - name: result - -- stage: Main - condition: and(succeeded(), eq(dependencies.Check.outputs['Skip.result.start_main'], 'true')) - dependsOn: Check - variables: - AZURE_CI: 'true' - jobs: - - job: Windows - pool: - vmIMage: 'windows-latest' - variables: - MNE_LOGGING_LEVEL: 'warning' - MNE_FORCE_SERIAL: 'true' - PIP_DEPENDENCIES: 'codecov' - OPENBLAS_NUM_THREADS: 1 - AZURE_CI_WINDOWS: 'true' - SOUND_CARD_NAME: 'Speakers' - SOUND_CARD_FS: '44100' - SOUND_CARD_API: 'MME' - strategy: - maxParallel: 4 - matrix: - Python37: - PYTHON_VERSION: '3.7' - Python39: - PYTHON_VERSION: '3.9' - steps: - - task: UsePythonVersion@0 - inputs: - versionSpec: $(PYTHON_VERSION) - architecture: 'x64' - addToPath: true - - script: echo "##vso[task.prependpath]C:\Users\VssAdministrator\AppData\Roaming\Python\Python39\site-packages\pywin32_system32;" - displayName: Add local bin to PATH - condition: in(variables['PYTHON_VERSION'], '3.9') - - bash: | - set -e - pip install --user --upgrade pip setuptools wheel - pip install --user --upgrade numpy scipy matplotlib - if [[ "$PYTHON_VERSION" == "3.9" ]]; then - # Until https://github.com/pyglet/pyglet/pull/516 is reverted or fixed, we need to use an older one - pip install --user --upgrade https://github.com/pyglet/pyglet/zipball/pyglet-1.5-maintenance - else - pip install --user --upgrade "pyglet!=1.5.16" - fi - pip install --user -q coverage setuptools h5py pandas pytest pytest-cov pytest-timeout codecov pyglet-ffmpeg mne tdtpy joblib numpydoc pillow - python -c "import mne; mne.sys_info()" - python -c "import matplotlib.pyplot as plt" - python -c "import pyglet; print(pyglet.version)" - python -c "import tdt; print(tdt.__version__)" - displayName: 'Install pip dependencies' - - bash: | - set -e - git clone --depth 1 git://github.com/LABSN/sound-ci-helpers.git - sound-ci-helpers/auto.sh - pip install -q --user rtmixer - python -m sounddevice - displayName: 'Install rtmixer' - - bash: | - set -e - git clone --depth 1 git://github.com/pyvista/gl-ci-helpers.git - powershell gl-ci-helpers/appveyor/install_opengl.ps1 - python -c "import pyglet; r = pyglet.gl.gl_info.get_renderer(); print(r); assert 'gallium' in r.lower()" - displayName: 'Get OpenGL' - - powershell: | - powershell make/get_video.ps1 - displayName: 'Get video support' - - powershell: | - python -c "from ctypes import cdll; print(cdll.LoadLibrary('avcodec-58'))" - displayName: 'Check avcodec' - - powershell: | - python -c "import expyfun; expyfun._utils._has_video(raise_error=True)" - displayName: 'Check video support' - - bash: | - python setup.py develop - displayName: 'Install' - - bash: | - pytest --tb=short --cov=expyfun expyfun - displayName: 'Run tests' - - bash: | - codecov --root %BUILD_REPOSITORY_LOCALPATH% -t %CODECOV_TOKEN% - displayName: 'Codecov' - env: - CODECOV_TOKEN: $(CODECOV_TOKEN) - condition: always() - - task: PublishTestResults@2 - inputs: - testResultsFiles: 'junit-*.xml' - testRunTitle: 'Publish test results for Python $(python.version)' - condition: always() diff --git a/codecov.yml b/codecov.yml index a011f80f..408f379c 100644 --- a/codecov.yml +++ b/codecov.yml @@ -4,7 +4,7 @@ github_checks: # too noisy, even though "a" interactively disables them codecov: notify: - require_ci_to_pass: no + require_ci_to_pass: false coverage: status: diff --git a/doc/_static/font-awesome.css b/doc/_static/font-awesome.css deleted file mode 100644 index c1ecf734..00000000 --- a/doc/_static/font-awesome.css +++ /dev/null @@ -1,2337 +0,0 @@ -/*! - * Font Awesome 4.7.0 by @davegandy - http://fontawesome.io - @fontawesome - * License - http://fontawesome.io/license (Font: SIL OFL 1.1, CSS: MIT License) - */ -/* FONT PATH - * -------------------------- */ -@font-face { - font-family: 'FontAwesome'; - src: url('./fontawesome-webfont.eot?v=4.7.0'); - src: url('./fontawesome-webfont.eot?#iefix&v=4.7.0') format('embedded-opentype'), url('./fontawesome-webfont.woff2?v=4.7.0') format('woff2'), url('./fontawesome-webfont.woff?v=4.7.0') format('woff'), url('./fontawesome-webfont.ttf?v=4.7.0') format('truetype'); - font-weight: normal; - font-style: normal; -} -.fa { - display: inline-block; - font: normal normal normal 14px/1 FontAwesome; - font-size: inherit; - text-rendering: auto; - -webkit-font-smoothing: antialiased; - -moz-osx-font-smoothing: grayscale; -} -/* makes the font 33% larger relative to the icon container */ -.fa-lg { - font-size: 1.33333333em; - line-height: 0.75em; - vertical-align: -15%; -} -.fa-2x { - font-size: 2em; -} -.fa-3x { - font-size: 3em; -} -.fa-4x { - font-size: 4em; -} -.fa-5x { - font-size: 5em; -} -.fa-fw { - width: 1.28571429em; - text-align: center; -} -.fa-ul { - padding-left: 0; - margin-left: 2.14285714em; - list-style-type: none; -} -.fa-ul > li { - position: relative; -} -.fa-li { - position: absolute; - left: -2.14285714em; - width: 2.14285714em; - top: 0.14285714em; - text-align: center; -} -.fa-li.fa-lg { - left: -1.85714286em; -} -.fa-border { - padding: .2em .25em .15em; - border: solid 0.08em #eeeeee; - border-radius: .1em; -} -.fa-pull-left { - float: left; -} -.fa-pull-right { - float: right; -} -.fa.fa-pull-left { - margin-right: .3em; -} -.fa.fa-pull-right { - margin-left: .3em; -} -/* Deprecated as of 4.4.0 */ -.pull-right { - float: right; -} -.pull-left { - float: left; -} -.fa.pull-left { - margin-right: .3em; -} -.fa.pull-right { - margin-left: .3em; -} -.fa-spin { - -webkit-animation: fa-spin 2s infinite linear; - animation: fa-spin 2s infinite linear; -} -.fa-pulse { - -webkit-animation: fa-spin 1s infinite steps(8); - animation: fa-spin 1s infinite steps(8); -} -@-webkit-keyframes fa-spin { - 0% { - -webkit-transform: rotate(0deg); - transform: rotate(0deg); - } - 100% { - -webkit-transform: rotate(359deg); - transform: rotate(359deg); - } -} -@keyframes fa-spin { - 0% { - -webkit-transform: rotate(0deg); - transform: rotate(0deg); - } - 100% { - -webkit-transform: rotate(359deg); - transform: rotate(359deg); - } -} -.fa-rotate-90 { - -ms-filter: "progid:DXImageTransform.Microsoft.BasicImage(rotation=1)"; - -webkit-transform: rotate(90deg); - -ms-transform: rotate(90deg); - transform: rotate(90deg); -} -.fa-rotate-180 { - -ms-filter: "progid:DXImageTransform.Microsoft.BasicImage(rotation=2)"; - -webkit-transform: rotate(180deg); - -ms-transform: rotate(180deg); - transform: rotate(180deg); -} -.fa-rotate-270 { - -ms-filter: "progid:DXImageTransform.Microsoft.BasicImage(rotation=3)"; - -webkit-transform: rotate(270deg); - -ms-transform: rotate(270deg); - transform: rotate(270deg); -} -.fa-flip-horizontal { - -ms-filter: "progid:DXImageTransform.Microsoft.BasicImage(rotation=0, mirror=1)"; - -webkit-transform: scale(-1, 1); - -ms-transform: scale(-1, 1); - transform: scale(-1, 1); -} -.fa-flip-vertical { - -ms-filter: "progid:DXImageTransform.Microsoft.BasicImage(rotation=2, mirror=1)"; - -webkit-transform: scale(1, -1); - -ms-transform: scale(1, -1); - transform: scale(1, -1); -} -:root .fa-rotate-90, -:root .fa-rotate-180, -:root .fa-rotate-270, -:root .fa-flip-horizontal, -:root .fa-flip-vertical { - filter: none; -} -.fa-stack { - position: relative; - display: inline-block; - width: 2em; - height: 2em; - line-height: 2em; - vertical-align: middle; -} -.fa-stack-1x, -.fa-stack-2x { - position: absolute; - left: 0; - width: 100%; - text-align: center; -} -.fa-stack-1x { - line-height: inherit; -} -.fa-stack-2x { - font-size: 2em; -} -.fa-inverse { - color: #ffffff; -} -/* Font Awesome uses the Unicode Private Use Area (PUA) to ensure screen - readers do not read off random characters that represent icons */ -.fa-glass:before { - content: "\f000"; -} -.fa-music:before { - content: "\f001"; -} -.fa-search:before { - content: "\f002"; -} -.fa-envelope-o:before { - content: "\f003"; -} -.fa-heart:before { - content: "\f004"; -} -.fa-star:before { - content: "\f005"; -} -.fa-star-o:before { - content: "\f006"; -} -.fa-user:before { - content: "\f007"; -} -.fa-film:before { - content: "\f008"; -} -.fa-th-large:before { - content: "\f009"; -} -.fa-th:before { - content: "\f00a"; -} -.fa-th-list:before { - content: "\f00b"; -} -.fa-check:before { - content: "\f00c"; -} -.fa-remove:before, -.fa-close:before, -.fa-times:before { - content: "\f00d"; -} -.fa-search-plus:before { - content: "\f00e"; -} -.fa-search-minus:before { - content: "\f010"; -} -.fa-power-off:before { - content: "\f011"; -} -.fa-signal:before { - content: "\f012"; -} -.fa-gear:before, -.fa-cog:before { - content: "\f013"; -} -.fa-trash-o:before { - content: "\f014"; -} -.fa-home:before { - content: "\f015"; -} -.fa-file-o:before { - content: "\f016"; -} -.fa-clock-o:before { - content: "\f017"; -} -.fa-road:before { - content: "\f018"; -} -.fa-download:before { - content: "\f019"; -} -.fa-arrow-circle-o-down:before { - content: "\f01a"; -} -.fa-arrow-circle-o-up:before { - content: "\f01b"; -} -.fa-inbox:before { - content: "\f01c"; -} -.fa-play-circle-o:before { - content: "\f01d"; -} -.fa-rotate-right:before, -.fa-repeat:before { - content: "\f01e"; -} -.fa-refresh:before { - content: "\f021"; -} -.fa-list-alt:before { - content: "\f022"; -} -.fa-lock:before { - content: "\f023"; -} -.fa-flag:before { - content: "\f024"; -} -.fa-headphones:before { - content: "\f025"; -} -.fa-volume-off:before { - content: "\f026"; -} -.fa-volume-down:before { - content: "\f027"; -} -.fa-volume-up:before { - content: "\f028"; -} -.fa-qrcode:before { - content: "\f029"; -} -.fa-barcode:before { - content: "\f02a"; -} -.fa-tag:before { - content: "\f02b"; -} -.fa-tags:before { - content: "\f02c"; -} -.fa-book:before { - content: "\f02d"; -} -.fa-bookmark:before { - content: "\f02e"; -} -.fa-print:before { - content: "\f02f"; -} -.fa-camera:before { - content: "\f030"; -} -.fa-font:before { - content: "\f031"; -} -.fa-bold:before { - content: "\f032"; -} -.fa-italic:before { - content: "\f033"; -} -.fa-text-height:before { - content: "\f034"; -} -.fa-text-width:before { - content: "\f035"; -} -.fa-align-left:before { - content: "\f036"; -} -.fa-align-center:before { - content: "\f037"; -} -.fa-align-right:before { - content: "\f038"; -} -.fa-align-justify:before { - content: "\f039"; -} -.fa-list:before { - content: "\f03a"; -} -.fa-dedent:before, -.fa-outdent:before { - content: "\f03b"; -} -.fa-indent:before { - content: "\f03c"; -} -.fa-video-camera:before { - content: "\f03d"; -} -.fa-photo:before, -.fa-image:before, -.fa-picture-o:before { - content: "\f03e"; -} -.fa-pencil:before { - content: "\f040"; -} -.fa-map-marker:before { - content: "\f041"; -} -.fa-adjust:before { - content: "\f042"; -} -.fa-tint:before { - content: "\f043"; -} -.fa-edit:before, -.fa-pencil-square-o:before { - content: "\f044"; -} -.fa-share-square-o:before { - content: "\f045"; -} -.fa-check-square-o:before { - content: "\f046"; -} -.fa-arrows:before { - content: "\f047"; -} -.fa-step-backward:before { - content: "\f048"; -} -.fa-fast-backward:before { - content: "\f049"; -} -.fa-backward:before { - content: "\f04a"; -} -.fa-play:before { - content: "\f04b"; -} -.fa-pause:before { - content: "\f04c"; -} -.fa-stop:before { - content: "\f04d"; -} -.fa-forward:before { - content: "\f04e"; -} -.fa-fast-forward:before { - content: "\f050"; -} -.fa-step-forward:before { - content: "\f051"; -} -.fa-eject:before { - content: "\f052"; -} -.fa-chevron-left:before { - content: "\f053"; -} -.fa-chevron-right:before { - content: "\f054"; -} -.fa-plus-circle:before { - content: "\f055"; -} -.fa-minus-circle:before { - content: "\f056"; -} -.fa-times-circle:before { - content: "\f057"; -} -.fa-check-circle:before { - content: "\f058"; -} -.fa-question-circle:before { - content: "\f059"; -} -.fa-info-circle:before { - content: "\f05a"; -} -.fa-crosshairs:before { - content: "\f05b"; -} -.fa-times-circle-o:before { - content: "\f05c"; -} -.fa-check-circle-o:before { - content: "\f05d"; -} -.fa-ban:before { - content: "\f05e"; -} -.fa-arrow-left:before { - content: "\f060"; -} -.fa-arrow-right:before { - content: "\f061"; -} -.fa-arrow-up:before { - content: "\f062"; -} -.fa-arrow-down:before { - content: "\f063"; -} -.fa-mail-forward:before, -.fa-share:before { - content: "\f064"; -} -.fa-expand:before { - content: "\f065"; -} -.fa-compress:before { - content: "\f066"; -} -.fa-plus:before { - content: "\f067"; -} -.fa-minus:before { - content: "\f068"; -} -.fa-asterisk:before { - content: "\f069"; -} -.fa-exclamation-circle:before { - content: "\f06a"; -} -.fa-gift:before { - content: "\f06b"; -} -.fa-leaf:before { - content: "\f06c"; -} -.fa-fire:before { - content: "\f06d"; -} -.fa-eye:before { - content: "\f06e"; -} -.fa-eye-slash:before { - content: "\f070"; -} -.fa-warning:before, -.fa-exclamation-triangle:before { - content: "\f071"; -} -.fa-plane:before { - content: "\f072"; -} -.fa-calendar:before { - content: "\f073"; -} -.fa-random:before { - content: "\f074"; -} -.fa-comment:before { - content: "\f075"; -} -.fa-magnet:before { - content: "\f076"; -} -.fa-chevron-up:before { - content: "\f077"; -} -.fa-chevron-down:before { - content: "\f078"; -} -.fa-retweet:before { - content: "\f079"; -} -.fa-shopping-cart:before { - content: "\f07a"; -} -.fa-folder:before { - content: "\f07b"; -} -.fa-folder-open:before { - content: "\f07c"; -} -.fa-arrows-v:before { - content: "\f07d"; -} -.fa-arrows-h:before { - content: "\f07e"; -} -.fa-bar-chart-o:before, -.fa-bar-chart:before { - content: "\f080"; -} -.fa-twitter-square:before { - content: "\f081"; -} -.fa-facebook-square:before { - content: "\f082"; -} -.fa-camera-retro:before { - content: "\f083"; -} -.fa-key:before { - content: "\f084"; -} -.fa-gears:before, -.fa-cogs:before { - content: "\f085"; -} -.fa-comments:before { - content: "\f086"; -} -.fa-thumbs-o-up:before { - content: "\f087"; -} -.fa-thumbs-o-down:before { - content: "\f088"; -} -.fa-star-half:before { - content: "\f089"; -} -.fa-heart-o:before { - content: "\f08a"; -} -.fa-sign-out:before { - content: "\f08b"; -} -.fa-linkedin-square:before { - content: "\f08c"; -} -.fa-thumb-tack:before { - content: "\f08d"; -} -.fa-external-link:before { - content: "\f08e"; -} -.fa-sign-in:before { - content: "\f090"; -} -.fa-trophy:before { - content: "\f091"; -} -.fa-github-square:before { - content: "\f092"; -} -.fa-upload:before { - content: "\f093"; -} -.fa-lemon-o:before { - content: "\f094"; -} -.fa-phone:before { - content: "\f095"; -} -.fa-square-o:before { - content: "\f096"; -} -.fa-bookmark-o:before { - content: "\f097"; -} -.fa-phone-square:before { - content: "\f098"; -} -.fa-twitter:before { - content: "\f099"; -} -.fa-facebook-f:before, -.fa-facebook:before { - content: "\f09a"; -} -.fa-github:before { - content: "\f09b"; -} -.fa-unlock:before { - content: "\f09c"; -} -.fa-credit-card:before { - content: "\f09d"; -} -.fa-feed:before, -.fa-rss:before { - content: "\f09e"; -} -.fa-hdd-o:before { - content: "\f0a0"; -} -.fa-bullhorn:before { - content: "\f0a1"; -} -.fa-bell:before { - content: "\f0f3"; -} -.fa-certificate:before { - content: "\f0a3"; -} -.fa-hand-o-right:before { - content: "\f0a4"; -} -.fa-hand-o-left:before { - content: "\f0a5"; -} -.fa-hand-o-up:before { - content: "\f0a6"; -} -.fa-hand-o-down:before { - content: "\f0a7"; -} -.fa-arrow-circle-left:before { - content: "\f0a8"; -} -.fa-arrow-circle-right:before { - content: "\f0a9"; -} -.fa-arrow-circle-up:before { - content: "\f0aa"; -} -.fa-arrow-circle-down:before { - content: "\f0ab"; -} -.fa-globe:before { - content: "\f0ac"; -} -.fa-wrench:before { - content: "\f0ad"; -} -.fa-tasks:before { - content: "\f0ae"; -} -.fa-filter:before { - content: "\f0b0"; -} -.fa-briefcase:before { - content: "\f0b1"; -} -.fa-arrows-alt:before { - content: "\f0b2"; -} -.fa-group:before, -.fa-users:before { - content: "\f0c0"; -} -.fa-chain:before, -.fa-link:before { - content: "\f0c1"; -} -.fa-cloud:before { - content: "\f0c2"; -} -.fa-flask:before { - content: "\f0c3"; -} -.fa-cut:before, -.fa-scissors:before { - content: "\f0c4"; -} -.fa-copy:before, -.fa-files-o:before { - content: "\f0c5"; -} -.fa-paperclip:before { - content: "\f0c6"; -} -.fa-save:before, -.fa-floppy-o:before { - content: "\f0c7"; -} -.fa-square:before { - content: "\f0c8"; -} -.fa-navicon:before, -.fa-reorder:before, -.fa-bars:before { - content: "\f0c9"; -} -.fa-list-ul:before { - content: "\f0ca"; -} -.fa-list-ol:before { - content: "\f0cb"; -} -.fa-strikethrough:before { - content: "\f0cc"; -} -.fa-underline:before { - content: "\f0cd"; -} -.fa-table:before { - content: "\f0ce"; -} -.fa-magic:before { - content: "\f0d0"; -} -.fa-truck:before { - content: "\f0d1"; -} -.fa-pinterest:before { - content: "\f0d2"; -} -.fa-pinterest-square:before { - content: "\f0d3"; -} -.fa-google-plus-square:before { - content: "\f0d4"; -} -.fa-google-plus:before { - content: "\f0d5"; -} -.fa-money:before { - content: "\f0d6"; -} -.fa-caret-down:before { - content: "\f0d7"; -} -.fa-caret-up:before { - content: "\f0d8"; -} -.fa-caret-left:before { - content: "\f0d9"; -} -.fa-caret-right:before { - content: "\f0da"; -} -.fa-columns:before { - content: "\f0db"; -} -.fa-unsorted:before, -.fa-sort:before { - content: "\f0dc"; -} -.fa-sort-down:before, -.fa-sort-desc:before { - content: "\f0dd"; -} -.fa-sort-up:before, -.fa-sort-asc:before { - content: "\f0de"; -} -.fa-envelope:before { - content: "\f0e0"; -} -.fa-linkedin:before { - content: "\f0e1"; -} -.fa-rotate-left:before, -.fa-undo:before { - content: "\f0e2"; -} -.fa-legal:before, -.fa-gavel:before { - content: "\f0e3"; -} -.fa-dashboard:before, -.fa-tachometer:before { - content: "\f0e4"; -} -.fa-comment-o:before { - content: "\f0e5"; -} -.fa-comments-o:before { - content: "\f0e6"; -} -.fa-flash:before, -.fa-bolt:before { - content: "\f0e7"; -} -.fa-sitemap:before { - content: "\f0e8"; -} -.fa-umbrella:before { - content: "\f0e9"; -} -.fa-paste:before, -.fa-clipboard:before { - content: "\f0ea"; -} -.fa-lightbulb-o:before { - content: "\f0eb"; -} -.fa-exchange:before { - content: "\f0ec"; -} -.fa-cloud-download:before { - content: "\f0ed"; -} -.fa-cloud-upload:before { - content: "\f0ee"; -} -.fa-user-md:before { - content: "\f0f0"; -} -.fa-stethoscope:before { - content: "\f0f1"; -} -.fa-suitcase:before { - content: "\f0f2"; -} -.fa-bell-o:before { - content: "\f0a2"; -} -.fa-coffee:before { - content: "\f0f4"; -} -.fa-cutlery:before { - content: "\f0f5"; -} -.fa-file-text-o:before { - content: "\f0f6"; -} -.fa-building-o:before { - content: "\f0f7"; -} -.fa-hospital-o:before { - content: "\f0f8"; -} -.fa-ambulance:before { - content: "\f0f9"; -} -.fa-medkit:before { - content: "\f0fa"; -} -.fa-fighter-jet:before { - content: "\f0fb"; -} -.fa-beer:before { - content: "\f0fc"; -} -.fa-h-square:before { - content: "\f0fd"; -} -.fa-plus-square:before { - content: "\f0fe"; -} -.fa-angle-double-left:before { - content: "\f100"; -} -.fa-angle-double-right:before { - content: "\f101"; -} -.fa-angle-double-up:before { - content: "\f102"; -} -.fa-angle-double-down:before { - content: "\f103"; -} -.fa-angle-left:before { - content: "\f104"; -} -.fa-angle-right:before { - content: "\f105"; -} -.fa-angle-up:before { - content: "\f106"; -} -.fa-angle-down:before { - content: "\f107"; -} -.fa-desktop:before { - content: "\f108"; -} -.fa-laptop:before { - content: "\f109"; -} -.fa-tablet:before { - content: "\f10a"; -} -.fa-mobile-phone:before, -.fa-mobile:before { - content: "\f10b"; -} -.fa-circle-o:before { - content: "\f10c"; -} -.fa-quote-left:before { - content: "\f10d"; -} -.fa-quote-right:before { - content: "\f10e"; -} -.fa-spinner:before { - content: "\f110"; -} -.fa-circle:before { - content: "\f111"; -} -.fa-mail-reply:before, -.fa-reply:before { - content: "\f112"; -} -.fa-github-alt:before { - content: "\f113"; -} -.fa-folder-o:before { - content: "\f114"; -} -.fa-folder-open-o:before { - content: "\f115"; -} -.fa-smile-o:before { - content: "\f118"; -} -.fa-frown-o:before { - content: "\f119"; -} -.fa-meh-o:before { - content: "\f11a"; -} -.fa-gamepad:before { - content: "\f11b"; -} -.fa-keyboard-o:before { - content: "\f11c"; -} -.fa-flag-o:before { - content: "\f11d"; -} -.fa-flag-checkered:before { - content: "\f11e"; -} -.fa-terminal:before { - content: "\f120"; -} -.fa-code:before { - content: "\f121"; -} -.fa-mail-reply-all:before, -.fa-reply-all:before { - content: "\f122"; -} -.fa-star-half-empty:before, -.fa-star-half-full:before, -.fa-star-half-o:before { - content: "\f123"; -} -.fa-location-arrow:before { - content: "\f124"; -} -.fa-crop:before { - content: "\f125"; -} -.fa-code-fork:before { - content: "\f126"; -} -.fa-unlink:before, -.fa-chain-broken:before { - content: "\f127"; -} -.fa-question:before { - content: "\f128"; -} -.fa-info:before { - content: "\f129"; -} -.fa-exclamation:before { - content: "\f12a"; -} -.fa-superscript:before { - content: "\f12b"; -} -.fa-subscript:before { - content: "\f12c"; -} -.fa-eraser:before { - content: "\f12d"; -} -.fa-puzzle-piece:before { - content: "\f12e"; -} -.fa-microphone:before { - content: "\f130"; -} -.fa-microphone-slash:before { - content: "\f131"; -} -.fa-shield:before { - content: "\f132"; -} -.fa-calendar-o:before { - content: "\f133"; -} -.fa-fire-extinguisher:before { - content: "\f134"; -} -.fa-rocket:before { - content: "\f135"; -} -.fa-maxcdn:before { - content: "\f136"; -} -.fa-chevron-circle-left:before { - content: "\f137"; -} -.fa-chevron-circle-right:before { - content: "\f138"; -} -.fa-chevron-circle-up:before { - content: "\f139"; -} -.fa-chevron-circle-down:before { - content: "\f13a"; -} -.fa-html5:before { - content: "\f13b"; -} -.fa-css3:before { - content: "\f13c"; -} -.fa-anchor:before { - content: "\f13d"; -} -.fa-unlock-alt:before { - content: "\f13e"; -} -.fa-bullseye:before { - content: "\f140"; -} -.fa-ellipsis-h:before { - content: "\f141"; -} -.fa-ellipsis-v:before { - content: "\f142"; -} -.fa-rss-square:before { - content: "\f143"; -} -.fa-play-circle:before { - content: "\f144"; -} -.fa-ticket:before { - content: "\f145"; -} -.fa-minus-square:before { - content: "\f146"; -} -.fa-minus-square-o:before { - content: "\f147"; -} -.fa-level-up:before { - content: "\f148"; -} -.fa-level-down:before { - content: "\f149"; -} -.fa-check-square:before { - content: "\f14a"; -} -.fa-pencil-square:before { - content: "\f14b"; -} -.fa-external-link-square:before { - content: "\f14c"; -} -.fa-share-square:before { - content: "\f14d"; -} -.fa-compass:before { - content: "\f14e"; -} -.fa-toggle-down:before, -.fa-caret-square-o-down:before { - content: "\f150"; -} -.fa-toggle-up:before, -.fa-caret-square-o-up:before { - content: "\f151"; -} -.fa-toggle-right:before, -.fa-caret-square-o-right:before { - content: "\f152"; -} -.fa-euro:before, -.fa-eur:before { - content: "\f153"; -} -.fa-gbp:before { - content: "\f154"; -} -.fa-dollar:before, -.fa-usd:before { - content: "\f155"; -} -.fa-rupee:before, -.fa-inr:before { - content: "\f156"; -} -.fa-cny:before, -.fa-rmb:before, -.fa-yen:before, -.fa-jpy:before { - content: "\f157"; -} -.fa-ruble:before, -.fa-rouble:before, -.fa-rub:before { - content: "\f158"; -} -.fa-won:before, -.fa-krw:before { - content: "\f159"; -} -.fa-bitcoin:before, -.fa-btc:before { - content: "\f15a"; -} -.fa-file:before { - content: "\f15b"; -} -.fa-file-text:before { - content: "\f15c"; -} -.fa-sort-alpha-asc:before { - content: "\f15d"; -} -.fa-sort-alpha-desc:before { - content: "\f15e"; -} -.fa-sort-amount-asc:before { - content: "\f160"; -} -.fa-sort-amount-desc:before { - content: "\f161"; -} -.fa-sort-numeric-asc:before { - content: "\f162"; -} -.fa-sort-numeric-desc:before { - content: "\f163"; -} -.fa-thumbs-up:before { - content: "\f164"; -} -.fa-thumbs-down:before { - content: "\f165"; -} -.fa-youtube-square:before { - content: "\f166"; -} -.fa-youtube:before { - content: "\f167"; -} -.fa-xing:before { - content: "\f168"; -} -.fa-xing-square:before { - content: "\f169"; -} -.fa-youtube-play:before { - content: "\f16a"; -} -.fa-dropbox:before { - content: "\f16b"; -} -.fa-stack-overflow:before { - content: "\f16c"; -} -.fa-instagram:before { - content: "\f16d"; -} -.fa-flickr:before { - content: "\f16e"; -} -.fa-adn:before { - content: "\f170"; -} -.fa-bitbucket:before { - content: "\f171"; -} -.fa-bitbucket-square:before { - content: "\f172"; -} -.fa-tumblr:before { - content: "\f173"; -} -.fa-tumblr-square:before { - content: "\f174"; -} -.fa-long-arrow-down:before { - content: "\f175"; -} -.fa-long-arrow-up:before { - content: "\f176"; -} -.fa-long-arrow-left:before { - content: "\f177"; -} -.fa-long-arrow-right:before { - content: "\f178"; -} -.fa-apple:before { - content: "\f179"; -} -.fa-windows:before { - content: "\f17a"; -} -.fa-android:before { - content: "\f17b"; -} -.fa-linux:before { - content: "\f17c"; -} -.fa-dribbble:before { - content: "\f17d"; -} -.fa-skype:before { - content: "\f17e"; -} -.fa-foursquare:before { - content: "\f180"; -} -.fa-trello:before { - content: "\f181"; -} -.fa-female:before { - content: "\f182"; -} -.fa-male:before { - content: "\f183"; -} -.fa-gittip:before, -.fa-gratipay:before { - content: "\f184"; -} -.fa-sun-o:before { - content: "\f185"; -} -.fa-moon-o:before { - content: "\f186"; -} -.fa-archive:before { - content: "\f187"; -} -.fa-bug:before { - content: "\f188"; -} -.fa-vk:before { - content: "\f189"; -} -.fa-weibo:before { - content: "\f18a"; -} -.fa-renren:before { - content: "\f18b"; -} -.fa-pagelines:before { - content: "\f18c"; -} -.fa-stack-exchange:before { - content: "\f18d"; -} -.fa-arrow-circle-o-right:before { - content: "\f18e"; -} -.fa-arrow-circle-o-left:before { - content: "\f190"; -} -.fa-toggle-left:before, -.fa-caret-square-o-left:before { - content: "\f191"; -} -.fa-dot-circle-o:before { - content: "\f192"; -} -.fa-wheelchair:before { - content: "\f193"; -} -.fa-vimeo-square:before { - content: "\f194"; -} -.fa-turkish-lira:before, -.fa-try:before { - content: "\f195"; -} -.fa-plus-square-o:before { - content: "\f196"; -} -.fa-space-shuttle:before { - content: "\f197"; -} -.fa-slack:before { - content: "\f198"; -} -.fa-envelope-square:before { - content: "\f199"; -} -.fa-wordpress:before { - content: "\f19a"; -} -.fa-openid:before { - content: "\f19b"; -} -.fa-institution:before, -.fa-bank:before, -.fa-university:before { - content: "\f19c"; -} -.fa-mortar-board:before, -.fa-graduation-cap:before { - content: "\f19d"; -} -.fa-yahoo:before { - content: "\f19e"; -} -.fa-google:before { - content: "\f1a0"; -} -.fa-reddit:before { - content: "\f1a1"; -} -.fa-reddit-square:before { - content: "\f1a2"; -} -.fa-stumbleupon-circle:before { - content: "\f1a3"; -} -.fa-stumbleupon:before { - content: "\f1a4"; -} -.fa-delicious:before { - content: "\f1a5"; -} -.fa-digg:before { - content: "\f1a6"; -} -.fa-pied-piper-pp:before { - content: "\f1a7"; -} -.fa-pied-piper-alt:before { - content: "\f1a8"; -} -.fa-drupal:before { - content: "\f1a9"; -} -.fa-joomla:before { - content: "\f1aa"; -} -.fa-language:before { - content: "\f1ab"; -} -.fa-fax:before { - content: "\f1ac"; -} -.fa-building:before { - content: "\f1ad"; -} -.fa-child:before { - content: "\f1ae"; -} -.fa-paw:before { - content: "\f1b0"; -} -.fa-spoon:before { - content: "\f1b1"; -} -.fa-cube:before { - content: "\f1b2"; -} -.fa-cubes:before { - content: "\f1b3"; -} -.fa-behance:before { - content: "\f1b4"; -} -.fa-behance-square:before { - content: "\f1b5"; -} -.fa-steam:before { - content: "\f1b6"; -} -.fa-steam-square:before { - content: "\f1b7"; -} -.fa-recycle:before { - content: "\f1b8"; -} -.fa-automobile:before, -.fa-car:before { - content: "\f1b9"; -} -.fa-cab:before, -.fa-taxi:before { - content: "\f1ba"; -} -.fa-tree:before { - content: "\f1bb"; -} -.fa-spotify:before { - content: "\f1bc"; -} -.fa-deviantart:before { - content: "\f1bd"; -} -.fa-soundcloud:before { - content: "\f1be"; -} -.fa-database:before { - content: "\f1c0"; -} -.fa-file-pdf-o:before { - content: "\f1c1"; -} -.fa-file-word-o:before { - content: "\f1c2"; -} -.fa-file-excel-o:before { - content: "\f1c3"; -} -.fa-file-powerpoint-o:before { - content: "\f1c4"; -} -.fa-file-photo-o:before, -.fa-file-picture-o:before, -.fa-file-image-o:before { - content: "\f1c5"; -} -.fa-file-zip-o:before, -.fa-file-archive-o:before { - content: "\f1c6"; -} -.fa-file-sound-o:before, -.fa-file-audio-o:before { - content: "\f1c7"; -} -.fa-file-movie-o:before, -.fa-file-video-o:before { - content: "\f1c8"; -} -.fa-file-code-o:before { - content: "\f1c9"; -} -.fa-vine:before { - content: "\f1ca"; -} -.fa-codepen:before { - content: "\f1cb"; -} -.fa-jsfiddle:before { - content: "\f1cc"; -} -.fa-life-bouy:before, -.fa-life-buoy:before, -.fa-life-saver:before, -.fa-support:before, -.fa-life-ring:before { - content: "\f1cd"; -} -.fa-circle-o-notch:before { - content: "\f1ce"; -} -.fa-ra:before, -.fa-resistance:before, -.fa-rebel:before { - content: "\f1d0"; -} -.fa-ge:before, -.fa-empire:before { - content: "\f1d1"; -} -.fa-git-square:before { - content: "\f1d2"; -} -.fa-git:before { - content: "\f1d3"; -} -.fa-y-combinator-square:before, -.fa-yc-square:before, -.fa-hacker-news:before { - content: "\f1d4"; -} -.fa-tencent-weibo:before { - content: "\f1d5"; -} -.fa-qq:before { - content: "\f1d6"; -} -.fa-wechat:before, -.fa-weixin:before { - content: "\f1d7"; -} -.fa-send:before, -.fa-paper-plane:before { - content: "\f1d8"; -} -.fa-send-o:before, -.fa-paper-plane-o:before { - content: "\f1d9"; -} -.fa-history:before { - content: "\f1da"; -} -.fa-circle-thin:before { - content: "\f1db"; -} -.fa-header:before { - content: "\f1dc"; -} -.fa-paragraph:before { - content: "\f1dd"; -} -.fa-sliders:before { - content: "\f1de"; -} -.fa-share-alt:before { - content: "\f1e0"; -} -.fa-share-alt-square:before { - content: "\f1e1"; -} -.fa-bomb:before { - content: "\f1e2"; -} -.fa-soccer-ball-o:before, -.fa-futbol-o:before { - content: "\f1e3"; -} -.fa-tty:before { - content: "\f1e4"; -} -.fa-binoculars:before { - content: "\f1e5"; -} -.fa-plug:before { - content: "\f1e6"; -} -.fa-slideshare:before { - content: "\f1e7"; -} -.fa-twitch:before { - content: "\f1e8"; -} -.fa-yelp:before { - content: "\f1e9"; -} -.fa-newspaper-o:before { - content: "\f1ea"; -} -.fa-wifi:before { - content: "\f1eb"; -} -.fa-calculator:before { - content: "\f1ec"; -} -.fa-paypal:before { - content: "\f1ed"; -} -.fa-google-wallet:before { - content: "\f1ee"; -} -.fa-cc-visa:before { - content: "\f1f0"; -} -.fa-cc-mastercard:before { - content: "\f1f1"; -} -.fa-cc-discover:before { - content: "\f1f2"; -} -.fa-cc-amex:before { - content: "\f1f3"; -} -.fa-cc-paypal:before { - content: "\f1f4"; -} -.fa-cc-stripe:before { - content: "\f1f5"; -} -.fa-bell-slash:before { - content: "\f1f6"; -} -.fa-bell-slash-o:before { - content: "\f1f7"; -} -.fa-trash:before { - content: "\f1f8"; -} -.fa-copyright:before { - content: "\f1f9"; -} -.fa-at:before { - content: "\f1fa"; -} -.fa-eyedropper:before { - content: "\f1fb"; -} -.fa-paint-brush:before { - content: "\f1fc"; -} -.fa-birthday-cake:before { - content: "\f1fd"; -} -.fa-area-chart:before { - content: "\f1fe"; -} -.fa-pie-chart:before { - content: "\f200"; -} -.fa-line-chart:before { - content: "\f201"; -} -.fa-lastfm:before { - content: "\f202"; -} -.fa-lastfm-square:before { - content: "\f203"; -} -.fa-toggle-off:before { - content: "\f204"; -} -.fa-toggle-on:before { - content: "\f205"; -} -.fa-bicycle:before { - content: "\f206"; -} -.fa-bus:before { - content: "\f207"; -} -.fa-ioxhost:before { - content: "\f208"; -} -.fa-angellist:before { - content: "\f209"; -} -.fa-cc:before { - content: "\f20a"; -} -.fa-shekel:before, -.fa-sheqel:before, -.fa-ils:before { - content: "\f20b"; -} -.fa-meanpath:before { - content: "\f20c"; -} -.fa-buysellads:before { - content: "\f20d"; -} -.fa-connectdevelop:before { - content: "\f20e"; -} -.fa-dashcube:before { - content: "\f210"; -} -.fa-forumbee:before { - content: "\f211"; -} -.fa-leanpub:before { - content: "\f212"; -} -.fa-sellsy:before { - content: "\f213"; -} -.fa-shirtsinbulk:before { - content: "\f214"; -} -.fa-simplybuilt:before { - content: "\f215"; -} -.fa-skyatlas:before { - content: "\f216"; -} -.fa-cart-plus:before { - content: "\f217"; -} -.fa-cart-arrow-down:before { - content: "\f218"; -} -.fa-diamond:before { - content: "\f219"; -} -.fa-ship:before { - content: "\f21a"; -} -.fa-user-secret:before { - content: "\f21b"; -} -.fa-motorcycle:before { - content: "\f21c"; -} -.fa-street-view:before { - content: "\f21d"; -} -.fa-heartbeat:before { - content: "\f21e"; -} -.fa-venus:before { - content: "\f221"; -} -.fa-mars:before { - content: "\f222"; -} -.fa-mercury:before { - content: "\f223"; -} -.fa-intersex:before, -.fa-transgender:before { - content: "\f224"; -} -.fa-transgender-alt:before { - content: "\f225"; -} -.fa-venus-double:before { - content: "\f226"; -} -.fa-mars-double:before { - content: "\f227"; -} -.fa-venus-mars:before { - content: "\f228"; -} -.fa-mars-stroke:before { - content: "\f229"; -} -.fa-mars-stroke-v:before { - content: "\f22a"; -} -.fa-mars-stroke-h:before { - content: "\f22b"; -} -.fa-neuter:before { - content: "\f22c"; -} -.fa-genderless:before { - content: "\f22d"; -} -.fa-facebook-official:before { - content: "\f230"; -} -.fa-pinterest-p:before { - content: "\f231"; -} -.fa-whatsapp:before { - content: "\f232"; -} -.fa-server:before { - content: "\f233"; -} -.fa-user-plus:before { - content: "\f234"; -} -.fa-user-times:before { - content: "\f235"; -} -.fa-hotel:before, -.fa-bed:before { - content: "\f236"; -} -.fa-viacoin:before { - content: "\f237"; -} -.fa-train:before { - content: "\f238"; -} -.fa-subway:before { - content: "\f239"; -} -.fa-medium:before { - content: "\f23a"; -} -.fa-yc:before, -.fa-y-combinator:before { - content: "\f23b"; -} -.fa-optin-monster:before { - content: "\f23c"; -} -.fa-opencart:before { - content: "\f23d"; -} -.fa-expeditedssl:before { - content: "\f23e"; -} -.fa-battery-4:before, -.fa-battery:before, -.fa-battery-full:before { - content: "\f240"; -} -.fa-battery-3:before, -.fa-battery-three-quarters:before { - content: "\f241"; -} -.fa-battery-2:before, -.fa-battery-half:before { - content: "\f242"; -} -.fa-battery-1:before, -.fa-battery-quarter:before { - content: "\f243"; -} -.fa-battery-0:before, -.fa-battery-empty:before { - content: "\f244"; -} -.fa-mouse-pointer:before { - content: "\f245"; -} -.fa-i-cursor:before { - content: "\f246"; -} -.fa-object-group:before { - content: "\f247"; -} -.fa-object-ungroup:before { - content: "\f248"; -} -.fa-sticky-note:before { - content: "\f249"; -} -.fa-sticky-note-o:before { - content: "\f24a"; -} -.fa-cc-jcb:before { - content: "\f24b"; -} -.fa-cc-diners-club:before { - content: "\f24c"; -} -.fa-clone:before { - content: "\f24d"; -} -.fa-balance-scale:before { - content: "\f24e"; -} -.fa-hourglass-o:before { - content: "\f250"; -} -.fa-hourglass-1:before, -.fa-hourglass-start:before { - content: "\f251"; -} -.fa-hourglass-2:before, -.fa-hourglass-half:before { - content: "\f252"; -} -.fa-hourglass-3:before, -.fa-hourglass-end:before { - content: "\f253"; -} -.fa-hourglass:before { - content: "\f254"; -} -.fa-hand-grab-o:before, -.fa-hand-rock-o:before { - content: "\f255"; -} -.fa-hand-stop-o:before, -.fa-hand-paper-o:before { - content: "\f256"; -} -.fa-hand-scissors-o:before { - content: "\f257"; -} -.fa-hand-lizard-o:before { - content: "\f258"; -} -.fa-hand-spock-o:before { - content: "\f259"; -} -.fa-hand-pointer-o:before { - content: "\f25a"; -} -.fa-hand-peace-o:before { - content: "\f25b"; -} -.fa-trademark:before { - content: "\f25c"; -} -.fa-registered:before { - content: "\f25d"; -} -.fa-creative-commons:before { - content: "\f25e"; -} -.fa-gg:before { - content: "\f260"; -} -.fa-gg-circle:before { - content: "\f261"; -} -.fa-tripadvisor:before { - content: "\f262"; -} -.fa-odnoklassniki:before { - content: "\f263"; -} -.fa-odnoklassniki-square:before { - content: "\f264"; -} -.fa-get-pocket:before { - content: "\f265"; -} -.fa-wikipedia-w:before { - content: "\f266"; -} -.fa-safari:before { - content: "\f267"; -} -.fa-chrome:before { - content: "\f268"; -} -.fa-firefox:before { - content: "\f269"; -} -.fa-opera:before { - content: "\f26a"; -} -.fa-internet-explorer:before { - content: "\f26b"; -} -.fa-tv:before, -.fa-television:before { - content: "\f26c"; -} -.fa-contao:before { - content: "\f26d"; -} -.fa-500px:before { - content: "\f26e"; -} -.fa-amazon:before { - content: "\f270"; -} -.fa-calendar-plus-o:before { - content: "\f271"; -} -.fa-calendar-minus-o:before { - content: "\f272"; -} -.fa-calendar-times-o:before { - content: "\f273"; -} -.fa-calendar-check-o:before { - content: "\f274"; -} -.fa-industry:before { - content: "\f275"; -} -.fa-map-pin:before { - content: "\f276"; -} -.fa-map-signs:before { - content: "\f277"; -} -.fa-map-o:before { - content: "\f278"; -} -.fa-map:before { - content: "\f279"; -} -.fa-commenting:before { - content: "\f27a"; -} -.fa-commenting-o:before { - content: "\f27b"; -} -.fa-houzz:before { - content: "\f27c"; -} -.fa-vimeo:before { - content: "\f27d"; -} -.fa-black-tie:before { - content: "\f27e"; -} -.fa-fonticons:before { - content: "\f280"; -} -.fa-reddit-alien:before { - content: "\f281"; -} -.fa-edge:before { - content: "\f282"; -} -.fa-credit-card-alt:before { - content: "\f283"; -} -.fa-codiepie:before { - content: "\f284"; -} -.fa-modx:before { - content: "\f285"; -} -.fa-fort-awesome:before { - content: "\f286"; -} -.fa-usb:before { - content: "\f287"; -} -.fa-product-hunt:before { - content: "\f288"; -} -.fa-mixcloud:before { - content: "\f289"; -} -.fa-scribd:before { - content: "\f28a"; -} -.fa-pause-circle:before { - content: "\f28b"; -} -.fa-pause-circle-o:before { - content: "\f28c"; -} -.fa-stop-circle:before { - content: "\f28d"; -} -.fa-stop-circle-o:before { - content: "\f28e"; -} -.fa-shopping-bag:before { - content: "\f290"; -} -.fa-shopping-basket:before { - content: "\f291"; -} -.fa-hashtag:before { - content: "\f292"; -} -.fa-bluetooth:before { - content: "\f293"; -} -.fa-bluetooth-b:before { - content: "\f294"; -} -.fa-percent:before { - content: "\f295"; -} -.fa-gitlab:before { - content: "\f296"; -} -.fa-wpbeginner:before { - content: "\f297"; -} -.fa-wpforms:before { - content: "\f298"; -} -.fa-envira:before { - content: "\f299"; -} -.fa-universal-access:before { - content: "\f29a"; -} -.fa-wheelchair-alt:before { - content: "\f29b"; -} -.fa-question-circle-o:before { - content: "\f29c"; -} -.fa-blind:before { - content: "\f29d"; -} -.fa-audio-description:before { - content: "\f29e"; -} -.fa-volume-control-phone:before { - content: "\f2a0"; -} -.fa-braille:before { - content: "\f2a1"; -} -.fa-assistive-listening-systems:before { - content: "\f2a2"; -} -.fa-asl-interpreting:before, -.fa-american-sign-language-interpreting:before { - content: "\f2a3"; -} -.fa-deafness:before, -.fa-hard-of-hearing:before, -.fa-deaf:before { - content: "\f2a4"; -} -.fa-glide:before { - content: "\f2a5"; -} -.fa-glide-g:before { - content: "\f2a6"; -} -.fa-signing:before, -.fa-sign-language:before { - content: "\f2a7"; -} -.fa-low-vision:before { - content: "\f2a8"; -} -.fa-viadeo:before { - content: "\f2a9"; -} -.fa-viadeo-square:before { - content: "\f2aa"; -} -.fa-snapchat:before { - content: "\f2ab"; -} -.fa-snapchat-ghost:before { - content: "\f2ac"; -} -.fa-snapchat-square:before { - content: "\f2ad"; -} -.fa-pied-piper:before { - content: "\f2ae"; -} -.fa-first-order:before { - content: "\f2b0"; -} -.fa-yoast:before { - content: "\f2b1"; -} -.fa-themeisle:before { - content: "\f2b2"; -} -.fa-google-plus-circle:before, -.fa-google-plus-official:before { - content: "\f2b3"; -} -.fa-fa:before, -.fa-font-awesome:before { - content: "\f2b4"; -} -.fa-handshake-o:before { - content: "\f2b5"; -} -.fa-envelope-open:before { - content: "\f2b6"; -} -.fa-envelope-open-o:before { - content: "\f2b7"; -} -.fa-linode:before { - content: "\f2b8"; -} -.fa-address-book:before { - content: "\f2b9"; -} -.fa-address-book-o:before { - content: "\f2ba"; -} -.fa-vcard:before, -.fa-address-card:before { - content: "\f2bb"; -} -.fa-vcard-o:before, -.fa-address-card-o:before { - content: "\f2bc"; -} -.fa-user-circle:before { - content: "\f2bd"; -} -.fa-user-circle-o:before { - content: "\f2be"; -} -.fa-user-o:before { - content: "\f2c0"; -} -.fa-id-badge:before { - content: "\f2c1"; -} -.fa-drivers-license:before, -.fa-id-card:before { - content: "\f2c2"; -} -.fa-drivers-license-o:before, -.fa-id-card-o:before { - content: "\f2c3"; -} -.fa-quora:before { - content: "\f2c4"; -} -.fa-free-code-camp:before { - content: "\f2c5"; -} -.fa-telegram:before { - content: "\f2c6"; -} -.fa-thermometer-4:before, -.fa-thermometer:before, -.fa-thermometer-full:before { - content: "\f2c7"; -} -.fa-thermometer-3:before, -.fa-thermometer-three-quarters:before { - content: "\f2c8"; -} -.fa-thermometer-2:before, -.fa-thermometer-half:before { - content: "\f2c9"; -} -.fa-thermometer-1:before, -.fa-thermometer-quarter:before { - content: "\f2ca"; -} -.fa-thermometer-0:before, -.fa-thermometer-empty:before { - content: "\f2cb"; -} -.fa-shower:before { - content: "\f2cc"; -} -.fa-bathtub:before, -.fa-s15:before, -.fa-bath:before { - content: "\f2cd"; -} -.fa-podcast:before { - content: "\f2ce"; -} -.fa-window-maximize:before { - content: "\f2d0"; -} -.fa-window-minimize:before { - content: "\f2d1"; -} -.fa-window-restore:before { - content: "\f2d2"; -} -.fa-times-rectangle:before, -.fa-window-close:before { - content: "\f2d3"; -} -.fa-times-rectangle-o:before, -.fa-window-close-o:before { - content: "\f2d4"; -} -.fa-bandcamp:before { - content: "\f2d5"; -} -.fa-grav:before { - content: "\f2d6"; -} -.fa-etsy:before { - content: "\f2d7"; -} -.fa-imdb:before { - content: "\f2d8"; -} -.fa-ravelry:before { - content: "\f2d9"; -} -.fa-eercast:before { - content: "\f2da"; -} -.fa-microchip:before { - content: "\f2db"; -} -.fa-snowflake-o:before { - content: "\f2dc"; -} -.fa-superpowers:before { - content: "\f2dd"; -} -.fa-wpexplorer:before { - content: "\f2de"; -} -.fa-meetup:before { - content: "\f2e0"; -} -.sr-only { - position: absolute; - width: 1px; - height: 1px; - padding: 0; - margin: -1px; - overflow: hidden; - clip: rect(0, 0, 0, 0); - border: 0; -} -.sr-only-focusable:active, -.sr-only-focusable:focus { - position: static; - width: auto; - height: auto; - margin: 0; - overflow: visible; - clip: auto; -} diff --git a/doc/_static/fontawesome-webfont.eot b/doc/_static/fontawesome-webfont.eot deleted file mode 100644 index e9f60ca9..00000000 Binary files a/doc/_static/fontawesome-webfont.eot and /dev/null differ diff --git a/doc/_static/fontawesome-webfont.ttf b/doc/_static/fontawesome-webfont.ttf deleted file mode 100644 index 35acda2f..00000000 Binary files a/doc/_static/fontawesome-webfont.ttf and /dev/null differ diff --git a/doc/_static/fontawesome-webfont.woff b/doc/_static/fontawesome-webfont.woff deleted file mode 100644 index 400014a4..00000000 Binary files a/doc/_static/fontawesome-webfont.woff and /dev/null differ diff --git a/doc/_static/fontawesome-webfont.woff2 b/doc/_static/fontawesome-webfont.woff2 deleted file mode 100644 index 4d13fc60..00000000 Binary files a/doc/_static/fontawesome-webfont.woff2 and /dev/null differ diff --git a/doc/_templates/autosummary/class.rst b/doc/_templates/autosummary/class.rst index e4adacd5..fe474401 100644 --- a/doc/_templates/autosummary/class.rst +++ b/doc/_templates/autosummary/class.rst @@ -1,12 +1,12 @@ -{{ fullname }} -{{ underline }} +{{ fullname | escape | underline }} .. currentmodule:: {{ module }} .. autoclass:: {{ objname }} - :special-members: __contains__,__getitem__,__iter__,__len__,__add__,__sub__,__mul__,__div__,__neg__,__hash__ + :special-members: __contains__,__getitem__,__iter__,__len__,__add__,__sub__,__mul__,__div__,__neg__ + :members: - {% block methods %} - {% endblock %} +.. _sphx_glr_backreferences_{{ fullname }}: -.. include:: {{module}}.{{objname}}.examples +.. minigallery:: {{ fullname }} + :add-heading: diff --git a/doc/_templates/autosummary/function.rst b/doc/_templates/autosummary/function.rst index bdde2420..bd78b8e8 100644 --- a/doc/_templates/autosummary/function.rst +++ b/doc/_templates/autosummary/function.rst @@ -1,12 +1,10 @@ -{{ fullname }} -{{ underline }} +{{ fullname | escape | underline }} .. currentmodule:: {{ module }} .. autofunction:: {{ objname }} -.. include:: {{module}}.{{objname}}.examples +.. _sphx_glr_backreferences_{{ fullname }}: -.. raw:: html - -
+.. minigallery:: {{ fullname }} + :add-heading: diff --git a/doc/conf.py b/doc/conf.py index 8abf8f39..4ae798d1 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Expyfun documentation build configuration file, created by # sphinx-quickstart on Fri Jun 11 10:45:48 2010. @@ -14,75 +13,75 @@ import inspect import os -from os.path import relpath, dirname import sys from datetime import date -import sphinx_gallery # noqa -import sphinx_bootstrap_theme # noqa -from numpydoc import numpydoc, docscrape # noqa +from os.path import dirname, relpath + +import sphinx # noqa +from numpydoc import docscrape, numpydoc # noqa # Work around Pyglet annoyingness -assert 'pyglet' not in sys.modules -if 'sphinx' in sys.modules: - s = sys.modules.pop('sphinx') - import pyglet # noqa - sys.modules['sphinx'] = s - del s +assert "pyglet" not in sys.modules +s = sys.modules.pop("sphinx") +import pyglet # noqa + +sys.modules["sphinx"] = s +del s + +import expyfun # noqa: E402 # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. curdir = os.path.dirname(__file__) -sys.path.append(os.path.abspath(os.path.join(curdir, '..', 'expyfun'))) -sys.path.append(os.path.abspath(os.path.join(curdir, 'sphinxext'))) +sys.path.append(os.path.abspath(os.path.join(curdir, "sphinxext"))) + -import expyfun -if not os.path.isdir('_images'): - os.mkdir('_images') +if not os.path.isdir("_images"): + os.mkdir("_images") # -- General configuration ------------------------------------------------ # If your documentation needs a minimal Sphinx version, state it here. -needs_sphinx = '1.8' +needs_sphinx = "1.8" # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.coverage', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - 'sphinx.ext.linkcode', - 'sphinx.ext.mathjax', - 'sphinx.ext.todo', - 'sphinx_gallery.gen_gallery', - 'sphinx_fontawesome', - 'numpydoc', - 'sphinx_bootstrap_theme', + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.coverage", + "sphinx.ext.doctest", + "sphinx.ext.intersphinx", + "sphinx.ext.linkcode", + "sphinx.ext.mathjax", + "sphinx.ext.todo", + "sphinx_gallery.gen_gallery", + "numpydoc", ] autosummary_generate = True -autodoc_default_options = {'inherited-members': None} +autodoc_default_options = {"inherited-members": None} # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix of source filenames. -source_suffix = '.rst' +source_suffix = ".rst" # The encoding of source files. -#source_encoding = 'utf-8-sig' +# source_encoding = 'utf-8-sig' # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = u'expyfun' +project = "expyfun" td = date.today() -copyright = u'2013-%s, expyfun developers. Last updated on %s' % (td.year, - td.isoformat()) +copyright = ( # noqa: A001 + f"2013-{td.year}, expyfun developers. Last updated on {td.isoformat()}" +) nitpicky = True # The version info for the project you're documenting, acts as replacement for @@ -96,81 +95,86 @@ # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. -#language = None +# language = None # There are two options for replacing |today|: either, you set today to some # non-false value, then it is used: -#today = '' +# today = '' # Else, today_fmt is used as the format for a strftime call. -#today_fmt = '%B %d, %Y' +# today_fmt = '%B %d, %Y' # List of documents that shouldn't be included in the build. unused_docs = [] # List of directories, relative to source directory, that shouldn't be searched # for source files. -exclude_trees = ['_build'] +exclude_trees = ["_build"] # The reST default role (used for this markup: `text`) to use for all # documents. default_role = "autolink" # If true, '()' will be appended to :func: etc. cross-reference text. -#add_function_parentheses = True +# add_function_parentheses = True # If true, the current module name will be prepended to all description # unit titles (such as .. function::). -#add_module_names = True +# add_module_names = True # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. -#show_authors = False +# show_authors = False # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. -modindex_common_prefix = ['expyfun.'] +modindex_common_prefix = ["expyfun."] # If true, keep warnings as "system message" paragraphs in the built documents. -#keep_warnings = False +# keep_warnings = False # -- Options for HTML output ---------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = 'bootstrap' +html_theme = "pydata_sphinx_theme" # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. html_theme_options = { - 'navbar_title': 'expyfun', - 'source_link_position': "nav", # default - 'bootswatch_theme': "yeti", - 'navbar_sidebarrel': False, # Render the next/prev links in navbar? - 'navbar_pagenav': True, - 'globaltoc_depth': 0, - 'navbar_class': "navbar", - 'bootstrap_version': "3", # default - 'navbar_links': [ - ("Getting started", "getting_started"), - ("Examples", "auto_examples/index"), - ("API reference", "python_reference"), + "logo": { + "text": "expyfun", + }, + "icon_links": [ + dict( + name="GitHub", + url="https://github.com/LABSN/expyfun", + icon="fa-brands fa-square-github", + ), ], - } + "icon_links_label": "External Links", # for screen reader + "use_edit_page_button": False, + "navigation_with_keys": False, + "show_toc_level": 1, + "article_header_start": [], # disable breadcrumbs + "navbar_end": ["theme-switcher", "navbar-icon-links"], + "footer_start": ["copyright"], + "secondary_sidebar_items": ["page-toc", "edit-this-page"], +} # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". -#html_title = None +# html_title = None # A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None +# html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. -html_logo = "_static/favicon.ico" +# html_logo = "_static/favicon.ico" # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 @@ -180,36 +184,36 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static', '_images'] +html_static_path = ["_static", "_images"] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied # directly to the root of the documentation. -#html_extra_path = [] +# html_extra_path = [] # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -#html_last_updated_fmt = '%b %d, %Y' +# html_last_updated_fmt = '%b %d, %Y' # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. -#html_use_smartypants = True +# html_use_smartypants = True # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +html_sidebars = {"getting_started": []} # Additional templates that should be rendered to pages, maps page names to # template names. -#html_additional_pages = {} +# html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +# html_domain_indices = True # If false, no index is generated. -#html_use_index = True +# html_use_index = True # If true, the index is split into individual pages for each letter. -#html_split_index = False +# html_split_index = False # If true, links to the reST sources are added to the pages. html_show_sourcelink = False @@ -219,52 +223,55 @@ html_show_sphinx = False # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. -#html_show_copyright = True +# html_show_copyright = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. -#html_use_opensearch = '' +# html_use_opensearch = '' # variables to pass to HTML templating engine -build_dev_html = bool(int(os.environ.get('BUILD_DEV_HTML', False))) +build_dev_html = bool(int(os.environ.get("BUILD_DEV_HTML", False))) -html_context = {'use_google_analytics': True, 'use_twitter': True, - 'use_media_buttons': True, 'build_dev_html': build_dev_html} +html_context = { + "use_google_analytics": True, + "use_twitter": True, + "use_media_buttons": True, + "build_dev_html": build_dev_html, +} # This is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = None +# html_file_suffix = None # Output file base name for HTML help builder. -htmlhelp_basename = 'expyfun-doc' +htmlhelp_basename = "expyfun-doc" trim_doctests_flags = True # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = { - 'python': ('https://docs.python.org/3', None), - 'numpy': ('https://numpy.org/devdocs', None), - 'scipy': ('https://scipy.github.io/devdocs', None), - 'matplotlib': ('https://matplotlib.org', None), - 'sklearn': ('https://scikit-learn.org/stable', None), - 'pandas': ('https://pandas.pydata.org/pandas-docs/stable', None), - 'sounddevice': ('https://python-sounddevice.readthedocs.io', None), - 'rtmixer': ('https://python-rtmixer.readthedocs.io/en/latest', None), - 'pyglet': ('https://pyglet.readthedocs.io/en/latest', None), - 'mne': ('https://mne-tools.github.io/dev', None), + "python": ("https://docs.python.org/3", None), + "numpy": ("https://numpy.org/devdocs", None), + "scipy": ("https://scipy.github.io/devdocs", None), + "matplotlib": ("https://matplotlib.org/stable", None), + "pandas": ("https://pandas.pydata.org/pandas-docs/stable", None), + "sounddevice": ("https://python-sounddevice.readthedocs.io", None), + "rtmixer": ("https://python-rtmixer.readthedocs.io/en/latest", None), + "pyglet": ("https://pyglet.readthedocs.io/en/latest", None), + "mne": ("https://mne.tools/dev", None), } -examples_dirs = ['../examples'] -gallery_dirs = ['auto_examples'] +examples_dirs = ["../examples"] +gallery_dirs = ["auto_examples"] sphinx_gallery_conf = { - 'doc_module': ('expyfun',), - 'examples_dirs': examples_dirs, - 'gallery_dirs': gallery_dirs, - 'backreferences_dir': 'generated', - 'plot_gallery': 'True', # Avoid annoying Unicode/bool default warning - 'filename_pattern': r'/.*(?>> expyfun.get_config() @@ -165,8 +163,7 @@ The fixed, hardware-dependent settings for a given system get written to an ``expyfun.json`` file. You can use :func:`expyfun.get_config_path` to get the path to your config file. Some sample configurations: -- A TDT-based M/EEG+pupillometry machine: - +A TDT-based M/EEG+pupillometry machine .. code-block:: JSON { @@ -182,8 +179,7 @@ get the path to your config file. Some sample configurations: "TRIGGER_CONTROLLER": "tdt" } -- A sound-card-based EEG system: - +A sound-card-based EEG system .. code-block:: JSON { diff --git a/doc/git_diagram.py b/doc/git_diagram.py index 003910f4..231f06ca 100644 --- a/doc/git_diagram.py +++ b/doc/git_diagram.py @@ -1,16 +1,13 @@ -# -*- coding: utf-8 -*- - -import os -from os import path as op +import pygraphviz as pgv -title = 'git flow diagram' +title = "git flow diagram" -font_face = 'Arial' +font_face = "Arial" node_size = 12 node_small_size = 9 edge_size = 9 -local_color = '#7bbeca' -remote_color = '#ff6347' +local_color = "#7bbeca" +remote_color = "#ff6347" legend = """ < @@ -20,102 +17,94 @@ Remote repositories >""" % (edge_size, local_color, remote_color) -legend = ''.join(legend.split('\n')) +legend = "".join(legend.split("\n")) nodes = dict( - upstream='LABSN/expyfun\n' - 'master\n' - ' ', - maint='Eric89GXL/expyfun\n' - 'master\n' - 'other_branch', - dev='rkmaddox/expyfun\n' - 'master\n' - 'fix_branch', - maint_clone='/home/larsoner/expyfun\n' - 'master (origin/master)\n' - 'other_branch (origin/other_branch)\n' - 'ross_branch (rkmaddox/fix_branch)', - dev_clone='/home/rkmaddox/expyfun\n' - 'master (origin/master)\n' - 'fix_branch (origin/fix_branch)\n' - ' ', - user_clone='/home/akclee/expyfun\n' - 'master (origin/master)\n' - ' \n' - ' ', + upstream="LABSN/expyfun\n" "master\n" " ", + maint="Eric89GXL/expyfun\n" "master\n" "other_branch", + dev="rkmaddox/expyfun\n" "master\n" "fix_branch", + maint_clone="/home/larsoner/expyfun\n" + "master (origin/master)\n" + "other_branch (origin/other_branch)\n" + "ross_branch (rkmaddox/fix_branch)", + dev_clone="/home/rkmaddox/expyfun\n" + "master (origin/master)\n" + "fix_branch (origin/fix_branch)\n" + " ", + user_clone="/home/akclee/expyfun\n" "master (origin/master)\n" " \n" " ", legend=legend, ) -remote_space = ('maint', 'dev', 'upstream') -local_space = ('maint_clone', 'dev_clone', 'user_clone') +remote_space = ("maint", "dev", "upstream") +local_space = ("maint_clone", "dev_clone", "user_clone") edges = ( - ('maint_clone', 'maint', 'origin'), - ('dev_clone', 'dev', 'origin'), - ('user_clone', 'upstream', 'origin'), - ('maint_clone', 'upstream', 'upstream'), - ('maint_clone', 'dev', 'rkmaddox'), - ('dev_clone', 'upstream', 'upstream'), + ("maint_clone", "maint", "origin"), + ("dev_clone", "dev", "origin"), + ("user_clone", "upstream", "origin"), + ("maint_clone", "upstream", "upstream"), + ("maint_clone", "dev", "rkmaddox"), + ("dev_clone", "upstream", "upstream"), ) subgraphs = ( - [('upstream', 'maint', 'dev'), ('GitHub')], - [('maint_clone'), ('Maintainer')], - [('dev_clone'), ("Developer")], - [('user_clone'), ("User")], + [("upstream", "maint", "dev"), ("GitHub")], + [("maint_clone"), ("Maintainer")], + [("dev_clone"), ("Developer")], + [("user_clone"), ("User")], ) -import pygraphviz as pgv g = pgv.AGraph(name=title, directed=True) for key, label in nodes.items(): - label = label.split('\n') + label = label.split("\n") if len(label) > 1: - label[0] = ('<' % node_size - + label[0] + '') + label[0] = '<' % node_size + label[0] + "" for li in range(1, len(label)): - label[li] = ('' % node_small_size - + label[li] + '') + label[li] = ( + '' % node_small_size + + label[li] + + "" + ) label[-1] = label[-1] + '
>' label = '
'.join(label) else: label = label[0] - g.add_node(key, shape='plaintext', label=label) + g.add_node(key, shape="plaintext", label=label) # Create and customize nodes and edges for edge in edges: g.add_edge(*edge[:2]) e = g.get_edge(*edge[:2]) if len(edge) > 2: - e.attr['label'] = ('<' + - '
'.join(edge[2].split('\n')) + - '
>') - e.attr['fontsize'] = edge_size + e.attr["label"] = ( + "<" + + '
'.join(edge[2].split("\n")) + + '
>' + ) + e.attr["fontsize"] = edge_size g.get_node # Change colors -for these_nodes, color in zip((local_space, remote_space), - (local_color, remote_color)): +for these_nodes, color in zip((local_space, remote_space), (local_color, remote_color)): for node in these_nodes: - g.get_node(node).attr['fillcolor'] = color - g.get_node(node).attr['style'] = 'filled' + g.get_node(node).attr["fillcolor"] = color + g.get_node(node).attr["style"] = "filled" # Create subgraphs for si, subgraph in enumerate(subgraphs): - g.add_subgraph(subgraph[0], 'cluster%s' % si, - label=subgraph[1], color='black') + g.add_subgraph(subgraph[0], "cluster%s" % si, label=subgraph[1], color="black") # Format (sub)graphs for gr in g.subgraphs() + [g]: for x in [gr.node_attr, gr.edge_attr]: - x['fontname'] = font_face -g.node_attr['shape'] = 'box' + x["fontname"] = font_face +g.node_attr["shape"] = "box" -g.get_node('legend').attr.update(shape='plaintext', margin=0, rank='sink') +g.get_node("legend").attr.update(shape="plaintext", margin=0, rank="sink") # put legend in same rank/level as inverse -l = g.add_subgraph(['legend', 'inv'], name='legendy') -l.graph_attr['rank'] = 'same' +ll = g.add_subgraph(["legend", "inv"], name="legendy") +ll.graph_attr["rank"] = "same" -g.layout('dot') -g.draw('git_flow.svg', format='svg') +g.layout("dot") +g.draw("git_flow.svg", format="svg") diff --git a/doc/index.rst b/doc/index.rst index 4332eabc..bc1f5195 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -1,17 +1,11 @@ -.. raw:: html +======= +expyfun +======= -
+.. rst-class:: h4 font-weight-light my-4 -==================== -Expyfun |headphones| -==================== - -A high-precision auditory and visual stimulus delivery library for -psychoacoustics in Python. - -.. raw:: html - -
+ A high-precision auditory and visual stimulus delivery library for + psychoacoustics in Python. Purpose ------- @@ -49,3 +43,10 @@ Hardware support - Mouse responses - Cedrus response boxes - Joystick control / responses + +.. toctree:: + :hidden: + + getting_started.rst + python_reference.rst + auto_examples/index.rst diff --git a/doc/parallel_installation.rst b/doc/parallel_installation.rst index eb94bf7a..8b5be040 100644 --- a/doc/parallel_installation.rst +++ b/doc/parallel_installation.rst @@ -14,7 +14,7 @@ USB protocol itself, which is not designed for low-latency control. Instructions differ between Linux and Windows: -- |linux| Linux +Linux On Linux, you need ``pyparallel``:: $ pip install pyparallel @@ -28,7 +28,7 @@ Instructions differ between Linux and Windows: 5. ``$ ls /dev/parport*`` to get the parallel port address, e.g. ``'/dev/parport0'``, and set this as ``TRIGGER_ADDRESS`` in the config. -- |windows| Windows +Windows If you are on a modern Windows system (i.e., 64-bit), you'll need to: - Download the latest "binaries" archive from the `InpOut32 site`_ diff --git a/environment_test.yml b/environment_test.yml index 8ce8f254..688ad5db 100644 --- a/environment_test.yml +++ b/environment_test.yml @@ -1,20 +1,19 @@ name: test channels: -- conda-forge + - conda-forge dependencies: -- scipy -- matplotlib -- pandas -- h5py -- coverage -- setuptools -- mne-base -- numpydoc -- pytest -- pytest-cov -- pytest-timeout -- pillow -- joblib -- ffmpeg -- codecov + - scipy + - matplotlib + - pandas + - h5py + - coverage + - setuptools + - mne-base + - numpydoc + - pytest + - pytest-cov + - pytest-timeout + - pillow + - joblib + - ffmpeg<6 # Do pip separately diff --git a/examples/analysis/analysis_demo.py b/examples/analysis/analysis_demo.py index 768f8622..bd66f7b5 100644 --- a/examples/analysis/analysis_demo.py +++ b/examples/analysis/analysis_demo.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ ============= Analysis demo @@ -12,9 +11,9 @@ # # License: BSD (3-clause) +import matplotlib.pyplot as plt import numpy as np import pandas as pd -import matplotlib.pyplot as plt import expyfun.analyze as ea @@ -26,7 +25,7 @@ a_prob = 0.9 b_prob = 0.6 f_prob = 0.2 -subjs = ['a', 'b', 'c', 'd', 'e'] +subjs = ["a", "b", "c", "d", "e"] a_hit = np.random.binomial(targets, a_prob, len(subjs)) b_hit = np.random.binomial(targets, b_prob, len(subjs)) a_fa = np.random.binomial(foils, f_prob, len(subjs)) @@ -35,25 +34,43 @@ b_miss = targets - b_hit a_cr = foils - a_fa b_cr = foils - b_fa -data = pd.DataFrame(dict(a_hit=a_hit, a_miss=a_miss, a_fa=a_fa, a_cr=a_cr, - b_hit=b_hit, b_miss=b_miss, b_fa=b_fa, b_cr=b_cr), - index=subjs) +data = pd.DataFrame( + dict( + a_hit=a_hit, + a_miss=a_miss, + a_fa=a_fa, + a_cr=a_cr, + b_hit=b_hit, + b_miss=b_miss, + b_fa=b_fa, + b_cr=b_cr, + ), + index=subjs, +) # calculate dprimes -a_dprime = ea.dprime(data[['a_hit', 'a_miss', 'a_fa', 'a_cr']]) -b_dprime = ea.dprime(data[['b_hit', 'b_miss', 'b_fa', 'b_cr']]) +a_dprime = ea.dprime(data[["a_hit", "a_miss", "a_fa", "a_cr"]]) +b_dprime = ea.dprime(data[["b_hit", "b_miss", "b_fa", "b_cr"]]) results = pd.DataFrame(dict(ctrl=a_dprime, test=b_dprime)) # plot -subplt, barplt = ea.barplot(results, axis=0, err_bars='sd', lines=True, - brackets=[(0, 1)], bracket_text=[r'$p < 10^{-9}$']) -subplt.yaxis.set_label_text('d-prime +/- 1 s.d.') -subplt.set_title('Each line represents a different subject') +subplt, barplt = ea.barplot( + results, + axis=0, + err_bars="sd", + lines=True, + brackets=[(0, 1)], + bracket_text=[r"$p < 10^{-9}$"], +) +subplt.yaxis.set_label_text("d-prime +/- 1 s.d.") +subplt.set_title("Each line represents a different subject") # significance brackets example trials_per_cond = 100 -conds = ['ctrl', 'test'] -diffs = ['easy', 'hard'] -colnames = ['-'.join([x, y]) for x, y in zip(conds * 2, - np.tile(diffs, (2, 1)).T.ravel().tolist())] +conds = ["ctrl", "test"] +diffs = ["easy", "hard"] +colnames = [ + "-".join([x, y]) + for x, y in zip(conds * 2, np.tile(diffs, (2, 1)).T.ravel().tolist()) +] cond_prob = [0.9, 0.8] diff_prob = [0.9, 0.7] cond_block = np.tile(np.atleast_2d(cond_prob).T, (2, len(subjs))).T @@ -62,19 +79,26 @@ shape = (len(subjs), len(conds) * len(diffs)) rawscores_targ = np.random.binomial(trials_per_cond, probs, shape) rawscores_foil = np.random.binomial(trials_per_cond, probs, shape) -hmfc = np.c_[rawscores_targ.ravel(), - (trials_per_cond - rawscores_targ).ravel(), - (trials_per_cond - rawscores_foil).ravel(), - rawscores_foil.ravel()] +hmfc = np.c_[ + rawscores_targ.ravel(), + (trials_per_cond - rawscores_targ).ravel(), + (trials_per_cond - rawscores_foil).ravel(), + rawscores_foil.ravel(), +] dprimes = ea.dprime(hmfc).reshape(shape) results = pd.DataFrame(dprimes, index=subjs, columns=colnames) -subplt, barplt = ea.barplot(results, axis=0, err_bars='sd', lines=True, - groups=[(0, 1), (2, 3)], group_names=diffs, - bar_names=conds * 2, bracket_group_lines=True, - brackets=[(0, 1), (2, 3), (0, 2), (1, 3), - ([0, 1], 3)], # [2, 3] - bracket_text=['foo', 'bar', 'baz', 'snafu', - 'foobar']) -subplt.yaxis.set_label_text('d-prime +/- 1 s.d.') -subplt.set_title('Each line represents a different subject') +subplt, barplt = ea.barplot( + results, + axis=0, + err_bars="sd", + lines=True, + groups=[(0, 1), (2, 3)], + group_names=diffs, + bar_names=conds * 2, + bracket_group_lines=True, + brackets=[(0, 1), (2, 3), (0, 2), (1, 3), ([0, 1], 3)], # [2, 3] + bracket_text=["foo", "bar", "baz", "snafu", "foobar"], +) +subplt.yaxis.set_label_text("d-prime +/- 1 s.d.") +subplt.set_title("Each line represents a different subject") plt.show() diff --git a/examples/analysis/parse_demo.py b/examples/analysis/parse_demo.py index 1470b8ac..119e2e88 100644 --- a/examples/analysis/parse_demo.py +++ b/examples/analysis/parse_demo.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ ============ Parsing demo @@ -17,13 +16,13 @@ print(__doc__) -data = read_tab('sample.tab') # from simple_experiment -print('Number of trials: %s' % len(data)) +data = read_tab("sample.tab") # from simple_experiment +print("Number of trials: %s" % len(data)) keys = list(data[0].keys()) -print('Data keys: %s\n' % keys) +print("Data keys: %s\n" % keys) for di, d in enumerate(data): - if d['trial_id'][0][0] == 'multi-tone': - print('Trial %s multi-tone' % (di + 1)) - targs = ast.literal_eval(d['multi-tone trial'][0][0]) - presses = [int(k[0]) for k in d['keypress']] - print(' Targs: %s\n Press: %s' % (targs, presses)) + if d["trial_id"][0][0] == "multi-tone": + print("Trial %s multi-tone" % (di + 1)) + targs = ast.literal_eval(d["multi-tone trial"][0][0]) + presses = [int(k[0]) for k in d["keypress"]] + print(" Targs: %s\n Press: %s" % (targs, presses)) diff --git a/examples/basic_experiment.py b/examples/basic_experiment.py index 1c2d549b..e16243ef 100644 --- a/examples/basic_experiment.py +++ b/examples/basic_experiment.py @@ -18,26 +18,26 @@ print(__doc__) # set configuration -fs = 24414. # default for ExperimentController +fs = 24414.0 # default for ExperimentController dur = 1.0 tone = np.sin(2 * np.pi * 1000 * np.arange(int(fs * dur)) / float(fs)) tone *= 0.01 * np.sqrt(2) # Set RMS to 0.01 -max_wait = 1. if not building_doc else 0. +max_wait = 1.0 if not building_doc else 0.0 -with ExperimentController('testExp', participant='foo', session='001', - output_dir=None, version='dev') as ec: - ec.screen_prompt('Press a button when you hear the tone', - max_wait=max_wait) +with ExperimentController( + "testExp", participant="foo", session="001", output_dir=None, version="dev" +) as ec: + ec.screen_prompt("Press a button when you hear the tone", max_wait=max_wait) dot = FixationDot(ec) ec.load_buffer(tone) dot.draw() screenshot = ec.screenshot() # only because we want to show it in the docs - ec.identify_trial(ec_id='tone', ttl_id=[0, 0]) + ec.identify_trial(ec_id="tone", ttl_id=[0, 0]) ec.start_stimulus() - presses = ec.wait_for_presses(dur if not building_doc else 0.) + presses = ec.wait_for_presses(dur if not building_doc else 0.0) ec.trial_ok() - print('Presses:\n{}'.format(presses)) + print(f"Presses:\n{presses}") analyze.plot_screen(screenshot) diff --git a/examples/experiments/drawing_methods.py b/examples/experiments/drawing_methods.py index dbe85af1..4265b8a9 100644 --- a/examples/experiments/drawing_methods.py +++ b/examples/experiments/drawing_methods.py @@ -11,16 +11,22 @@ import numpy as np -from expyfun import visual, ExperimentController import expyfun.analyze as ea +from expyfun import ExperimentController, visual print(__doc__) -with ExperimentController('test', session='1', participant='2', - full_screen=False, window_size=[600, 600], - output_dir=None, version='dev') as ec: - ec.screen_text('hello') +with ExperimentController( + "test", + session="1", + participant="2", + full_screen=False, + window_size=[600, 600], + output_dir=None, + version="dev", +) as ec: + ec.screen_text("hello") # make an image with alpha the x-dimension (columns), RGB upward img_buffer = np.zeros((120, 100, 4)) @@ -28,17 +34,29 @@ img_buffer[:, 50:, 3] = 0.5 img_buffer[0] = 1 for ii in range(3): - img_buffer[ii * 40:(ii + 1) * 40, :, ii] = 1.0 - img = visual.RawImage(ec, img_buffer, scale=2.) + img_buffer[ii * 40 : (ii + 1) * 40, :, ii] = 1.0 + img = visual.RawImage(ec, img_buffer, scale=2.0) # make a line, rectangle, diamond, and circle - line = visual.Line(ec, [[-2, 2, 2, -2], [-2, 2, -2, -2]], units='deg', - line_color='w', line_width=2.0) - rect = visual.Rectangle(ec, [0, 0, 2, 2], units='deg', fill_color='y') - diamond = visual.Diamond(ec, [0, 0, 4, 4], units='deg', fill_color=None, - line_color='gray', line_width=2.0) - circle = visual.Circle(ec, 1, units='deg', line_color='w', fill_color='k', - line_width=2.0) + line = visual.Line( + ec, + [[-2, 2, 2, -2], [-2, 2, -2, -2]], + units="deg", + line_color="w", + line_width=2.0, + ) + rect = visual.Rectangle(ec, [0, 0, 2, 2], units="deg", fill_color="y") + diamond = visual.Diamond( + ec, + [0, 0, 4, 4], + units="deg", + fill_color=None, + line_color="gray", + line_width=2.0, + ) + circle = visual.Circle( + ec, 1, units="deg", line_color="w", fill_color="k", line_width=2.0 + ) # do the drawing, then flip for obj in [img, line, rect, diamond, circle]: diff --git a/examples/experiments/eyetracking_experiment_.py b/examples/experiments/eyetracking_experiment_.py index e1270600..3e2ff6d1 100644 --- a/examples/experiments/eyetracking_experiment_.py +++ b/examples/experiments/eyetracking_experiment_.py @@ -12,53 +12,70 @@ import numpy as np -from expyfun import ExperimentController, EyelinkController, visual import expyfun.analyze as ea +from expyfun import ExperimentController, EyelinkController, visual print(__doc__) -with ExperimentController('testExp', full_screen=True, participant='foo', - session='001', output_dir=None, version='dev') as ec: +with ExperimentController( + "testExp", + full_screen=True, + participant="foo", + session="001", + output_dir=None, + version="dev", +) as ec: el = EyelinkController(ec) - ec.screen_prompt('Welcome to the experiment!\n\nFirst, we will ' - 'perform a screen calibration.\n\nPress a button ' - 'to continue.') + ec.screen_prompt( + "Welcome to the experiment!\n\nFirst, we will " + "perform a screen calibration.\n\nPress a button " + "to continue." + ) el.calibrate() # by default this starts recording EyeLink data - ec.screen_prompt('Excellent! Now, follow the red circle around the edge ' - 'of the big white circle.\n\nPress a button to ' - 'continue') + ec.screen_prompt( + "Excellent! Now, follow the red circle around the edge " + "of the big white circle.\n\nPress a button to " + "continue" + ) # make some circles to be drawn radius = 7.5 # degrees targ_rad = 0.2 # degrees - theta = np.linspace(np.pi / 2., 2.5 * np.pi, 200) + theta = np.linspace(np.pi / 2.0, 2.5 * np.pi, 200) x_pos, y_pos = radius * np.cos(theta), radius * np.sin(theta) - big_circ = visual.Circle(ec, radius, (0, 0), units='deg', - fill_color=None, line_color='white', - line_width=3.0) - targ_circ = visual.Circle(ec, targ_rad, (x_pos[0], y_pos[0]), - units='deg', fill_color='red') + big_circ = visual.Circle( + ec, + radius, + (0, 0), + units="deg", + fill_color=None, + line_color="white", + line_width=3.0, + ) + targ_circ = visual.Circle( + ec, targ_rad, (x_pos[0], y_pos[0]), units="deg", fill_color="red" + ) fix_pos = (x_pos[0], y_pos[0]) # start out by waiting for a 1 sec fixation at the start big_circ.draw() targ_circ.draw() screenshot = ec.screenshot() - ec.identify_trial(ec_id='Circle', ttl_id=[0], el_id=[0]) + ec.identify_trial(ec_id="Circle", ttl_id=[0], el_id=[0]) ec.start_stimulus() # automatically stamps to EL - if not el.wait_for_fix(fix_pos, 1., max_wait=5., units='deg'): - print('Initial fixation failed') + if not el.wait_for_fix(fix_pos, 1.0, max_wait=5.0, units="deg"): + print("Initial fixation failed") for ii, (x, y) in enumerate(zip(x_pos[1:], y_pos[1:])): - targ_circ.set_pos((x, y), units='deg') + targ_circ.set_pos((x, y), units="deg") big_circ.draw() targ_circ.draw() ec.flip() - if not el.wait_for_fix([x, y], max_wait=5., units='deg'): - print('Fixation {0} failed'.format(ii + 1)) + if not el.wait_for_fix([x, y], max_wait=5.0, units="deg"): + print(f"Fixation {ii + 1} failed") ec.trial_ok() el.stop() # stop recording to save the file - ec.screen_prompt('All done!', max_wait=1.0) + ec.screen_prompt("All done!", max_wait=1.0) # eyelink auto-closes (el.close()) because it gets registered with EC ea.plot_screen(screenshot) diff --git a/examples/experiments/formatted_text.py b/examples/experiments/formatted_text.py index 4d3d7f66..783733c8 100644 --- a/examples/experiments/formatted_text.py +++ b/examples/experiments/formatted_text.py @@ -15,34 +15,42 @@ print(__doc__) # Colors -blue = _convert_color('#00CEE9') -pink = _convert_color('#FF97AF') +blue = _convert_color("#00CEE9") +pink = _convert_color("#FF97AF") white = (255, 255, 255, 255) # Text -one = ('This text can only have a single color, font, and size for the whole ' - 'sentence, because it is specified as attr=False') -two = ('Additional calls to ec.screen_text() can have different formatting,' - 'but have to be manually positioned.') -thr = ('This text can have {{color {0}}}different {{color {1}}}colors ' - 'speci{{color {2}}}fied inline, because its {{color {0}}}attr ' - '{{color {2}}}argument is {{color {1}}}True. {{color {2}}}' - 'Specifying different typefaces or sizes inline is buggy and ' - 'not recommended.').format(blue, pink, white) -fou = 'Press any key to change all the text to pink using .set_color().' -fiv = 'Press any key to quit.' -max_wait = float('inf') if not building_doc else 0. +one = ( + "This text can only have a single color, font, and size for the whole " + "sentence, because it is specified as attr=False" +) +two = ( + "Additional calls to ec.screen_text() can have different formatting," + "but have to be manually positioned." +) +thr = ( + f"This text can have {{color {blue}}}different {{color {pink}}}colors " + f"speci{{color {white}}}fied inline, because its {{color {blue}}}attr " + f"{{color {white}}}argument is {{color {pink}}}True. {{color {white}}}" + "Specifying different typefaces or sizes inline is buggy and " + "not recommended." +) +fou = "Press any key to change all the text to pink using .set_color()." +fiv = "Press any key to quit." +max_wait = float("inf") if not building_doc else 0.0 -with ExperimentController('textDemo', participant='foo', session='001', - output_dir=None, version='dev') as ec: +with ExperimentController( + "textDemo", participant="foo", session="001", output_dir=None, version="dev" +) as ec: ec.wait_secs(0.1) # without this, first flip doesn't show on some systems txt_one = ec.screen_text(one, pos=[0, 0.5], attr=False) - txt_two = ec.screen_text(two, pos=[0, 0.2], font_name='Times New Roman', - font_size=32, color='#00CEE9') + txt_two = ec.screen_text( + two, pos=[0, 0.2], font_name="Times New Roman", font_size=32, color="#00CEE9" + ) txt_thr = ec.screen_text(thr, pos=[0, -0.2]) screenshot = ec.screenshot() ec.screen_prompt(fou, pos=[0, -0.5], max_wait=max_wait) for txt in (txt_one, txt_two, txt_thr): - txt.set_color('#FF97AF') + txt.set_color("#FF97AF") txt.draw() ec.screen_prompt(fiv, pos=[0, -0.5], max_wait=max_wait) diff --git a/examples/experiments/joystick_experiment.py b/examples/experiments/joystick_experiment.py index 4feb5176..40e1a2a0 100644 --- a/examples/experiments/joystick_experiment.py +++ b/examples/experiments/joystick_experiment.py @@ -21,26 +21,31 @@ noise_thresh = 0.01 # permit slight miscalibration # on a Logitech Cordless Rumblepad, the right stick is the analog one, # and it has values stored in z and rz -joy_keys = ('z', 'rz') -with ExperimentController('joyExp', participant='foo', session='001', - output_dir=None, version='dev', - joystick=joystick) as ec: - circles = [Circle(ec, 0.5, units='deg', - fill_color=(1., 1., 1., 0.2), line_color='w')] +joy_keys = ("z", "rz") +with ExperimentController( + "joyExp", + participant="foo", + session="001", + output_dir=None, + version="dev", + joystick=joystick, +) as ec: + circles = [ + Circle(ec, 0.5, units="deg", fill_color=(1.0, 1.0, 1.0, 0.2), line_color="w") + ] # We use normalized units for "pos" so we need to compensate in movement # so that X/Y movement is even - ratios = [1., ec.window_size_pix[0] / float(ec.window_size_pix[1])] - pressed = '' + ratios = [1.0, ec.window_size_pix[0] / float(ec.window_size_pix[1])] + pressed = "" if not building_doc: ec.listen_joystick_button_presses() count = 0 screenshot = None - pos = [0., 0.] - while pressed != '2': # enable a clean quit (button number 3) + pos = [0.0, 0.0] + while pressed != "2": # enable a clean quit (button number 3) ####################################################################### # Draw things - Text(ec, str(count), pos=(1, -1), - anchor_x='right', anchor_y='bottom').draw() + Text(ec, str(count), pos=(1, -1), anchor_x="right", anchor_y="bottom").draw() for circle in circles[::-1]: circle.draw() screenshot = ec.screenshot() if screenshot is None else screenshot @@ -52,21 +57,21 @@ pressed = ec.get_joystick_button_presses() ec.listen_joystick_button_presses() # clear events else: - pressed = [('2',)] + pressed = [("2",)] count += len(pressed) ####################################################################### # Move the cursor for idx, (key, ratio) in enumerate(zip(joy_keys, ratios)): - delta = 0. if building_doc else ec.get_joystick_value(key) + delta = 0.0 if building_doc else ec.get_joystick_value(key) if abs(delta) > noise_thresh: # remove noise - pos[idx] = max(min( - pos[idx] + move_rate * ratio * delta, 1), -1) - circles[0].set_pos(pos, units='norm') + pos[idx] = max(min(pos[idx] + move_rate * ratio * delta, 1), -1) + circles[0].set_pos(pos, units="norm") if pressed: - circles.insert(1, Circle(ec, 1, units='deg', - fill_color='r', line_color='w')) - circles[1].set_pos(pos, units='norm') + circles.insert( + 1, Circle(ec, 1, units="deg", fill_color="r", line_color="w") + ) + circles[1].set_pos(pos, units="norm") if len(circles) > 5: circles.pop(-1) pressed = pressed[0][0] # for exit condition diff --git a/examples/experiments/keypress.py b/examples/experiments/keypress.py index c0f657aa..c5b28d18 100644 --- a/examples/experiments/keypress.py +++ b/examples/experiments/keypress.py @@ -10,104 +10,119 @@ # # License: BSD (3-clause) -from expyfun import ExperimentController, building_doc import expyfun.analyze as ea +from expyfun import ExperimentController, building_doc print(__doc__) isi = 0.5 -wait_dur = 3.0 if not building_doc else 0. -msg_dur = 3.0 if not building_doc else 0. +wait_dur = 3.0 if not building_doc else 0.0 +msg_dur = 3.0 if not building_doc else 0.0 -with ExperimentController('KeypressDemo', screen_num=0, - window_size=[640, 480], full_screen=False, - stim_db=0, noise_db=0, output_dir=None, - participant='foo', session='001', - version='dev') as ec: +with ExperimentController( + "KeypressDemo", + screen_num=0, + window_size=[640, 480], + full_screen=False, + stim_db=0, + noise_db=0, + output_dir=None, + participant="foo", + session="001", + version="dev", +) as ec: ec.wait_secs(isi) ############### # screen_prompt - pressed = ec.screen_prompt('press any key\n\nscreen_prompt(' - 'max_wait={})'.format(wait_dur), - max_wait=wait_dur, timestamp=True) - ec.write_data_line('screen_prompt', pressed) + pressed = ec.screen_prompt( + "press any key\n\nscreen_prompt(" f"max_wait={wait_dur})", + max_wait=wait_dur, + timestamp=True, + ) + ec.write_data_line("screen_prompt", pressed) if pressed[0] is None: - message = 'no keys pressed' + message = "no keys pressed" else: - message = '{} pressed after {} secs'.format(pressed[0], - round(pressed[1], 4)) + message = f"{pressed[0]} pressed after {round(pressed[1], 4)} secs" ec.screen_prompt(message, msg_dur) ec.wait_secs(isi) ################## # wait_for_presses - ec.screen_text('press some keys\n\nwait_for_presses(max_wait={})' - ''.format(wait_dur)) + ec.screen_text(f"press some keys\n\nwait_for_presses(max_wait={wait_dur})" "") screenshot = ec.screenshot() ec.flip() pressed = ec.wait_for_presses(wait_dur) - ec.write_data_line('wait_for_presses', pressed) + ec.write_data_line("wait_for_presses", pressed) if not len(pressed): - message = 'no keys pressed' + message = "no keys pressed" else: - message = ['{} pressed after {} secs\n' - ''.format(key, round(time, 4)) for key, time in pressed] - message = ''.join(message) + message = [ + f"{key} pressed after {round(time, 4)} secs\n" "" for key, time in pressed + ] + message = "".join(message) ec.screen_prompt(message, msg_dur) ec.wait_secs(isi) ############################################ # wait_for_presses, relative to master clock - ec.screen_text('press some keys\n\nwait_for_presses(max_wait={}, ' - 'relative_to=0.0)'.format(wait_dur)) + ec.screen_text( + f"press some keys\n\nwait_for_presses(max_wait={wait_dur}, " "relative_to=0.0)" + ) ec.flip() pressed = ec.wait_for_presses(wait_dur, relative_to=0.0) - ec.write_data_line('wait_for_presses relative_to 0.0', pressed) + ec.write_data_line("wait_for_presses relative_to 0.0", pressed) if not len(pressed): - message = 'no keys pressed' + message = "no keys pressed" else: - message = ['{} pressed at {} secs\n' - ''.format(key, round(time, 4)) for key, time in pressed] - message = ''.join(message) + message = [ + f"{key} pressed at {round(time, 4)} secs\n" "" for key, time in pressed + ] + message = "".join(message) ec.screen_prompt(message, msg_dur) ec.wait_secs(isi) ########################################## # listen_presses / wait_secs / get_presses - ec.screen_text('press some keys\n\nlisten_presses()\nwait_secs({0})' - '\nget_presses()'.format(wait_dur)) + ec.screen_text( + f"press some keys\n\nlisten_presses()\nwait_secs({wait_dur})" "\nget_presses()" + ) ec.flip() ec.listen_presses() ec.wait_secs(wait_dur) pressed = ec.get_presses() # relative_to=0.0 - ec.write_data_line('listen / wait / get_presses', pressed) + ec.write_data_line("listen / wait / get_presses", pressed) if not len(pressed): - message = 'no keys pressed' + message = "no keys pressed" else: - message = ['{} pressed after {} secs\n' - ''.format(key, round(time, 4)) for key, time in pressed] - message = ''.join(message) + message = [ + f"{key} pressed after {round(time, 4)} secs\n" "" for key, time in pressed + ] + message = "".join(message) ec.screen_prompt(message, msg_dur) ec.wait_secs(isi) #################################################################### # listen_presses / wait_secs / get_presses, relative to master clock - ec.screen_text('press a few keys\n\nlisten_presses()' - '\nwait_secs({0})\nget_presses(relative_to=0.0)' - ''.format(wait_dur)) + ec.screen_text( + "press a few keys\n\nlisten_presses()" + f"\nwait_secs({wait_dur})\nget_presses(relative_to=0.0)" + "" + ) ec.flip() ec.listen_presses() ec.wait_secs(wait_dur) pressed = ec.get_presses(relative_to=0.0) - ec.write_data_line('listen / wait / get_presses relative_to 0.0', pressed) + ec.write_data_line("listen / wait / get_presses relative_to 0.0", pressed) if not len(pressed): - message = 'no keys pressed' + message = "no keys pressed" else: - message = ['{} pressed at {} secs\n' - ''.format(key, round(time, 4)) for key, time in pressed] - message = ''.join(message) + message = [ + f"{key} pressed at {round(time, 4)} secs\n" "" for key, time in pressed + ] + message = "".join(message) ec.screen_prompt(message, msg_dur) ec.wait_secs(isi) @@ -116,25 +131,29 @@ disp_time = wait_dur countdown = ec.current_time + disp_time ec.call_on_next_flip(ec.listen_presses) - ec.screen_text('press some keys\n\nlisten_presses()' - '\nwhile loop {}\nget_presses()'.format(disp_time)) + ec.screen_text( + "press some keys\n\nlisten_presses()" f"\nwhile loop {disp_time}\nget_presses()" + ) ec.flip() while ec.current_time < countdown: cur_time = round(countdown - ec.current_time, 1) if cur_time != disp_time: disp_time = cur_time # redraw text with updated disp_time - ec.screen_text('press some keys\n\nlisten_presses() ' - '\nwhile loop {}\nget_presses()'.format(disp_time)) + ec.screen_text( + "press some keys\n\nlisten_presses() " + f"\nwhile loop {disp_time}\nget_presses()" + ) ec.flip() pressed = ec.get_presses() - ec.write_data_line('listen / while / get_presses', pressed) + ec.write_data_line("listen / while / get_presses", pressed) if not len(pressed): - message = 'no keys pressed' + message = "no keys pressed" else: - message = ['{} pressed after {} secs\n' - ''.format(key, round(time, 4)) for key, time in pressed] - message = ''.join(message) + message = [ + f"{key} pressed after {round(time, 4)} secs\n" "" for key, time in pressed + ] + message = "".join(message) ec.screen_prompt(message, msg_dur) ec.wait_secs(isi) @@ -143,26 +162,31 @@ disp_time = wait_dur countdown = ec.current_time + disp_time ec.call_on_next_flip(ec.listen_presses) - ec.screen_text('press some keys\n\nlisten_presses()\nwhile loop ' - '{}\nget_presses(relative_to=0.0)'.format(disp_time)) + ec.screen_text( + "press some keys\n\nlisten_presses()\nwhile loop " + f"{disp_time}\nget_presses(relative_to=0.0)" + ) ec.flip() while ec.current_time < countdown: cur_time = round(countdown - ec.current_time, 1) if cur_time != disp_time: disp_time = cur_time # redraw text with updated disp_time - ec.screen_text('press some keys\n\nlisten_presses()\nwhile ' - 'loop {}\nget_presses(relative_to=0.0)' - ''.format(disp_time)) + ec.screen_text( + "press some keys\n\nlisten_presses()\nwhile " + f"loop {disp_time}\nget_presses(relative_to=0.0)" + "" + ) ec.flip() pressed = ec.get_presses(relative_to=0.0) - ec.write_data_line('listen / while / get_presses relative_to 0.0', pressed) + ec.write_data_line("listen / while / get_presses relative_to 0.0", pressed) if not len(pressed): - message = 'no keys pressed' + message = "no keys pressed" else: - message = ['{} pressed at {} secs\n' - ''.format(key, round(time, 4)) for key, time in pressed] - message = ''.join(message) + message = [ + f"{key} pressed at {round(time, 4)} secs\n" "" for key, time in pressed + ] + message = "".join(message) ec.screen_prompt(message, msg_dur) ea.plot_screen(screenshot) diff --git a/examples/experiments/keyrelease.py b/examples/experiments/keyrelease.py index 88033e8a..2c9be464 100644 --- a/examples/experiments/keyrelease.py +++ b/examples/experiments/keyrelease.py @@ -20,26 +20,37 @@ # # License: BSD (3-clause) -from expyfun import ExperimentController, building_doc, analyze as ea +from expyfun import ExperimentController, building_doc +from expyfun import analyze as ea print(__doc__) isi = 0.5 -wait_dur = 3.0 if not building_doc else 0. -msg_dur = 3.0 if not building_doc else 0. +wait_dur = 3.0 if not building_doc else 0.0 +msg_dur = 3.0 if not building_doc else 0.0 -with ExperimentController('KeyPressAndReleaseDemo', screen_num=0, - window_size=[1280, 960], full_screen=False, - stim_db=0, noise_db=0, output_dir=None, - participant='foo', session='001', - version='dev', response_device='keyboard') as ec: +with ExperimentController( + "KeyPressAndReleaseDemo", + screen_num=0, + window_size=[1280, 960], + full_screen=False, + stim_db=0, + noise_db=0, + output_dir=None, + participant="foo", + session="001", + version="dev", + response_device="keyboard", +) as ec: ec.wait_secs(isi) ########################################### # listen_presses / while loop / get_presses(kind='both') - instruction = ("Press and release some keys\n\nlisten_presses()" - "\nwhile loop {}\n" - "get_presses(kind='both', return_kinds=True)") + instruction = ( + "Press and release some keys\n\nlisten_presses()" + "\nwhile loop {}\n" + "get_presses(kind='both', return_kinds=True)" + ) disp_time = wait_dur countdown = ec.current_time + disp_time ec.call_on_next_flip(ec.listen_presses) @@ -53,14 +64,13 @@ # redraw text with updated disp_time ec.screen_text(instruction.format(disp_time)) ec.flip() - events = ec.get_presses(kind='both', return_kinds=True) - ec.write_data_line('listen / while / get_presses', events) + events = ec.get_presses(kind="both", return_kinds=True) + ec.write_data_line("listen / while / get_presses", events) if not len(events): - message = 'no keys pressed' + message = "no keys pressed" else: - message = ['{} {} after {} secs\n' - ''.format(k, r, round(t, 4)) for k, t, r in events] - message = ''.join(message) + message = [f"{k} {r} after {round(t, 4)} secs\n" "" for k, t, r in events] + message = "".join(message) ec.screen_prompt(message, msg_dur) ec.wait_secs(isi) diff --git a/examples/experiments/level_test.py b/examples/experiments/level_test.py index 1f035878..e3cb616f 100644 --- a/examples/experiments/level_test.py +++ b/examples/experiments/level_test.py @@ -16,35 +16,47 @@ import numpy as np +import expyfun.analyze as ea from expyfun import ExperimentController, building_doc from expyfun.visual import Rectangle -import expyfun.analyze as ea print(__doc__) -with ExperimentController('LevelTest', full_screen=True, noise_db=-np.inf, - participant='s', session='0', output_dir=None, - suppress_resamp=True, check_rms=None, - stim_db=80, version='dev') as ec: - tone = (0.01 * np.sqrt(2.) * - np.sin(2 * np.pi * 1000. * np.arange(0, 10, 1. / ec.fs))) +with ExperimentController( + "LevelTest", + full_screen=True, + noise_db=-np.inf, + participant="s", + session="0", + output_dir=None, + suppress_resamp=True, + check_rms=None, + stim_db=80, + version="dev", +) as ec: + tone = ( + 0.01 * np.sqrt(2.0) * np.sin(2 * np.pi * 1000.0 * np.arange(0, 10, 1.0 / ec.fs)) + ) assert np.allclose(np.sqrt(np.mean(tone * tone)), 0.01) - square = Rectangle(ec, (0, 0, 10, 10), units='deg', fill_color='r') - cm = np.diff(ec._convert_units([[0, 5], [0, 5]], 'deg', 'pix'), - axis=-1)[0] / ec.dpi / 0.39370 + square = Rectangle(ec, (0, 0, 10, 10), units="deg", fill_color="r") + cm = ( + np.diff(ec._convert_units([[0, 5], [0, 5]], "deg", "pix"), axis=-1)[0] + / ec.dpi + / 0.39370 + ) ec.load_buffer(tone) # RMS == 0.01 pressed = None screenshot = None - while pressed != '8': # enable a clean quit if required + while pressed != "8": # enable a clean quit if required square.draw() - ec.screen_text('Width: {} cm'.format(np.round(2 * cm, 1)), wrap=False) - ec.screen_text('Output level: {} dB'.format(ec.stim_db), wrap=True) + ec.screen_text(f"Width: {np.round(2 * cm, 1)} cm", wrap=False) + ec.screen_text(f"Output level: {ec.stim_db} dB", wrap=True) screenshot = ec.screenshot() if screenshot is None else screenshot t1 = ec.start_stimulus(start_of_trial=False) # skip checks - pressed = ec.wait_one_press(10)[0] if not building_doc else '8' + pressed = ec.wait_one_press(10)[0] if not building_doc else "8" ec.flip() - ec.wait_one_press(0.5 if not building_doc else 0.) + ec.wait_one_press(0.5 if not building_doc else 0.0) ec.stop() ea.plot_screen(screenshot) diff --git a/examples/experiments/mouse.py b/examples/experiments/mouse.py index 619bef95..b5a6e194 100644 --- a/examples/experiments/mouse.py +++ b/examples/experiments/mouse.py @@ -10,65 +10,73 @@ # # License: BSD (3-clause) -from expyfun import ExperimentController, building_doc import expyfun.analyze as ea -from expyfun.visual import (Circle, Rectangle, Diamond, ConcentricCircles, - FixationDot) +from expyfun import ExperimentController, building_doc +from expyfun.visual import Circle, ConcentricCircles, Diamond, FixationDot, Rectangle print(__doc__) -wait_dur = 3.0 if not building_doc else 0. -msg_dur = 1.5 if not building_doc else 0. -max_wait = float('inf') if not building_doc else 0. +wait_dur = 3.0 if not building_doc else 0.0 +msg_dur = 1.5 if not building_doc else 0.0 +max_wait = float("inf") if not building_doc else 0.0 -with ExperimentController('MouseDemo', screen_num=0, - window_size=[640, 480], full_screen=False, - stim_db=0, noise_db=0, output_dir=None, - participant='foo', session='001', - version='dev') as ec: +with ExperimentController( + "MouseDemo", + screen_num=0, + window_size=[640, 480], + full_screen=False, + stim_db=0, + noise_db=0, + output_dir=None, + participant="foo", + session="001", + version="dev", +) as ec: ################################# # toggle_cursor and move_mouse_to ec.toggle_cursor(True) ec.move_mouse_to((0, 0)) - ec.screen_prompt('Now you see it (centered on the window).', - max_wait=msg_dur, wrap=False) + ec.screen_prompt( + "Now you see it (centered on the window).", max_wait=msg_dur, wrap=False + ) ec.toggle_cursor(False) - ec.screen_prompt("Now you don't (maybe--Windows is buggy)", - max_wait=msg_dur, wrap=False) + ec.screen_prompt( + "Now you don't (maybe--Windows is buggy)", max_wait=msg_dur, wrap=False + ) ec.toggle_cursor(True) ################ # wait_one_click - ec.screen_text('Press any mouse button.', wrap=False) + ec.screen_text("Press any mouse button.", wrap=False) ec.flip() ec.wait_one_click(max_wait=max_wait) ec.toggle_cursor(False) - ec.screen_text('Press the left button.', wrap=False) + ec.screen_text("Press the left button.", wrap=False) ec.flip() - ec.wait_one_click(live_buttons=['left'], visible=True, max_wait=max_wait) + ec.wait_one_click(live_buttons=["left"], visible=True, max_wait=max_wait) ec.wait_secs(0.5) ec.toggle_cursor(True) ########################### # listen_clicks, get_clicks - ec.screen_text('Press a few buttons in a row.', wrap=False) + ec.screen_text("Press a few buttons in a row.", wrap=False) ec.flip() ec.listen_clicks() ec.wait_secs(wait_dur) clicks = ec.get_clicks() - ec.screen_prompt('Your clicks:\n%s' % str(clicks), max_wait=msg_dur) + ec.screen_prompt("Your clicks:\n%s" % str(clicks), max_wait=msg_dur) #################### # get_mouse_position - ec.screen_prompt('Move the mouse around...', max_wait=msg_dur, wrap=False) + ec.screen_prompt("Move the mouse around...", max_wait=msg_dur, wrap=False) stop_time = ec.current_time + wait_dur while ec.current_time < stop_time: - ec.screen_text('%i, %i' % tuple([p for p in - ec.get_mouse_position()]), - wrap=False) + ec.screen_text( + "%i, %i" % tuple([p for p in ec.get_mouse_position()]), wrap=False + ) ec.check_force_quit() ec.flip() @@ -76,15 +84,16 @@ # wait_for_click_on ec.toggle_cursor(False) ec.wait_secs(1) - c = Circle(ec, 150, units='pix') - r = Rectangle(ec, (0.5, 0.5, 0.2, 0.2), units='norm', fill_color='r') - cc = ConcentricCircles(ec, pos=[0.6, -0.4], - colors=[[0.2, 0.2, 0.2], [0.6, 0.6, 0.6]]) - d = Diamond(ec, (-0.5, 0.5, 0.4, 0.25), fill_color='b') + c = Circle(ec, 150, units="pix") + r = Rectangle(ec, (0.5, 0.5, 0.2, 0.2), units="norm", fill_color="r") + cc = ConcentricCircles( + ec, pos=[0.6, -0.4], colors=[[0.2, 0.2, 0.2], [0.6, 0.6, 0.6]] + ) + d = Diamond(ec, (-0.5, 0.5, 0.4, 0.25), fill_color="b") dot = FixationDot(ec) objects = [c, r, cc, d, dot] - ec.screen_prompt('Click on some objects...', max_wait=msg_dur, wrap=False) + ec.screen_prompt("Click on some objects...", max_wait=msg_dur, wrap=False) for ti in range(3): for o in objects: o.draw() diff --git a/examples/experiments/progress_bar.py b/examples/experiments/progress_bar.py index a1e0c3f2..ec6c62c0 100644 --- a/examples/experiments/progress_bar.py +++ b/examples/experiments/progress_bar.py @@ -1,5 +1,4 @@ #!/usr/bin/env python2 -# -*- coding: utf-8 -*- """ ================ ProgressBar demo @@ -8,24 +7,34 @@ This example shows how to display progress between trials using :class:`expyfun.visual.ProgressBar`. """ + +import numpy as np + +import expyfun.analyze as ea from expyfun import ExperimentController, building_doc from expyfun.visual import ProgressBar -import expyfun.analyze as ea -import numpy as np n_trials = 6 max_wait = 0.1 if building_doc else np.inf wait_dur = 0.1 if building_doc else 0.5 -with ExperimentController('name', version='dev', window_size=[800, 600], - full_screen=False, session='foo', - participant='foo') as ec: - +with ExperimentController( + "name", + version="dev", + window_size=[800, 600], + full_screen=False, + session="foo", + participant="foo", +) as ec: # initialize the progress bar - pb = ProgressBar(ec, [0, -.1, 1.5, .1], units='norm') + pb = ProgressBar(ec, [0, -0.1, 1.5, 0.1], units="norm") - ec.screen_prompt('Press the number shown on the screen. Start by pressing' - ' 1.', font_size=16, live_keys=[1], max_wait=max_wait) + ec.screen_prompt( + "Press the number shown on the screen. Start by pressing" " 1.", + font_size=16, + live_keys=[1], + max_wait=max_wait, + ) for n in np.arange(n_trials) + 1: # subject does some task @@ -41,16 +50,19 @@ percent = int(n * 100 / n_trials) pb.update_bar(percent) # display the progress bar with some text - ec.screen_text('You\'ve completed {} %. Press any key to proceed.' - ''.format(percent), [0, .1], wrap=False, - font_size=16) + ec.screen_text( + f"You've completed {percent} %. Press any key to proceed." "", + [0, 0.1], + wrap=False, + font_size=16, + ) pb.draw() if n == 4: screenshot = ec.screenshot() ec.flip() # subject uses any key press to proceed ec.wait_one_press(max_wait=max_wait) - ec.screen_text('This example is complete.') + ec.screen_text("This example is complete.") ec.flip() ec.wait_secs(1) diff --git a/examples/experiments/pupillometry_experiment_.py b/examples/experiments/pupillometry_experiment_.py index 0d3c09e5..0ba3dc91 100644 --- a/examples/experiments/pupillometry_experiment_.py +++ b/examples/experiments/pupillometry_experiment_.py @@ -10,43 +10,50 @@ # # License: BSD (3-clause) -import numpy as np import matplotlib.pyplot as plt +import numpy as np from expyfun import ExperimentController, EyelinkController -from expyfun.codeblocks import (find_pupil_dynamic_range, - find_pupil_tone_impulse_response) +from expyfun.codeblocks import ( + find_pupil_dynamic_range, + find_pupil_tone_impulse_response, +) print(__doc__) -with ExperimentController('pupilExp', full_screen=True, participant='foo', - session='001', output_dir=None, version='dev') as ec: +with ExperimentController( + "pupilExp", + full_screen=True, + participant="foo", + session="001", + output_dir=None, + version="dev", +) as ec: el = EyelinkController(ec) bgcolor, fcolor, lev, resp = find_pupil_dynamic_range(ec, el) - prf, t_srf, e_prf = find_pupil_tone_impulse_response(ec, el, bgcolor, - fcolor) + prf, t_srf, e_prf = find_pupil_tone_impulse_response(ec, el, bgcolor, fcolor) uni_lev = np.unique(lev) uni_lev_label = (255 * uni_lev).astype(int) -uni_lev[uni_lev == 0] = np.sort(uni_lev)[1] / 2. +uni_lev[uni_lev == 0] = np.sort(uni_lev)[1] / 2.0 r = resp.reshape((len(lev) // len(uni_lev), len(uni_lev))) r_span = [r.min(), r.max()] # Grayscale responses -ax = plt.subplot(2, 1, 1, xlabel='Screen level', ylabel='Pupil dilation (AU)') -ax.plot([bgcolor, bgcolor], r_span, linestyle='--', color='r') -ax.fill_between(uni_lev, np.min(r, 0), np.max(r, 0), facecolor=(1, 1, 0), - edgecolor='none') -ax.semilogx(uni_lev, np.mean(r, 0), color='k') +ax = plt.subplot(2, 1, 1, xlabel="Screen level", ylabel="Pupil dilation (AU)") +ax.plot([bgcolor, bgcolor], r_span, linestyle="--", color="r") +ax.fill_between( + uni_lev, np.min(r, 0), np.max(r, 0), facecolor=(1, 1, 0), edgecolor="none" +) +ax.semilogx(uni_lev, np.mean(r, 0), color="k") ax.set_xlim(uni_lev[[0, -1]]) ax.set_ylim(r_span) plt.xticks(uni_lev, uni_lev_label) # PRF -ax = plt.subplot(2, 1, 2, xlabel='Time (s)', ylabel='Pupil response (AU)') -ax.fill_between(t_srf, prf - e_prf, prf + e_prf, facecolor=(1, 1, 0), - edgecolor='none') -ax.plot(t_srf, prf, color='k') +ax = plt.subplot(2, 1, 2, xlabel="Time (s)", ylabel="Pupil response (AU)") +ax.fill_between(t_srf, prf - e_prf, prf + e_prf, facecolor=(1, 1, 0), edgecolor="none") +ax.plot(t_srf, prf, color="k") ax.set_xlim(t_srf[[0, -1]]) plt.tight_layout() plt.show() diff --git a/examples/experiments/tracker_dealer.py b/examples/experiments/tracker_dealer.py index 174a143a..a64f3189 100644 --- a/examples/experiments/tracker_dealer.py +++ b/examples/experiments/tracker_dealer.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ ========================================================================== Adaptive tracking for two trial types and tracker reconstruction from .tab @@ -15,12 +14,13 @@ @author: maddycapp27 """ +import matplotlib.pyplot as plt import numpy as np + from expyfun import ExperimentController -from expyfun.stimuli import TrackerUD, TrackerDealer from expyfun.analyze import sigmoid from expyfun.io import reconstruct_dealer -import matplotlib.pyplot as plt +from expyfun.stimuli import TrackerDealer, TrackerUD # define parameters of modeled subject (using sigmoid probability) true_thresh = [30, 40] # true thresholds for trial types 1 and 2 @@ -45,13 +45,13 @@ stop_trials = np.inf start_value = 45 change_indices = [5] -change_rule = 'reversals' +change_rule = "reversals" x_min = 0 x_max = 90 # parameters for the tracker dealer max_lag = 2 -pace_rule = 'reversals' +pace_rule = "reversals" rng_dealer = np.random.RandomState(3) # random seed to select trial type ############################################################################### @@ -63,18 +63,37 @@ # for that trial can be acquired. :class:`expyfun.ExperimentController` is used # to generate log files with :class:`expyfun.stimuli.TrackerUD` and # :class:`expyfun.stimuli.TrackerDealer` information. -std_args = ['test'] # experiment name -std_kwargs = dict(full_screen=False, window_size=(1, 1), participant='foo', - session='01', stim_db=0.0, noise_db=0.0, verbose=True, - version='dev') +std_args = ["test"] # experiment name +std_kwargs = dict( + full_screen=False, + window_size=(1, 1), + participant="foo", + session="01", + stim_db=0.0, + noise_db=0.0, + verbose=True, + version="dev", +) with ExperimentController(*std_args, **std_kwargs) as ec: - # initialize two tracker objects--one for each trial type - tr_ud = [TrackerUD(ec, up, down, step_size_up, step_size_down, - stop_reversals, stop_trials, start_value, - change_indices, change_rule, x_min, - x_max) for _ in range(2)] + tr_ud = [ + TrackerUD( + ec, + up, + down, + step_size_up, + step_size_down, + stop_reversals, + stop_trials, + start_value, + change_indices, + change_rule, + x_min, + x_max, + ) + for _ in range(2) + ] # initialize TrackerDealer object td = TrackerDealer(ec, tr_ud, max_lag, pace_rule, rng_dealer) @@ -85,8 +104,10 @@ for ss, level in td: # Get information of which trial type is next and what the level is at # that time from TrackerDealer - td.respond(rng_human.rand() < sigmoid(level - true_thresh[sum(ss)], - lower=chance, slope=slope)) + td.respond( + rng_human.rand() + < sigmoid(level - true_thresh[sum(ss)], lower=chance, slope=slope) + ) ############################################################################### # Reconstructing the TrackerDealer Object @@ -107,7 +128,9 @@ for i in [0, 1]: fig, ax, lines = td_tab.trackers.ravel()[i].plot(ax=axes[i], n_skip=4) - ax.legend(loc='best') - ax.set_title('Adaptive track of model human trial type {} (true threshold ' - 'is {})'.format(i + 1, true_thresh[i])) + ax.legend(loc="best") + ax.set_title( + f"Adaptive track of model human trial type {i + 1} (true threshold " + f"is {true_thresh[i]})" + ) fig.tight_layout() diff --git a/examples/experiments/tracker_dealer_doublesided.py b/examples/experiments/tracker_dealer_doublesided.py index aed5a7a9..f3f04b10 100644 --- a/examples/experiments/tracker_dealer_doublesided.py +++ b/examples/experiments/tracker_dealer_doublesided.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ ====================================== Adaptive tracking from above and below @@ -11,10 +10,11 @@ @author: maddycapp27 """ +import matplotlib.pyplot as plt import numpy as np -from expyfun.stimuli import TrackerUD, TrackerDealer + from expyfun.analyze import sigmoid -import matplotlib.pyplot as plt +from expyfun.stimuli import TrackerDealer, TrackerUD # define parameters of modeled subject (using sigmoid probability) true_thresh = 30 # true thresholds for trial types 1 and 2 @@ -40,19 +40,19 @@ stop_trials = np.inf start_value = [15, 45] change_indices = [5] -change_rule = 'reversals' +change_rule = "reversals" x_min = 0 x_max = 90 # callback function that prints to console def callback(event_type, value=None, timestamp=None): - print((str(event_type) + ':').ljust(40) + str(value)) + print((str(event_type) + ":").ljust(40) + str(value)) # parameters for the tracker dealer max_lag = 2 -pace_rule = 'reversals' +pace_rule = "reversals" rng_dealer = np.random.RandomState(4) # random seed for selecting trial type ############################################################################## @@ -65,9 +65,23 @@ def callback(event_type, value=None, timestamp=None): # acquired. # initialize two tracker objects--one for each start value -tr_ud = [TrackerUD(callback, up, down, step_size_up, step_size_down, - stop_reversals, stop_trials, sv, change_indices, - change_rule, x_min, x_max) for sv in start_value] +tr_ud = [ + TrackerUD( + callback, + up, + down, + step_size_up, + step_size_down, + stop_reversals, + stop_trials, + sv, + change_indices, + change_rule, + x_min, + x_max, + ) + for sv in start_value +] # initialize TrackerDealer object td = TrackerDealer(callback, tr_ud, max_lag, pace_rule, rng_dealer) @@ -78,8 +92,9 @@ def callback(event_type, value=None, timestamp=None): for _, level in td: # Get information of which trial type is next and what the level is at # that time from TrackerDealer - td.respond(rng_human.rand() < sigmoid(level - true_thresh, - lower=chance, slope=slope)) + td.respond( + rng_human.rand() < sigmoid(level - true_thresh, lower=chance, slope=slope) + ) ############################################################################## # Plotting the Results @@ -88,7 +103,9 @@ def callback(event_type, value=None, timestamp=None): for i in [0, 1]: fig, ax, lines = td.trackers.ravel()[i].plot(ax=axes[i], n_skip=4) - ax.legend(loc='best') - ax.set_title('Adaptive track with start value {} (true threshold ' - 'is {})'.format(start_value[i], true_thresh)) + ax.legend(loc="best") + ax.set_title( + f"Adaptive track with start value {start_value[i]} (true threshold " + f"is {true_thresh})" + ) fig.tight_layout() diff --git a/examples/experiments/version_checking_.py b/examples/experiments/version_checking_.py index c7ed25d8..78df6e82 100644 --- a/examples/experiments/version_checking_.py +++ b/examples/experiments/version_checking_.py @@ -22,7 +22,7 @@ # directory so we don't break any other code examples, but usually you'd # want to do it in the experiment directory: temp_dir = tempfile.mkdtemp() -download_version('c18133c', temp_dir) +download_version("c18133c", temp_dir) # Now we would normally need to restart Python so the next ``import expyfun`` # call imported the proper version. We'd want to add an ``assert_version`` @@ -36,11 +36,11 @@ assert_version('c18133c') """ try: - run_subprocess(['python', '-c', cmd], cwd=temp_dir) + run_subprocess(["python", "-c", cmd], cwd=temp_dir) except Exception as exp: - print('Failure: {0}'.format(exp)) + print(f"Failure: {exp}") else: - print('Success!') + print("Success!") # Try modifying the commit number to something invalid, and you should # see a failure. diff --git a/examples/generate_simple_stimuli.py b/examples/generate_simple_stimuli.py index 8b4b8683..a3c20b79 100644 --- a/examples/generate_simple_stimuli.py +++ b/examples/generate_simple_stimuli.py @@ -8,8 +8,9 @@ """ from os import path as op -import numpy as np + import matplotlib.pyplot as plt +import numpy as np from expyfun.io import write_hdf5, write_wav from expyfun.stimuli import play_sound @@ -17,9 +18,18 @@ print(__doc__) -def generate_stimuli(num_trials=10, num_freqs=4, stim_dur=0.5, min_freq=500.0, - max_freq=4000.0, fs=24414.0625, rms=0.01, output_dir='.', - save_as='hdf5', rand_seed=0): +def generate_stimuli( + num_trials=10, + num_freqs=4, + stim_dur=0.5, + min_freq=500.0, + max_freq=4000.0, + fs=24414.0625, + rms=0.01, + output_dir=".", + save_as="hdf5", + rand_seed=0, +): """Make some sine waves and save in various formats. Optimized for saving as MAT files, but can also save directly as WAV files, or can return a python dictionary with sinewave data as values. @@ -65,42 +75,47 @@ def generate_stimuli(num_trials=10, num_freqs=4, stim_dur=0.5, min_freq=500.0, rng = np.random.RandomState(rand_seed) # check input arguments - if save_as is not None and save_as not in ['dict', 'wav', 'hdf5']: + if save_as is not None and save_as not in ["dict", "wav", "hdf5"]: raise ValueError('"save_as" must be "dict", "wav", or "hdf5"') fs = float(fs) t = np.arange(np.round(stim_dur * fs)) / fs # frequencies equally spaced on a log-2 scale - freqs = min_freq * np.logspace(0, np.log2(max_freq / float(min_freq)), - num_freqs, endpoint=True, base=2) + freqs = min_freq * np.logspace( + 0, np.log2(max_freq / float(min_freq)), num_freqs, endpoint=True, base=2 + ) # strings for the filenames / dictionary keys freq_names = [str(int(f)) for f in freqs] - names = ['stim_%s_%s' % (n, f) for n, f in enumerate(freq_names)] + names = ["stim_%s_%s" % (n, f) for n, f in enumerate(freq_names)] # generate sinewaves & RMS normalize wavs = [np.sin(2 * np.pi * f * t) for f in freqs] - wavs = [rms / np.sqrt(np.mean(w ** 2)) * w for w in wavs] + wavs = [rms / np.sqrt(np.mean(w**2)) * w for w in wavs] # collect into dictionary & save wav_dict = {n: w for (n, w) in zip(names, wavs)} - if save_as == 'hdf5': + if save_as == "hdf5": num_reps = num_trials // num_freqs + 1 trials = np.tile(range(num_freqs), num_reps) trial_order = rng.permutation(trials[0:num_trials]) - wav_dict.update({'trial_order': trial_order, 'freqs': freqs, 'fs': fs, - 'rms': rms}) - write_hdf5(op.join(output_dir, 'equally_spaced_sinewaves.hdf5'), - wav_dict, overwrite=True) - elif save_as == 'wav': + wav_dict.update( + {"trial_order": trial_order, "freqs": freqs, "fs": fs, "rms": rms} + ) + write_hdf5( + op.join(output_dir, "equally_spaced_sinewaves.hdf5"), + wav_dict, + overwrite=True, + ) + elif save_as == "wav": for n in names: - write_wav(op.join(output_dir, n + '.wav'), wav_dict[n], int(fs)) + write_wav(op.join(output_dir, n + ".wav"), wav_dict[n], int(fs)) return wav_dict -if __name__ == '__main__': +if __name__ == "__main__": wav_dict = generate_stimuli(save_as=None) - plt.plot(wav_dict['stim_0_500'][:1000]) - play_sound(wav_dict['stim_0_500']) + plt.plot(wav_dict["stim_0_500"][:1000]) + play_sound(wav_dict["stim_0_500"]) plt.show() diff --git a/examples/simple_experiment.py b/examples/simple_experiment.py index 5b8278e0..54804e8a 100644 --- a/examples/simple_experiment.py +++ b/examples/simple_experiment.py @@ -13,16 +13,21 @@ import os import sys from os import path as op + import numpy as np -from expyfun import (ExperimentController, get_keyboard_input, set_log_level, - building_doc) -from expyfun.io import read_hdf5 import expyfun.analyze as ea +from expyfun import ( + ExperimentController, + building_doc, + get_keyboard_input, + set_log_level, +) +from expyfun.io import read_hdf5 print(__doc__) -set_log_level('INFO') +set_log_level("INFO") # set configuration noise_db = 45 # dB for background noise @@ -35,41 +40,54 @@ running_total = 0 # make the stimuli if necessary and then load them -fname = 'equally_spaced_sinewaves.hdf5' +fname = "equally_spaced_sinewaves.hdf5" if not op.isfile(fname): # This sys.path wrangling is only necessary for Sphinx automatic # documentation building sys.path.insert(0, os.getcwd()) from generate_simple_stimuli import generate_stimuli + generate_stimuli() stims = read_hdf5(fname) -orig_rms = stims['rms'] -freqs = stims['freqs'] -fs = stims['fs'] -trial_order = stims['trial_order'] +orig_rms = stims["rms"] +freqs = stims["freqs"] +fs = stims["fs"] +trial_order = stims["trial_order"] num_trials = len(trial_order) num_freqs = len(freqs) if num_freqs > 8: - raise RuntimeError('Too many frequencies, not enough buttons.') + raise RuntimeError("Too many frequencies, not enough buttons.") # keep only sinusoids, order low-high, convert to list of arrays -wavs = [stims[k] for k in sorted(stims.keys()) if k.startswith('stim_')] +wavs = [stims[k] for k in sorted(stims.keys()) if k.startswith("stim_")] # instructions -instructions = ('You will hear tones at {0} different frequencies. Your job is' - ' to press the button corresponding to that frequency. Please ' - 'press buttons 1-{0} now to hear each tone.').format(num_freqs) - -instr_finished = ('Okay, now press any of those buttons to start the real ' - 'thing. There will be background noise.') - -with ExperimentController('testExp', verbose=True, screen_num=0, - window_size=[800, 600], full_screen=False, - stim_db=stim_db, noise_db=noise_db, stim_fs=fs, - participant='foo', session='001', - version='dev', output_dir=None) as ec: - +instructions = ( + f"You will hear tones at {num_freqs} different frequencies. Your job is" + " to press the button corresponding to that frequency. Please " + f"press buttons 1-{num_freqs} now to hear each tone." +) + +instr_finished = ( + "Okay, now press any of those buttons to start the real " + "thing. There will be background noise." +) + +with ExperimentController( + "testExp", + verbose=True, + screen_num=0, + window_size=[800, 600], + full_screen=False, + stim_db=stim_db, + noise_db=noise_db, + stim_fs=fs, + participant="foo", + session="001", + version="dev", + output_dir=None, +) as ec: # define usable buttons / keys live_keys = [x + 1 for x in range(num_freqs)] @@ -80,8 +98,7 @@ max_wait = max_resp_time = min_resp_time = train = feedback_dur = 0 long_resp_time = 0 else: - train = get_keyboard_input('Run training (0=no, 1=yes [default]): ', - 1, int) + train = get_keyboard_input("Run training (0=no, 1=yes [default]): ", 1, int) ec.set_visible(True) if train: @@ -108,72 +125,75 @@ ec.wait_secs(isi) ec.call_on_next_flip(ec.start_noise()) - ec.screen_text('OK, here we go!', wrap=False) + ec.screen_text("OK, here we go!", wrap=False) screenshot = ec.screenshot() ec.wait_one_press(max_wait=feedback_dur, live_keys=None) ec.wait_secs(isi) single_trial_order = trial_order[range(len(trial_order) // 2)] - mass_trial_order = trial_order[len(trial_order) // 2:] + mass_trial_order = trial_order[len(trial_order) // 2 :] # run the single-tone trials for stim_num in single_trial_order: ec.load_buffer(wavs[stim_num]) ec.identify_trial(ec_id=stim_num, ttl_id=[0, 0]) - ec.write_data_line('one-tone trial', stim_num + 1) + ec.write_data_line("one-tone trial", stim_num + 1) ec.start_stimulus() - pressed, timestamp = ec.wait_one_press(max_resp_time, min_resp_time, - live_keys) + pressed, timestamp = ec.wait_one_press(max_resp_time, min_resp_time, live_keys) ec.stop() # will stop stim playback as soon as response logged ec.trial_ok() # some feedback if pressed is None: - message = 'Too slow!' + message = "Too slow!" elif int(pressed) == stim_num + 1: running_total += 1 - message = ('Correct! Your reaction time was ' - '{}').format(round(timestamp, 3)) + message = "Correct! Your reaction time was " f"{round(timestamp, 3)}" else: - message = ('You pressed {0}, the correct answer was ' - '{1}.').format(pressed, stim_num + 1) + message = ( + f"You pressed {pressed}, the correct answer was " f"{stim_num + 1}." + ) ec.screen_prompt(message, max_wait=feedback_dur) ec.wait_secs(isi) # create 100 ms pause to play between stims and concatenate pause = np.zeros(int(ec.fs / 10)) concat_wavs = wavs[mass_trial_order[0]] - for num in mass_trial_order[1:len(mass_trial_order)]: + for num in mass_trial_order[1 : len(mass_trial_order)]: concat_wavs = np.r_[concat_wavs, pause, wavs[num]] concat_dur = len(concat_wavs) / float(ec.fs) # run mass trial - ec.screen_prompt('Now you will hear {0} tones in a row. After they stop, ' - 'wait for the "Go!" prompt, then you will have {1} ' - 'seconds to push the buttons in the order that the tones ' - 'played in. Press one of the buttons to begin.' - ''.format(len(mass_trial_order), max_resp_time), - live_keys=live_keys, max_wait=max_wait) + ec.screen_prompt( + f"Now you will hear {len(mass_trial_order)} tones in a row. After they stop, " + f'wait for the "Go!" prompt, then you will have {max_resp_time} ' + "seconds to push the buttons in the order that the tones " + "played in. Press one of the buttons to begin." + "", + live_keys=live_keys, + max_wait=max_wait, + ) ec.load_buffer(concat_wavs) - ec.identify_trial(ec_id='multi-tone', ttl_id=[0, 1]) - ec.write_data_line('multi-tone trial', [x + 1 for x in mass_trial_order]) + ec.identify_trial(ec_id="multi-tone", ttl_id=[0, 1]) + ec.write_data_line("multi-tone trial", [x + 1 for x in mass_trial_order]) ec.start_stimulus() - ec.wait_secs(len(concat_wavs) / float(ec.stim_fs) if not building_doc else - 0) - ec.screen_text('Go!', wrap=False) + ec.wait_secs(len(concat_wavs) / float(ec.stim_fs) if not building_doc else 0) + ec.screen_text("Go!", wrap=False) ec.flip() - pressed = ec.wait_for_presses(long_resp_time, min_resp_time, - live_keys, False) + pressed = ec.wait_for_presses(long_resp_time, min_resp_time, live_keys, False) answers = [str(x + 1) for x in mass_trial_order] correct = [press == ans for press, ans in zip(pressed, answers)] running_total += sum(correct) ec.call_on_next_flip(ec.stop_noise()) - ec.screen_prompt('You got {0} out of {1} correct.' - ''.format(sum(correct), len(answers)), - max_wait=feedback_dur) + ec.screen_prompt( + f"You got {sum(correct)} out of {len(answers)} correct." "", + max_wait=feedback_dur, + ) ec.trial_ok() # end experiment - ec.screen_prompt('All done! You got {0} correct out of {1} tones. Press ' - 'any key to close.'.format(running_total, num_trials), - max_wait=max_wait) + ec.screen_prompt( + f"All done! You got {running_total} correct out of {num_trials} tones. Press " + "any key to close.", + max_wait=max_wait, + ) ea.plot_screen(screenshot) diff --git a/examples/stimuli/advanced_stimuli.py b/examples/stimuli/advanced_stimuli.py index 15d79f37..31681976 100644 --- a/examples/stimuli/advanced_stimuli.py +++ b/examples/stimuli/advanced_stimuli.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ ======================================= Generate more advanced auditory stimuli @@ -8,28 +7,29 @@ of more advanced stimuli. """ -import numpy as np import matplotlib.pyplot as plt +import numpy as np from expyfun import building_doc from expyfun.stimuli import convolve_hrtf, play_sound, window_edges fs = 24414 dur = 0.5 -freq = 500. +freq = 500.0 # let's make a square wave sig = np.sin(freq * 2 * np.pi * np.arange(dur * fs, dtype=float) / fs) -sig = ((sig > 0) - 0.5) / 5. # make it reasonably quiet for play_sound +sig = ((sig > 0) - 0.5) / 5.0 # make it reasonably quiet for play_sound sig = window_edges(sig, fs) play_sound(sig, fs, norm=False, wait=True) -move_sig = np.concatenate([convolve_hrtf(sig, fs, ang) - for ang in range(-90, 91, 15)], axis=1) +move_sig = np.concatenate( + [convolve_hrtf(sig, fs, ang) for ang in range(-90, 91, 15)], axis=1 +) if not building_doc: play_sound(move_sig, fs, norm=False, wait=True) t = np.arange(move_sig.shape[1]) / float(fs) plt.plot(t, move_sig.T) -plt.xlabel('Time (sec)') +plt.xlabel("Time (sec)") plt.show() diff --git a/examples/stimuli/advanced_video.py b/examples/stimuli/advanced_video.py index f01aeef7..a0da5fdf 100644 --- a/examples/stimuli/advanced_video.py +++ b/examples/stimuli/advanced_video.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ ====================== Video property control @@ -10,36 +9,50 @@ """ import numpy as np -from expyfun import (ExperimentController, fetch_data_file, building_doc, - analyze as ea, visual) + +from expyfun import ( + ExperimentController, + building_doc, + fetch_data_file, + visual, +) +from expyfun import ( + analyze as ea, +) print(__doc__) -movie_path = fetch_data_file('video/example-video.mp4') +movie_path = fetch_data_file("video/example-video.mp4") -ec_args = dict(exp_name='advanced video example', window_size=(720, 480), - full_screen=False, participant='foo', session='foo', - version='dev', output_dir=None) -colors = [x for x in 'rgbcmyk'] +ec_args = dict( + exp_name="advanced video example", + window_size=(720, 480), + full_screen=False, + participant="foo", + session="foo", + version="dev", + output_dir=None, +) +colors = [x for x in "rgbcmyk"] with ExperimentController(**ec_args) as ec: - screen_period = 1. / ec.estimate_screen_fs() + screen_period = 1.0 / ec.estimate_screen_fs() all_presses = list() fix = visual.FixationDot(ec) - text = text = visual.Text(ec, "Running ...", (0, -0.1), 'k') + text = text = visual.Text(ec, "Running ...", (0, -0.1), "k") screenshot = None # don't have one yet ec.load_video(movie_path) - ec.video.set_scale('fill') - ec.screen_prompt('press 1 during video to toggle pause.', max_wait=1.) + ec.video.set_scale("fill") + ec.screen_prompt("press 1 during video to toggle pause.", max_wait=1.0) ec.listen_presses() # to catch presses on first pass of while loop t_zero = ec.video.play(auto_draw=False) - this_sec = 0. + this_sec = 0.0 while not ec.video.finished: if ec.video.playing: ec.video.draw() else: - ec.screen_text('paused!', color='y', font_size=32, wrap=False) + ec.screen_text("paused!", color="y", font_size=32, wrap=False) text.draw() fix.draw() if screenshot is None: @@ -50,8 +63,7 @@ # change the background color every 1 second if this_sec != int(ec.video.time): this_sec = int(ec.video.time) - text = visual.Text( - ec, str(colors[this_sec]), (0, -0.1), 'k') + text = visual.Text(ec, str(colors[this_sec]), (0, -0.1), "k") ec.set_background_color(colors[this_sec]) # shrink the video, then move it rightward if ec.video.playing: @@ -70,9 +82,9 @@ if building_doc: break ec.delete_video() - preamble = 'press times:' if len(all_presses) else 'no presses' - msg = ', '.join(['{0:.3f}'.format(x[1]) for x in all_presses]) + preamble = "press times:" if len(all_presses) else "no presses" + msg = ", ".join([f"{x[1]:.3f}" for x in all_presses]) ec.flip() - ec.screen_prompt('\n'.join([preamble, msg]), max_wait=1.) + ec.screen_prompt("\n".join([preamble, msg]), max_wait=1.0) ea.plot_screen(screenshot) diff --git a/examples/stimuli/crm_stimuli.py b/examples/stimuli/crm_stimuli.py index 49ad869b..263d3017 100644 --- a/examples/stimuli/crm_stimuli.py +++ b/examples/stimuli/crm_stimuli.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ ================== Use the CRM corpus @@ -9,13 +8,19 @@ @author: rkmaddox """ -from expyfun._utils import _TempDir -from expyfun import ExperimentController, analyze, building_doc -from expyfun.stimuli import (crm_prepare_corpus, crm_sentence, crm_info, - crm_response_menu, add_pad, CRMPreload) - import numpy as np +from expyfun import ExperimentController, analyze, building_doc +from expyfun._utils import _TempDir +from expyfun.stimuli import ( + CRMPreload, + add_pad, + crm_info, + crm_prepare_corpus, + crm_response_menu, + crm_sentence, +) + print(__doc__) crm_path = _TempDir() @@ -40,46 +45,56 @@ # >>> crm_prepare_corpus(24414) # -crm_prepare_corpus(fs, path_out=crm_path, overwrite=True, - talker_list=[dict(sex=0, talker_num=0), - dict(sex=1, talker_num=0)]) +crm_prepare_corpus( + fs, + path_out=crm_path, + overwrite=True, + talker_list=[dict(sex=0, talker_num=0), dict(sex=1, talker_num=0)], +) # print the valid callsigns -print('Valid callsigns are {0}'.format(crm_info()['callsign'])) +print(f'Valid callsigns are {crm_info()["callsign"]}') # read a sentence in from the hard drive -x1 = 0.5 * crm_sentence(fs, 'm', '0', 'c', 'r', '5', path=crm_path) +x1 = 0.5 * crm_sentence(fs, "m", "0", "c", "r", "5", path=crm_path) # preload all the talkers and get a second sentence from memory crm = CRMPreload(fs, path=crm_path) -x2 = crm.sentence('f', '0', 'ringo', 'green', '6') +x2 = crm.sentence("f", "0", "ringo", "green", "6") -x = add_pad([x1, x2], alignment='start') +x = add_pad([x1, x2], alignment="start") ############################################################################### # Now we actually run the experiment. max_wait = 0.01 if building_doc else 3 with ExperimentController( - exp_name='CRM corpus example', window_size=(720, 480), - full_screen=False, participant='foo', session='foo', version='dev', - output_dir=None, stim_fs=40000) as ec: - ec.screen_text('Report the color and number spoken by the female ' - 'talker.', wrap=True) + exp_name="CRM corpus example", + window_size=(720, 480), + full_screen=False, + participant="foo", + session="foo", + version="dev", + output_dir=None, + stim_fs=40000, +) as ec: + ec.screen_text( + "Report the color and number spoken by the female " "talker.", wrap=True + ) screenshot = ec.screenshot() ec.flip() ec.wait_secs(max_wait) ec.load_buffer(x) - ec.identify_trial(ec_id='', ttl_id=[]) + ec.identify_trial(ec_id="", ttl_id=[]) ec.start_stimulus() ec.wait_secs(x.shape[-1] / float(fs)) resp = crm_response_menu(ec, max_wait=0.01 if building_doc else np.inf) - if resp == ('g', '6'): - ec.screen_prompt('Correct!', max_wait=max_wait) + if resp == ("g", "6"): + ec.screen_prompt("Correct!", max_wait=max_wait) else: - ec.screen_prompt('Incorrect.', max_wait=max_wait) + ec.screen_prompt("Incorrect.", max_wait=max_wait) ec.trial_ok() analyze.plot_screen(screenshot) diff --git a/examples/stimuli/simple_video.py b/examples/stimuli/simple_video.py index 335da5bd..90dc1afc 100644 --- a/examples/stimuli/simple_video.py +++ b/examples/stimuli/simple_video.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ ========================= Video playing made simple @@ -10,21 +9,27 @@ @author: drmccloy """ -from expyfun import (ExperimentController, fetch_data_file, analyze as ea, - building_doc) +from expyfun import ExperimentController, building_doc, fetch_data_file +from expyfun import analyze as ea print(__doc__) -movie_path = fetch_data_file('video/example-video.mp4') - -ec_args = dict(exp_name='simple video example', window_size=(720, 480), - full_screen=False, participant='foo', session='foo', - version='dev', output_dir=None) +movie_path = fetch_data_file("video/example-video.mp4") + +ec_args = dict( + exp_name="simple video example", + window_size=(720, 480), + full_screen=False, + participant="foo", + session="foo", + version="dev", + output_dir=None, +) screenshot = None with ExperimentController(**ec_args) as ec: ec.load_video(movie_path) - ec.video.set_scale('fit') + ec.video.set_scale("fit") t_zero = ec.video.play() while not ec.video.finished: if ec.video.playing: @@ -36,6 +41,6 @@ ec.check_force_quit() ec.delete_video() ec.flip() - ec.screen_prompt('video over', max_wait=1.) + ec.screen_prompt("video over", max_wait=1.0) ea.plot_screen(screenshot) diff --git a/examples/stimuli/stimulus_power.py b/examples/stimuli/stimulus_power.py index f7ceab18..60bff2d1 100644 --- a/examples/stimuli/stimulus_power.py +++ b/examples/stimuli/stimulus_power.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ ===================================== Examine and manipulate stimulus power @@ -7,11 +6,11 @@ This shows how to make stimuli that play at different SNRs and db SPL. """ -import numpy as np import matplotlib.pyplot as plt +import numpy as np -from expyfun.stimuli import window_edges, read_wav, rms from expyfun import fetch_data_file +from expyfun.stimuli import read_wav, rms, window_edges print(__doc__) @@ -19,7 +18,7 @@ # Load data # --------- # Get 2 seconds of data -data_orig, fs = read_wav(fetch_data_file('audio/dream.wav')) +data_orig, fs = read_wav(fetch_data_file("audio/dream.wav")) stop = int(round(fs * 2)) data_orig = window_edges(data_orig[0, :stop], fs) t = np.arange(data_orig.size) / float(fs) @@ -27,8 +26,7 @@ # look at the waveform fig, ax = plt.subplots() ax.plot(t, data_orig) -ax.set(xlabel='Time (sec)', ylabel='Amplitude', title='Original', - xlim=t[[0, -1]]) +ax.set(xlabel="Time (sec)", ylabel="Amplitude", title="Original", xlim=t[[0, -1]]) fig.tight_layout() ############################################################################### @@ -45,7 +43,7 @@ target *= 0.01 # do manual calculation same as ``rms``, result should be 0.01 # (to numerical precision) -print(np.sqrt(np.mean(target ** 2))) +print(np.sqrt(np.mean(target**2))) ############################################################################### # One important thing to note about this stimulus is that its long-term RMS @@ -67,26 +65,25 @@ # during your experiment. # Good idea to use a seed for reproducibility! -ratio_dB = -6. # dB +ratio_dB = -6.0 # dB rng = np.random.RandomState(0) masker = rng.randn(len(target)) masker /= rms(masker) # now has unit RMS masker *= 0.01 # now has RMS=0.01, same as target -ratio_amplitude = 10 ** (ratio_dB / 20.) # conversion from dB to amplitude +ratio_amplitude = 10 ** (ratio_dB / 20.0) # conversion from dB to amplitude masker *= ratio_amplitude ############################################################################### # Looking at the overlaid traces, you can see that the resulting SNR varies as # a function of time. -colors = ['#4477AA', '#EE7733'] +colors = ["#4477AA", "#EE7733"] fig, ax = plt.subplots() -ax.plot(t, target, label='target', alpha=0.5, color=colors[0], lw=0.5) -ax.plot(t, masker, label='masker', alpha=0.5, color=colors[1], lw=0.5) -ax.axhline(0.01, label='target RMS', color=colors[0], lw=1) -ax.axhline(0.01 * ratio_amplitude, label='masker RMS', color=colors[1], lw=1) -ax.set(xlabel='Time (sec)', ylabel='Amplitude', title='Calibrated', - xlim=t[[0, -1]]) +ax.plot(t, target, label="target", alpha=0.5, color=colors[0], lw=0.5) +ax.plot(t, masker, label="masker", alpha=0.5, color=colors[1], lw=0.5) +ax.axhline(0.01, label="target RMS", color=colors[0], lw=1) +ax.axhline(0.01 * ratio_amplitude, label="masker RMS", color=colors[1], lw=1) +ax.set(xlabel="Time (sec)", ylabel="Amplitude", title="Calibrated", xlim=t[[0, -1]]) ax.legend() fig.tight_layout() @@ -97,19 +94,19 @@ # SNR varies as a function of frequency. from scipy.fft import rfft, rfftfreq # noqa -f = rfftfreq(len(target), 1. / fs) + +f = rfftfreq(len(target), 1.0 / fs) T = np.abs(rfft(target)) / np.sqrt(len(target)) # normalize the FFT properly M = np.abs(rfft(masker)) / np.sqrt(len(target)) fig, ax = plt.subplots() -ax.plot(f, T, label='target', alpha=0.5, color=colors[0], lw=0.5) -ax.plot(f, M, label='masker', alpha=0.5, color=colors[1], lw=0.5) +ax.plot(f, T, label="target", alpha=0.5, color=colors[0], lw=0.5) +ax.plot(f, M, label="masker", alpha=0.5, color=colors[1], lw=0.5) T_rms = rms(T) M_rms = rms(M) -print('Parseval\'s theorem: target RMS still %s' % (T_rms,)) -print('dB TMR is still %s' % (20 * np.log10(T_rms / M_rms),)) -ax.axhline(T_rms, label='target RMS', color=colors[0], lw=1) -ax.axhline(M_rms, label='masker RMS', color=colors[1], lw=1) -ax.set(xlabel='Freq (Hz)', ylabel='Amplitude', title='Spectrum', - xlim=f[[0, -1]]) +print("Parseval's theorem: target RMS still %s" % (T_rms,)) +print("dB TMR is still %s" % (20 * np.log10(T_rms / M_rms),)) +ax.axhline(T_rms, label="target RMS", color=colors[0], lw=1) +ax.axhline(M_rms, label="masker RMS", color=colors[1], lw=1) +ax.set(xlabel="Freq (Hz)", ylabel="Amplitude", title="Spectrum", xlim=f[[0, -1]]) ax.legend() fig.tight_layout() diff --git a/examples/stimuli/texture_stimuli.py b/examples/stimuli/texture_stimuli.py index 70e6cc6a..5d21b25a 100644 --- a/examples/stimuli/texture_stimuli.py +++ b/examples/stimuli/texture_stimuli.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ ======================== Generate texture stimuli @@ -7,25 +6,25 @@ This shows how to generate texture coherence stimuli. """ -import numpy as np import matplotlib.pyplot as plt +import numpy as np -from expyfun.stimuli import texture_ERB, play_sound +from expyfun.stimuli import play_sound, texture_ERB fs = 24414 n_freqs = 20 n_coh = 18 # very coherent example # let's make a textured stimilus and play it -sig = texture_ERB(n_freqs, n_coh, fs=fs, seq=('inc', 'nb', 'sam')) +sig = texture_ERB(n_freqs, n_coh, fs=fs, seq=("inc", "nb", "sam")) play_sound(sig, fs, norm=True, wait=True) ############################################################################### # Let's look at the time course t = np.arange(len(sig)) / float(fs) fig, ax = plt.subplots(1) -ax.plot(t, sig.T, color='k') -ax.set(xlabel='Time (sec)', ylabel='Amplitude (normalized)', xlim=t[[0, -1]]) +ax.plot(t, sig.T, color="k") +ax.set(xlabel="Time (sec)", ylabel="Amplitude (normalized)", xlim=t[[0, -1]]) fig.tight_layout() ############################################################################### @@ -33,17 +32,15 @@ fig, ax = plt.subplots(1, figsize=(8, 2)) img = ax.specgram(sig, NFFT=1024, Fs=fs, noverlap=800)[3] img.set_clim([img.get_clim()[1] - 50, img.get_clim()[1]]) -ax.set(xlim=t[[0, -1]], ylim=[0, 10000], xlabel='Time (sec)', - ylabel='Freq (Hz)') +ax.set(xlim=t[[0, -1]], ylim=[0, 10000], xlabel="Time (sec)", ylabel="Freq (Hz)") fig.tight_layout() ############################################################################### # And the long-term spectrum: fig, ax = plt.subplots(1) -ax.psd(sig, NFFT=16384, Fs=fs, color='k') +ax.psd(sig, NFFT=16384, Fs=fs, color="k") xticks = [250, 500, 1000, 2000, 4000, 8000] -ax.set(xlabel='Frequency (Hz)', ylabel='Power (dB)', xlim=[100, 10000], - xscale='log') +ax.set(xlabel="Frequency (Hz)", ylabel="Power (dB)", xlim=[100, 10000], xscale="log") ax.set(xticks=xticks) ax.set(xticklabels=xticks) fig.tight_layout() diff --git a/examples/stimuli/tracker_staircase.py b/examples/stimuli/tracker_staircase.py index b4d4ca09..1b6f9e56 100644 --- a/examples/stimuli/tracker_staircase.py +++ b/examples/stimuli/tracker_staircase.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ ============================== Do an adaptive track staircase @@ -12,15 +11,15 @@ import numpy as np -from expyfun.stimuli import TrackerUD from expyfun.analyze import sigmoid +from expyfun.stimuli import TrackerUD print(__doc__) # Make a callback function that prints to the console, rather than log file def callback(event_type, value=None, timestamp=None): - print((str(event_type) + ':').ljust(40) + str(value)) + print((str(event_type) + ":").ljust(40) + str(value)) # Define parameters for modeled human subject (sigmoid probability) @@ -36,12 +35,12 @@ def callback(event_type, value=None, timestamp=None): # Do the task until the tracker stops while not tr.stopped: - tr.respond(rng.rand() < sigmoid(tr.x_current - true_thresh, - lower=chance, slope=slope)) + tr.respond( + rng.rand() < sigmoid(tr.x_current - true_thresh, lower=chance, slope=slope) + ) # Plot the results fig, ax, lines = tr.plot() lines += tr.plot_thresh(4, ax=ax) -ax.set_title('Adaptive track of model human (true threshold is {})' - .format(true_thresh)) +ax.set_title(f"Adaptive track of model human (true threshold is {true_thresh})") diff --git a/examples/stimuli/tracker_staircase_MHW.py b/examples/stimuli/tracker_staircase_MHW.py index d38d271e..959c6cb9 100644 --- a/examples/stimuli/tracker_staircase_MHW.py +++ b/examples/stimuli/tracker_staircase_MHW.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ ================================================= Do an adaptive track staircase with MHW procedure @@ -12,13 +11,13 @@ import numpy as np -from expyfun.stimuli import TrackerMHW from expyfun.analyze import sigmoid +from expyfun.stimuli import TrackerMHW # Make a callback function that prints to the console, rather than log file def callback(event_type, value=None, timestamp=None): - print((str(event_type) + ':').ljust(40) + str(value)) + print((str(event_type) + ":").ljust(40) + str(value)) # Define parameters for modeled human subject (sigmoid probability) @@ -34,12 +33,12 @@ def callback(event_type, value=None, timestamp=None): # Do the task until the tracker stops while not tr.stopped: - tr.respond(rng.rand() < sigmoid(tr.x_current - true_thresh, - lower=chance, slope=slope)) + tr.respond( + rng.rand() < sigmoid(tr.x_current - true_thresh, lower=chance, slope=slope) + ) # Plot the results fig, ax, lines = tr.plot() lines += tr.plot_thresh() -ax.set_title('Adaptive track of model human (true threshold is {})' - .format(true_thresh)) +ax.set_title(f"Adaptive track of model human (true threshold is {true_thresh})") diff --git a/examples/stimuli/vocoded_stimuli.py b/examples/stimuli/vocoded_stimuli.py index 116ed06a..8db0ab09 100644 --- a/examples/stimuli/vocoded_stimuli.py +++ b/examples/stimuli/vocoded_stimuli.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ ======================== Generate vocoded stimuli @@ -9,33 +8,33 @@ @author: larsoner """ -import numpy as np import matplotlib.pyplot as plt +import numpy as np -from expyfun.stimuli import vocode, play_sound, window_edges, read_wav, rms from expyfun import fetch_data_file +from expyfun.stimuli import play_sound, read_wav, rms, vocode, window_edges print(__doc__) -data, fs = read_wav(fetch_data_file('audio/dream.wav')) +data, fs = read_wav(fetch_data_file("audio/dream.wav")) data = window_edges(data[0], fs) t = np.arange(data.size) / float(fs) # noise vocoder -data_noise = vocode(data, fs, mode='noise') +data_noise = vocode(data, fs, mode="noise") data_noise = data_noise * 0.01 / rms(data_noise) # sinewave vocoder -data_tone = vocode(data, fs, mode='tone') +data_tone = vocode(data, fs, mode="tone") data_tone = data_tone * 0.01 / rms(data_tone) # poisson vocoder -data_click = vocode(data, fs, mode='poisson', rate=400) +data_click = vocode(data, fs, mode="poisson", rate=400) data_click = data_click * 0.01 / rms(data_click) # combine all three cutoff = data.shape[-1] // 3 data_allthree = data_noise.copy() -data_allthree[cutoff:2 * cutoff] = data_tone[cutoff:2 * cutoff] -data_allthree[2 * cutoff:] = data_click[2 * cutoff:] +data_allthree[cutoff : 2 * cutoff] = data_tone[cutoff : 2 * cutoff] +data_allthree[2 * cutoff :] = data_click[2 * cutoff :] snd = play_sound(data_allthree, fs, norm=False, wait=False) # Uncomment this to play the original, too: @@ -43,18 +42,18 @@ ax1 = plt.subplot(3, 1, 1) ax1.plot(t, data) -ax1.set_title('Original') -ax1.set_ylabel('Amplitude') +ax1.set_title("Original") +ax1.set_ylabel("Amplitude") ax2 = plt.subplot(3, 1, 2, sharex=ax1, sharey=ax1) ax2.plot(t, data_noise) -ax2.set_title('Vocoded') +ax2.set_title("Vocoded") ax3 = plt.subplot(3, 1, 3, sharex=ax1) -ax2.set_title('Spectrogram') -ax2.set_ylabel('Amplitude') +ax2.set_title("Spectrogram") +ax2.set_ylabel("Amplitude") ax3.specgram(data_noise, Fs=fs) ax3.set_xlim(t[[0, -1]]) -ax3.set_ylim([0, fs / 2.]) -ax3.set_ylabel('Frequency (hz)') -ax3.set_xlabel('Time (sec)') +ax3.set_ylim([0, fs / 2.0]) +ax3.set_ylabel("Frequency (hz)") +ax3.set_xlabel("Time (sec)") plt.tight_layout() plt.show() diff --git a/examples/sync/sample_rate_test.py b/examples/sync/sample_rate_test.py index 07c9a091..a13e2c70 100644 --- a/examples/sync/sample_rate_test.py +++ b/examples/sync/sample_rate_test.py @@ -44,18 +44,25 @@ print(__doc__) stim = np.zeros(int(1e6) + 1) -stim[[0, -1]] = 1. -with ExperimentController('FsTest', full_screen=False, noise_db=-np.inf, - participant='s', session='0', output_dir=None, - suppress_resamp=True, check_rms=None, - version='dev') as ec: - ec.identify_trial(ec_id='', ttl_id=[0]) +stim[[0, -1]] = 1.0 +with ExperimentController( + "FsTest", + full_screen=False, + noise_db=-np.inf, + participant="s", + session="0", + output_dir=None, + suppress_resamp=True, + check_rms=None, + version="dev", +) as ec: + ec.identify_trial(ec_id="", ttl_id=[0]) ec.load_buffer(stim) - print('Starting stimulus.') + print("Starting stimulus.") ec.start_stimulus() - wait_dur = len(stim) / ec.fs + 1. - print('Stimulus started. Please wait %d seconds.' % wait_dur) + wait_dur = len(stim) / ec.fs + 1.0 + print("Stimulus started. Please wait %d seconds." % wait_dur) if not building_doc: ec.wait_secs(wait_dur) ec.stop() - print('Stimulus done.') + print("Stimulus done.") diff --git a/examples/sync/sync_test.py b/examples/sync/sync_test.py index f245ad83..f59ef418 100644 --- a/examples/sync/sync_test.py +++ b/examples/sync/sync_test.py @@ -36,16 +36,24 @@ import numpy as np +import expyfun.analyze as ea from expyfun import ExperimentController, building_doc from expyfun.visual import Circle, Rectangle -import expyfun.analyze as ea n_channels = 2 click_idx = [0] -with ExperimentController('SyncTest', full_screen=True, noise_db=-np.inf, - participant='s', session='0', output_dir=None, - suppress_resamp=True, check_rms=None, - n_channels=n_channels, version='dev') as ec: +with ExperimentController( + "SyncTest", + full_screen=True, + noise_db=-np.inf, + participant="s", + session="0", + output_dir=None, + suppress_resamp=True, + check_rms=None, + n_channels=n_channels, + version="dev", +) as ec: click = np.r_[0.1, np.zeros(99)] # RMS = 0.01 data = np.zeros((n_channels, len(click))) data[click_idx] = click @@ -53,23 +61,23 @@ pressed = None screenshot = None # Make a circle so that the photodiode can be centered on the screen - circle = Circle(ec, 1, units='deg', fill_color='k', line_color='w') + circle = Circle(ec, 1, units="deg", fill_color="k", line_color="w") # Make a rectangle that is the standard credit card size - rect = Rectangle(ec, [0, 0, 8.56, 5.398], 'cm', None, '#AA3377') - while pressed != '8': # enable a clean quit if required - ec.set_background_color('white') + rect = Rectangle(ec, [0, 0, 8.56, 5.398], "cm", None, "#AA3377") + while pressed != "8": # enable a clean quit if required + ec.set_background_color("white") t1 = ec.start_stimulus(start_of_trial=False) # skip checks - ec.set_background_color('black') + ec.set_background_color("black") t2 = ec.flip() diff = round(1000 * (t2 - t1), 2) - ec.screen_text('IFI (ms): {}'.format(diff), wrap=True) + ec.screen_text(f"IFI (ms): {diff}", wrap=True) circle.draw() rect.draw() screenshot = ec.screenshot() if screenshot is None else screenshot ec.flip() ec.stamp_triggers([2, 4, 8]) ec.refocus() - pressed = ec.wait_one_press(0.5)[0] if not building_doc else '8' + pressed = ec.wait_one_press(0.5)[0] if not building_doc else "8" ec.stop() ea.plot_screen(screenshot) diff --git a/expyfun/__init__.py b/expyfun/__init__.py index 76ceb486..105a8ed9 100644 --- a/expyfun/__init__.py +++ b/expyfun/__init__.py @@ -8,16 +8,28 @@ from ._version import __version__ # have to import verbose first since it's needed by many things -from ._utils import (set_log_level, set_log_file, set_config, check_units, - get_config, get_config_path, fetch_data_file, - run_subprocess, verbose_dec as verbose, building_doc, - known_config_types) +from ._utils import ( + set_log_level, + set_log_file, + set_config, + check_units, + get_config, + get_config_path, + fetch_data_file, + run_subprocess, + verbose_dec as verbose, + building_doc, + known_config_types, +) from ._git import assert_version, download_version from ._experiment_controller import ExperimentController, get_keyboard_input from ._eyelink_controller import EyelinkController from ._sound_controllers import SoundCardController -from ._trigger_controllers import (decimals_to_binary, binary_to_decimals, - ParallelTrigger) +from ._trigger_controllers import ( + decimals_to_binary, + binary_to_decimals, + ParallelTrigger, +) from ._tdt_controller import TDTController from . import analyze from . import codeblocks diff --git a/expyfun/_experiment_controller.py b/expyfun/_experiment_controller.py index 396d3d94..2cbdb02c 100644 --- a/expyfun/_experiment_controller.py +++ b/expyfun/_experiment_controller.py @@ -6,31 +6,43 @@ # # License: BSD (3-clause) -from collections import OrderedDict import inspect import json import os +import string import sys +import traceback as tb import warnings -from os import path as op +from collections import OrderedDict from functools import partial -import traceback as tb -import string +from os import path as op import numpy as np -from ._utils import (get_config, verbose_dec, _check_pyglet_version, - running_rms, _sanitize, logger, ZeroClock, date_str, - check_units, set_log_file, flush_logger, _TempDir, - string_types, _fix_audio_dims, input, _get_args, - _get_display, _wait_secs) +from ._git import __version__, assert_version +from ._input_controllers import CedrusBox, Joystick, Keyboard, Mouse +from ._sound_controllers import _AUTO_BACKENDS, SoundCardController, SoundPlayer from ._tdt_controller import TDTController from ._trigger_controllers import ParallelTrigger -from ._sound_controllers import (SoundPlayer, SoundCardController, - _AUTO_BACKENDS) -from ._input_controllers import Keyboard, CedrusBox, Mouse, Joystick -from .visual import Text, Rectangle, Video, _convert_color -from ._git import assert_version, __version__ +from ._utils import ( + ZeroClock, + _check_pyglet_version, + _fix_audio_dims, + _get_args, + _get_display, + _sanitize, + _TempDir, + _wait_secs, + check_units, + date_str, + flush_logger, + get_config, + logger, + running_rms, + set_log_file, + verbose_dec, +) +from .visual import Rectangle, Text, Video, _convert_color # Note: ec._trial_progress has three values: # 1. 'stopped', which ec.identify_trial turns into... @@ -40,7 +52,7 @@ _SLOW_LIMIT = 10000000 -class ExperimentController(object): +class ExperimentController: """Interface for hardware control (audio, buttonbox, eye tracker, etc.) Parameters @@ -142,14 +154,33 @@ class ExperimentController(object): """ @verbose_dec - def __init__(self, exp_name, audio_controller=None, response_device=None, - stim_rms=0.01, stim_fs=24414, stim_db=65, noise_db=45, - output_dir='data', window_size=None, screen_num=None, - full_screen=True, force_quit=None, participant=None, - monitor=None, trigger_controller=None, session=None, - check_rms='windowed', suppress_resamp=False, version=None, - safe_flipping=None, n_channels=2, - trigger_duration=0.01, joystick=False, verbose=None): + def __init__( + self, + exp_name, + audio_controller=None, + response_device=None, + stim_rms=0.01, + stim_fs=24414, + stim_db=65, + noise_db=45, + output_dir="data", + window_size=None, + screen_num=None, + full_screen=True, + force_quit=None, + participant=None, + monitor=None, + trigger_controller=None, + session=None, + check_rms="windowed", + suppress_resamp=False, + version=None, + safe_flipping=None, + n_channels=2, + trigger_duration=0.01, + joystick=False, + verbose=None, + ): # initialize some values self._stim_fs = stim_fs self._stim_rms = stim_rms @@ -158,7 +189,7 @@ def __init__(self, exp_name, audio_controller=None, response_device=None, self._stim_scaler = None self._suppress_resamp = suppress_resamp self.video = None - self._bgcolor = _convert_color('k') + self._bgcolor = _convert_color("k") # placeholder for extra actions to do on flip-and-play self._on_every_flip = [] self._on_next_flip = [] @@ -179,19 +210,23 @@ def __init__(self, exp_name, audio_controller=None, response_device=None, _check_pyglet_version(raise_error=True) # assure proper formatting for force-quit keys if force_quit is None: - force_quit = ['lctrl', 'rctrl'] - elif isinstance(force_quit, (int, string_types)): + force_quit = ["lctrl", "rctrl"] + elif isinstance(force_quit, (int, str)): force_quit = [str(force_quit)] - if 'escape' in force_quit: - logger.warning('Expyfun: using "escape" as a force-quit key ' - 'is not recommended because it has special ' - 'status in pyglet.') + if "escape" in force_quit: + logger.warning( + 'Expyfun: using "escape" as a force-quit key ' + "is not recommended because it has special " + "status in pyglet." + ) # check expyfun version if version is None: - raise RuntimeError('You must specify an expyfun version string' - ' to use ExperimentController, or specify ' - 'version=\'dev\' to override.') - elif version.lower() != 'dev': + raise RuntimeError( + "You must specify an expyfun version string" + " to use ExperimentController, or specify " + "version='dev' to override." + ) + elif version.lower() != "dev": assert_version(version) # set up timing # Use ZeroClock, which uses the "clock" fn but starts at zero @@ -203,28 +238,28 @@ def __init__(self, exp_name, audio_controller=None, response_device=None, self._exp_info = OrderedDict() for name in _get_args(self.__init__): - if name != 'self': + if name != "self": self._exp_info[name] = locals()[name] - self._exp_info['date'] = date_str() + self._exp_info["date"] = date_str() # skip verbose decorator frames - self._exp_info['file'] = \ - op.abspath(inspect.getfile(sys._getframe(3))) - self._exp_info['version_used'] = __version__ + self._exp_info["file"] = op.abspath(inspect.getfile(sys._getframe(3))) + self._exp_info["version_used"] = __version__ # session start dialog, if necessary - show_list = ['exp_name', 'date', 'file', 'participant', 'session'] - edit_list = ['participant', 'session'] # things editable in GUI + show_list = ["exp_name", "date", "file", "participant", "session"] + edit_list = ["participant", "session"] # things editable in GUI for key in show_list: value = self._exp_info[key] - if key in edit_list and value is not None and \ - not isinstance(value, string_types): - raise TypeError('{} must be string or None' - ''.format(value)) + if ( + key in edit_list + and value is not None + and not isinstance(value, str) + ): + raise TypeError(f"{value} must be string or None" "") if key in edit_list and value is None: - self._exp_info[key] = get_keyboard_input( - '{0}: '.format(key)) + self._exp_info[key] = get_keyboard_input(f"{key}: ") else: - print('{0}: {1}'.format(key, value)) + print(f"{key}: {value}") # # initialize log file @@ -235,92 +270,103 @@ def __init__(self, exp_name, audio_controller=None, response_device=None, output_dir = op.abspath(output_dir) if not op.isdir(output_dir): os.mkdir(output_dir) - basename = op.join(output_dir, '{}_{}' - ''.format(self._exp_info['participant'], - self._exp_info['date'])) + basename = op.join( + output_dir, + "{}_{}" "".format( + self._exp_info["participant"], self._exp_info["date"] + ), + ) self._output_dir = basename - self._log_file = self._output_dir + '.log' + self._log_file = self._output_dir + ".log" set_log_file(self._log_file) closer = partial(set_log_file, None) # initialize data file - self._data_file = open(self._output_dir + '.tab', 'a') + self._data_file = open(self._output_dir + ".tab", "a") self._extra_cleanup_fun.append(self.flush) # flush self._extra_cleanup_fun.append(self._data_file.close) # close self._extra_cleanup_fun.append(closer) # un-set log file - self._data_file.write('# ' + json.dumps(self._exp_info) + '\n') - self.write_data_line('event', 'value', 'timestamp') - logger.info('Expyfun: Using version %s (requested %s)' - % (__version__, version)) + self._data_file.write("# " + json.dumps(self._exp_info) + "\n") + self.write_data_line("event", "value", "timestamp") + logger.info( + "Expyfun: Using version %s (requested %s)" % (__version__, version) + ) # # set up monitor # if safe_flipping is None: - safe_flipping = not (get_config('SAFE_FLIPPING', '').lower() == - 'false') + safe_flipping = not (get_config("SAFE_FLIPPING", "").lower() == "false") if not safe_flipping: - logger.warning('Expyfun: Unsafe flipping mode enabled, flip ' - 'timing not guaranteed') + logger.warning( + "Expyfun: Unsafe flipping mode enabled, flip " + "timing not guaranteed" + ) self.safe_flipping = safe_flipping if screen_num is None: - screen_num = int(get_config('SCREEN_NUM', '0')) + screen_num = int(get_config("SCREEN_NUM", "0")) display = _get_display() screen = display.get_screens()[screen_num] if monitor is None: mon_size = [screen.width, screen.height] - mon_size = ','.join([str(d) for d in mon_size]) + mon_size = ",".join([str(d) for d in mon_size]) monitor = dict() - width = float(get_config('SCREEN_WIDTH', '51.0')) - dist = float(get_config('SCREEN_DISTANCE', '48.0')) - monitor['SCREEN_WIDTH'] = width - monitor['SCREEN_DISTANCE'] = dist - mon_size = get_config('SCREEN_SIZE_PIX', mon_size).split(',') + width = float(get_config("SCREEN_WIDTH", "51.0")) + dist = float(get_config("SCREEN_DISTANCE", "48.0")) + monitor["SCREEN_WIDTH"] = width + monitor["SCREEN_DISTANCE"] = dist + mon_size = get_config("SCREEN_SIZE_PIX", mon_size).split(",") mon_size = [int(p) for p in mon_size] - monitor['SCREEN_SIZE_PIX'] = mon_size + monitor["SCREEN_SIZE_PIX"] = mon_size if not isinstance(monitor, dict): - raise TypeError('monitor must be a dict, got %r' % (monitor,)) - req_mon_keys = ['SCREEN_WIDTH', 'SCREEN_DISTANCE', - 'SCREEN_SIZE_PIX'] + raise TypeError("monitor must be a dict, got %r" % (monitor,)) + req_mon_keys = ["SCREEN_WIDTH", "SCREEN_DISTANCE", "SCREEN_SIZE_PIX"] missing_keys = [key for key in req_mon_keys if key not in monitor] if missing_keys: - raise KeyError('monitor is missing required keys {0}' - ''.format(missing_keys)) - mon_size = monitor['SCREEN_SIZE_PIX'] - monitor['SCREEN_DPI'] = (monitor['SCREEN_SIZE_PIX'][0] / - (monitor['SCREEN_WIDTH'] * 0.393701)) - monitor['SCREEN_HEIGHT'] = (monitor['SCREEN_WIDTH'] / - float(monitor['SCREEN_SIZE_PIX'][0]) * - float(monitor['SCREEN_SIZE_PIX'][1])) + raise KeyError(f"monitor is missing required keys {missing_keys}" "") + mon_size = monitor["SCREEN_SIZE_PIX"] + monitor["SCREEN_DPI"] = monitor["SCREEN_SIZE_PIX"][0] / ( + monitor["SCREEN_WIDTH"] * 0.393701 + ) + monitor["SCREEN_HEIGHT"] = ( + monitor["SCREEN_WIDTH"] + / float(monitor["SCREEN_SIZE_PIX"][0]) + * float(monitor["SCREEN_SIZE_PIX"][1]) + ) self._monitor = monitor # # parse audio controller # if audio_controller is None: - audio_controller = {'TYPE': get_config('AUDIO_CONTROLLER', - 'sound_card')} - elif isinstance(audio_controller, string_types): + audio_controller = { + "TYPE": get_config("AUDIO_CONTROLLER", "sound_card") + } + elif isinstance(audio_controller, str): # old option, backward compat / shortcut if audio_controller in _AUTO_BACKENDS: audio_controller = { - 'TYPE': 'sound_card', - 'SOUND_CARD_BACKEND': audio_controller} + "TYPE": "sound_card", + "SOUND_CARD_BACKEND": audio_controller, + } else: - audio_controller = {'TYPE': audio_controller.lower()} + audio_controller = {"TYPE": audio_controller.lower()} elif not isinstance(audio_controller, dict): - raise TypeError('audio_controller must be a str or dict, got ' - 'type %s' % (type(audio_controller),)) - audio_type = audio_controller['TYPE'].lower() + raise TypeError( + "audio_controller must be a str or dict, got " + "type %s" % (type(audio_controller),) + ) + audio_type = audio_controller["TYPE"].lower() # # parse response device # if response_device is None: - response_device = get_config('RESPONSE_DEVICE', 'keyboard') - if response_device not in ['keyboard', 'tdt', 'cedrus']: - raise ValueError('response_device must be "keyboard", "tdt", ' - '"cedrus", or None') + response_device = get_config("RESPONSE_DEVICE", "keyboard") + if response_device not in ["keyboard", "tdt", "cedrus"]: + raise ValueError( + 'response_device must be "keyboard", "tdt", ' '"cedrus", or None' + ) self._response_device = response_device # @@ -329,28 +375,40 @@ def __init__(self, exp_name, audio_controller=None, response_device=None, trigger_duration = float(trigger_duration) if not 0.001 < trigger_duration <= 0.02: # probably an error - raise ValueError('high_duration must be between 0.001 and ' - '0.02 sec, got %s' % (trigger_duration,)) + raise ValueError( + "high_duration must be between 0.001 and " + "0.02 sec, got %s" % (trigger_duration,) + ) # Audio (and for TDT, potentially keyboard) - if audio_type == 'tdt': - logger.info('Expyfun: Setting up TDT') + if audio_type == "tdt": + logger.info("Expyfun: Setting up TDT") if n_channels != 2: - raise ValueError('n_channels must be equal to 2 for the ' - 'TDT backend, got %s' % (n_channels,)) + raise ValueError( + "n_channels must be equal to 2 for the " + "TDT backend, got %s" % (n_channels,) + ) if trigger_duration != 0.01: - raise ValueError('trigger_duration must be 0.01 for TDT, ' - 'got %s' % (trigger_duration,)) + raise ValueError( + "trigger_duration must be 0.01 for TDT, " + "got %s" % (trigger_duration,) + ) self._ac = TDTController(audio_controller, ec=self) self.audio_type = self._ac.model - elif audio_type == 'sound_card': + elif audio_type == "sound_card": self._ac = SoundCardController( - audio_controller, self.stim_fs, n_channels, - trigger_duration=trigger_duration, ec=self) + audio_controller, + self.stim_fs, + n_channels, + trigger_duration=trigger_duration, + ec=self, + ) self.audio_type = self._ac.backend_name else: - raise ValueError('audio_controller[\'TYPE\'] must be "tdt" ' - 'or "sound_card", got %r.' % (audio_type,)) + raise ValueError( + "audio_controller['TYPE'] must be \"tdt\" " + 'or "sound_card", got %r.' % (audio_type,) + ) del audio_type self._extra_cleanup_fun.insert(0, self._ac.halt) # audio scaling factor; ensure uniform intensity across devices @@ -358,46 +416,50 @@ def __init__(self, exp_name, audio_controller=None, response_device=None, self.set_noise_db(self._noise_db) if self._fs_mismatch: - msg = ('Expyfun: Mismatch between reported stim sample ' - 'rate ({0}) and device sample rate ({1}). ' - .format(self.stim_fs, self.fs)) + msg = ( + "Expyfun: Mismatch between reported stim sample " + f"rate ({self.stim_fs}) and device sample rate ({self.fs}). " + ) if self._suppress_resamp: - msg += ('Nothing will be done about this because ' - 'suppress_resamp is "True"') + msg += ( + "Nothing will be done about this because " + 'suppress_resamp is "True"' + ) else: - msg += ('Experiment Controller will resample for you, but ' - 'this takes a non-trivial amount of processing ' - 'time and may compromise your experimental ' - 'timing and/or cause artifacts.') + msg += ( + "Experiment Controller will resample for you, but " + "this takes a non-trivial amount of processing " + "time and may compromise your experimental " + "timing and/or cause artifacts." + ) logger.warning(msg) # # set up visual window (must be done before keyboard and mouse) # - logger.info('Expyfun: Setting up screen') + logger.info("Expyfun: Setting up screen") if full_screen: if window_size is None: - window_size = monitor['SCREEN_SIZE_PIX'] + window_size = monitor["SCREEN_SIZE_PIX"] else: if window_size is None: - window_size = get_config('WINDOW_SIZE', - '800,600').split(',') + window_size = get_config("WINDOW_SIZE", "800,600").split(",") window_size = [int(w) for w in window_size] window_size = np.array(window_size) if window_size.ndim != 1 or window_size.size != 2: - raise ValueError('window_size must be 2-element array-like or ' - 'None') + raise ValueError("window_size must be 2-element array-like or " "None") # open window and setup GL config self._setup_window(window_size, exp_name, full_screen, screen) # Keyboard - if response_device == 'keyboard': + if response_device == "keyboard": self._response_handler = Keyboard(self, force_quit) - elif response_device == 'tdt': + elif response_device == "tdt": if not isinstance(self._ac, TDTController): - raise ValueError('response_device can only be "tdt" if ' - 'tdt is used for audio') + raise ValueError( + 'response_device can only be "tdt" if ' "tdt is used for audio" + ) self._response_handler = self._ac self._ac._add_keyboard_init(self, force_quit) else: # response_device == 'cedrus' @@ -415,67 +477,79 @@ def __init__(self, exp_name, audio_controller=None, response_device=None, # self._ofp_critical_funs = list() if trigger_controller is None: - trigger_controller = get_config('TRIGGER_CONTROLLER', 'dummy') - if isinstance(trigger_controller, string_types): + trigger_controller = get_config("TRIGGER_CONTROLLER", "dummy") + if isinstance(trigger_controller, str): trigger_controller = dict(TYPE=trigger_controller) assert isinstance(trigger_controller, dict) trigger_controller = trigger_controller.copy() - known_keys = ('TYPE',) + known_keys = ("TYPE",) if set(trigger_controller) != set(known_keys): raise ValueError( - 'Unknown keys for trigger_controller, must be ' - f'{known_keys}, got {set(trigger_controller)}') - logger.info(f'Expyfun: Initializing {trigger_controller["TYPE"]} ' - 'triggering mode') - if trigger_controller['TYPE'] == 'tdt': + "Unknown keys for trigger_controller, must be " + f"{known_keys}, got {set(trigger_controller)}" + ) + logger.info( + f'Expyfun: Initializing {trigger_controller["TYPE"]} ' 'triggering mode' + ) + if trigger_controller["TYPE"] == "tdt": if not isinstance(self._ac, TDTController): - raise ValueError('trigger_controller can only be "tdt" if ' - 'tdt is used for audio') + raise ValueError( + 'trigger_controller can only be "tdt" if ' + "tdt is used for audio" + ) self._tc = self._ac - elif trigger_controller['TYPE'] == 'sound_card': + elif trigger_controller["TYPE"] == "sound_card": if not isinstance(self._ac, SoundCardController): - raise ValueError('trigger_controller can only be ' - '"sound_card" if the sound card is ' - 'used for audio') + raise ValueError( + "trigger_controller can only be " + '"sound_card" if the sound card is ' + "used for audio" + ) if self._ac._n_channels_stim == 0: - raise ValueError('cannot use sound card for triggering ' - 'when SOUND_CARD_TRIGGER_CHANNELS is ' - 'zero') + raise ValueError( + "cannot use sound card for triggering " + "when SOUND_CARD_TRIGGER_CHANNELS is " + "zero" + ) self._tc = self._ac - elif trigger_controller['TYPE'] in ['parallel', 'dummy']: + elif trigger_controller["TYPE"] in ["parallel", "dummy"]: addr = trigger_controller.get( - 'TRIGGER_ADDRESS', get_config('TRIGGER_ADDRESS', None)) + "TRIGGER_ADDRESS", get_config("TRIGGER_ADDRESS", None) + ) self._tc = ParallelTrigger( - trigger_controller['TYPE'], addr, - trigger_duration, ec=self) + trigger_controller["TYPE"], addr, trigger_duration, ec=self + ) self._extra_cleanup_fun.insert(0, self._tc.close) # The TDT always stamps "1" on stimulus onset. Here we need # to manually mimic that behavior. self._ofp_critical_funs.insert( - 0, lambda: self._stamp_ttl_triggers([1], False, False)) + 0, lambda: self._stamp_ttl_triggers([1], False, False) + ) else: - raise ValueError('trigger_controller type must be ' - '"parallel", "dummy", "sound_card", or "tdt",' - 'got {0}'.format(trigger_controller['TYPE'])) - self._id_call_dict['ttl_id'] = self._stamp_binary_id + raise ValueError( + "trigger_controller type must be " + '"parallel", "dummy", "sound_card", or "tdt",' + "got {0}".format(trigger_controller["TYPE"]) + ) + self._id_call_dict["ttl_id"] = self._stamp_binary_id # other basic components self._mouse_handler = Mouse(self) - t = np.arange(44100 // 3) / 44100. + t = np.arange(44100 // 3) / 44100.0 car = sum([np.sin(2 * np.pi * f * t) for f in [800, 1000, 1200]]) self._beep = None self._beep_data = np.tile(car * np.exp(-t * 10) / 4, (2, 3)) # finish initialization - logger.info('Expyfun: Initialization complete') - logger.exp('Expyfun: Participant: {0}' - ''.format(self._exp_info['participant'])) - logger.exp('Expyfun: Session: {0}' - ''.format(self._exp_info['session'])) - ok_log = partial(self.write_data_line, 'trial_ok', None) + logger.info("Expyfun: Initialization complete") + logger.exp( + "Expyfun: Participant: {0}" "".format(self._exp_info["participant"]) + ) + logger.exp("Expyfun: Session: {0}" "".format(self._exp_info["session"])) + ok_log = partial(self.write_data_line, "trial_ok", None) self._on_trial_ok.append(ok_log) self._on_trial_ok.append(self.flush) - self._trial_progress = 'stopped' + self._trial_progress = "stopped" except Exception: self.close() raise @@ -483,19 +557,28 @@ def __init__(self, exp_name, audio_controller=None, response_device=None, self.flip() def __repr__(self): - """Return a useful string representation of the experiment - """ - string = ('' - ''.format(self._exp_info['exp_name'], - self._exp_info['participant'], - self._exp_info['session'], - self.audio_type)) + """Return a useful string representation of the experiment""" + string = '' "".format( + self._exp_info["exp_name"], + self._exp_info["participant"], + self._exp_info["session"], + self.audio_type, + ) return string -# ############################### SCREEN METHODS ############################## - def screen_text(self, text, pos=[0, 0], color='white', font_name='Arial', - font_size=24, wrap=True, units='norm', attr=True, - log_data=True): + # ############################### SCREEN METHODS ############################## + def screen_text( + self, + text, + pos=(0, 0), + color="white", + font_name="Arial", + font_size=24, + wrap=True, + units="norm", + attr=True, + log_data=True, + ): """Show some text on the screen. Parameters @@ -534,18 +617,39 @@ def screen_text(self, text, pos=[0, 0], color='white', font_name='Arial', ExperimentController.screen_prompt """ check_units(units) - scr_txt = Text(self, text, pos, color, font_name, font_size, - wrap=wrap, units=units, attr=attr) + scr_txt = Text( + self, + text, + pos, + color, + font_name, + font_size, + wrap=wrap, + units=units, + attr=attr, + ) scr_txt.draw() if log_data: - self.call_on_next_flip(partial(self.write_data_line, 'screen_text', - text)) + self.call_on_next_flip(partial(self.write_data_line, "screen_text", text)) return scr_txt - def screen_prompt(self, text, max_wait=np.inf, min_wait=0, live_keys=None, - timestamp=False, clear_after=True, pos=[0, 0], - color='white', font_name='Arial', font_size=24, - wrap=True, units='norm', attr=True, click=False): + def screen_prompt( + self, + text, + max_wait=np.inf, + min_wait=0, + live_keys=None, + timestamp=False, + clear_after=True, + pos=(0, 0), + color="white", + font_name="Arial", + font_size=24, + wrap=True, + units="norm", + attr=True, + click=False, + ): """Display text and (optionally) wait for user continuation Parameters @@ -605,12 +709,19 @@ def screen_prompt(self, text, max_wait=np.inf, min_wait=0, live_keys=None, """ if not isinstance(text, list): text = [text] - if not all([isinstance(t, string_types) for t in text]): - raise TypeError('text must be a string or list of strings') + if not all(isinstance(t, str) for t in text): + raise TypeError("text must be a string or list of strings") for t in text: - self.screen_text(t, pos=pos, color=color, font_name=font_name, - font_size=font_size, wrap=wrap, units=units, - attr=attr) + self.screen_text( + t, + pos=pos, + color=color, + font_name=font_name, + font_size=font_size, + wrap=wrap, + units=units, + attr=attr, + ) self.flip() fun = self.wait_one_click if click else self.wait_one_press out = fun(max_wait, min_wait, live_keys, timestamp) @@ -618,7 +729,7 @@ def screen_prompt(self, text, max_wait=np.inf, min_wait=0, live_keys=None, self.flip() return out - def set_background_color(self, color='black'): + def set_background_color(self, color="black"): """Set and draw a solid background color Parameters @@ -635,10 +746,11 @@ def set_background_color(self, color='black'): appropriate background color. """ from pyglet import gl + # we go a little over here to be safe from round-off errors Rectangle(self, pos=[0, 0, 2.1, 2.1], fill_color=color).draw() self._bgcolor = _convert_color(color) - gl.glClearColor(*[c / 255. for c in self._bgcolor]) + gl.glClearColor(*[c / 255.0 for c in self._bgcolor]) def start_stimulus(self, start_of_trial=True, flip=True, when=None): """Play audio, (optionally) flip screen, run any "on_flip" functions. @@ -679,17 +791,19 @@ def start_stimulus(self, start_of_trial=True, flip=True, when=None): `call_on_next_flip` and `call_on_every_flip`. """ if start_of_trial: - if self._trial_progress != 'identified': - raise RuntimeError('Trial ID must be stamped before starting ' - 'the trial') - self._trial_progress = 'started' - extra = 'flipping screen and ' if flip else '' - logger.exp('Expyfun: Starting stimuli: {0}playing audio'.format(extra)) + if self._trial_progress != "identified": + raise RuntimeError( + "Trial ID must be stamped before starting " "the trial" + ) + self._trial_progress = "started" + extra = "flipping screen and " if flip else "" + logger.exp(f"Expyfun: Starting stimuli: {extra}playing audio") # ensure self._play comes first in list, followed by other critical # private functions (e.g., EL stamping), then user functions: if flip: - self._on_next_flip = ([self._play] + self._ofp_critical_funs + - self._on_next_flip) + self._on_next_flip = ( + [self._play] + self._ofp_critical_funs + self._on_next_flip + ) stimulus_time = self.flip(when) else: if when is not None: @@ -757,51 +871,56 @@ def _convert_units(self, verts, fro, to): check_units(fro) verts = np.array(np.atleast_2d(verts), dtype=float) if verts.shape[0] != 2: - raise RuntimeError('verts must have 2 rows') + raise RuntimeError("verts must have 2 rows") if fro == to: return verts # simplify by using two if neither is in normalized (native) units - if 'norm' not in [to, fro]: + if "norm" not in [to, fro]: # convert to normal - verts = self._convert_units(verts, fro, 'norm') + verts = self._convert_units(verts, fro, "norm") # convert from normal to dest - verts = self._convert_units(verts, 'norm', to) + verts = self._convert_units(verts, "norm", to) return verts # figure out our actual transition, knowing one is 'norm' win_w_pix, win_h_pix = self.window_size_pix mon_w_pix, mon_h_pix = self.monitor_size_pix - wh_cm = np.array([self._monitor['SCREEN_WIDTH'], - self._monitor['SCREEN_HEIGHT']], float) - d_cm = self._monitor['SCREEN_DISTANCE'] - cm_factors = (self.window_size_pix / self.monitor_size_pix * - wh_cm / 2.)[:, np.newaxis] - if 'pix' in [to, fro]: - if 'pix' == to: + wh_cm = np.array( + [self._monitor["SCREEN_WIDTH"], self._monitor["SCREEN_HEIGHT"]], float + ) + d_cm = self._monitor["SCREEN_DISTANCE"] + cm_factors = (self.window_size_pix / self.monitor_size_pix * wh_cm / 2.0)[ + :, np.newaxis + ] + if "pix" in [to, fro]: + if "pix" == to: # norm to pixels - x = np.array([[win_w_pix / 2., 0, win_w_pix / 2.], - [0, win_h_pix / 2., win_h_pix / 2.]]) + x = np.array( + [ + [win_w_pix / 2.0, 0, win_w_pix / 2.0], + [0, win_h_pix / 2.0, win_h_pix / 2.0], + ] + ) else: # pixels to norm - x = np.array([[2. / win_w_pix, 0, -1.], - [0, 2. / win_h_pix, -1.]]) + x = np.array([[2.0 / win_w_pix, 0, -1.0], [0, 2.0 / win_h_pix, -1.0]]) verts = np.dot(x, np.r_[verts, np.ones((1, verts.shape[1]))]) - elif 'deg' in [to, fro]: - if 'deg' == to: + elif "deg" in [to, fro]: + if "deg" == to: # norm (window) to norm (whole screen), then to deg verts = np.rad2deg(np.arctan2(verts * cm_factors, d_cm)) else: # deg to norm (whole screen), to norm (window) verts = (d_cm * np.tan(np.deg2rad(verts))) / cm_factors - elif 'cm' in [to, fro]: - if 'cm' == to: + elif "cm" in [to, fro]: + if "cm" == to: verts = verts * cm_factors else: verts = verts / cm_factors else: - raise KeyError('unknown conversion "{}" to "{}"'.format(fro, to)) + raise KeyError(f'unknown conversion "{fro}" to "{to}"') return verts def screenshot(self): @@ -817,9 +936,10 @@ def screenshot(self): """ import pyglet from PIL import Image + tempdir = _TempDir() - fname = op.join(tempdir, 'tmp.png') - with open(fname, 'wb') as fid: + fname = op.join(tempdir, "tmp.png") + with open(fname, "wb") as fid: pyglet.image.get_buffer_manager().get_color_buffer().save(file=fid) with Image.open(fname) as img: data = np.array(img) @@ -844,7 +964,7 @@ def window(self): @property def dpi(self): - return self._monitor['SCREEN_DPI'] + return self._monitor["SCREEN_DPI"] @property def window_size_pix(self): @@ -852,10 +972,10 @@ def window_size_pix(self): @property def monitor_size_pix(self): - return np.array(self._monitor['SCREEN_SIZE_PIX']) + return np.array(self._monitor["SCREEN_SIZE_PIX"]) -# ############################### VIDEO METHODS ############################### - def load_video(self, file_name, pos=(0, 0), units='norm', center=True): + # ############################### VIDEO METHODS ############################### + def load_video(self, file_name, pos=(0, 0), units="norm", center=True): """Load a video. Parameters @@ -877,26 +997,27 @@ def load_video(self, file_name, pos=(0, 0), units='norm', center=True): self.video = Video(self, file_name, pos, units) except MediaFormatException as exp: raise RuntimeError( - 'Something is wrong; probably you tried to load a ' - 'compressed video file but you do not have FFmpeg/Avbin ' - 'installed. Download and install it; if you are on Windows, ' - 'you may also need to manually copy the .dll file(s) ' - 'from C:\\Windows\\system32 to C:\\Windows\\SysWOW64.:\n%s' - % (exp,)) + "Something is wrong; probably you tried to load a " + "compressed video file but you do not have FFmpeg/Avbin " + "installed. Download and install it; if you are on Windows, " + "you may also need to manually copy the .dll file(s) " + "from C:\\Windows\\system32 to C:\\Windows\\SysWOW64.:\n%s" % (exp,) + ) def delete_video(self): """Delete the video.""" self.video._delete() self.video = None -# ############################### PYGLET EVENTS ############################### -# https://pyglet.readthedocs.io/en/latest/programming_guide/eventloop.html#dispatching-events-manually # noqa + # ############################### PYGLET EVENTS ############################### + # https://pyglet.readthedocs.io/en/latest/programming_guide/eventloop.html#dispatching-events-manually # noqa def _setup_event_loop(self): - from pyglet.app import platform_event_loop, event_loop + from pyglet.app import event_loop, platform_event_loop + event_loop.has_exit = False platform_event_loop.start() - event_loop.dispatch_event('on_enter') + event_loop.dispatch_event("on_enter") event_loop.is_running = True self._extra_cleanup_fun.append(self._end_event_loop) # This is when Pyglet calls: @@ -905,6 +1026,7 @@ def _setup_event_loop(self): def _dispatch_events(self): import pyglet + pyglet.clock.tick() self._win.dispatch_events() # timeout = self._event_loop.idle() @@ -912,27 +1034,40 @@ def _dispatch_events(self): pyglet.app.platform_event_loop.step(timeout) def _end_event_loop(self): - from pyglet.app import platform_event_loop, event_loop + from pyglet.app import event_loop, platform_event_loop + event_loop.is_running = False - event_loop.dispatch_event('on_exit') + event_loop.dispatch_event("on_exit") platform_event_loop.stop() -# ############################### OPENGL METHODS ############################## + # ############################### OPENGL METHODS ############################## def _setup_window(self, window_size, exp_name, full_screen, screen): import pyglet from pyglet import gl + # Use 16x sampling here - config_kwargs = dict(depth_size=8, double_buffer=True, stereo=False, - stencil_size=0, samples=0, sample_buffers=0) + config_kwargs = dict( + depth_size=8, + double_buffer=True, + stereo=False, + stencil_size=0, + samples=0, + sample_buffers=0, + ) # Travis can't handle multi-sampling, but our production machines must - if os.getenv('TRAVIS') == 'true': - del config_kwargs['samples'], config_kwargs['sample_buffers'] + if os.getenv("TRAVIS") == "true": + del config_kwargs["samples"], config_kwargs["sample_buffers"] self._full_screen = full_screen - win_kwargs = dict(width=int(window_size[0]), - height=int(window_size[1]), - caption=exp_name, fullscreen=False, - screen=screen, style='borderless', visible=False, - config=pyglet.gl.Config(**config_kwargs)) + win_kwargs = dict( + width=int(window_size[0]), + height=int(window_size[1]), + caption=exp_name, + fullscreen=False, + screen=screen, + style="borderless", + visible=False, + config=pyglet.gl.Config(**config_kwargs), + ) max_try = 5 # sometimes it fails for unknown reasons for ii in range(max_try): @@ -944,16 +1079,15 @@ def _setup_window(self, window_size, exp_name, full_screen, screen): else: break if not full_screen: - x = int(win.screen.width / 2. - win.width / 2.) - y = int(win.screen.height / 2. - win.height / 2.) + x = int(win.screen.width / 2.0 - win.width / 2.0) + y = int(win.screen.height / 2.0 - win.height / 2.0) win.set_location(x, y) self._win = win # with the context set up, do basic GL initialization gl.glClearColor(0.0, 0.0, 0.0, 1.0) # set the color to clear to gl.glClearDepth(1.0) # clear value for the depth buffer # set the viewport size - gl.glViewport(0, 0, int(self.window_size_pix[0]), - int(self.window_size_pix[1])) + gl.glViewport(0, 0, int(self.window_size_pix[0]), int(self.window_size_pix[1])) # set the projection matrix gl.glMatrixMode(gl.GL_PROJECTION) gl.glLoadIdentity() @@ -968,18 +1102,21 @@ def _setup_window(self, window_size, exp_name, full_screen, screen): gl.glBlendFunc(gl.GL_SRC_ALPHA, gl.GL_ONE_MINUS_SRC_ALPHA) gl.glShadeModel(gl.GL_SMOOTH) gl.glClear(gl.GL_COLOR_BUFFER_BIT) - v_ = False if os.getenv('_EXPYFUN_WIN_INVISIBLE') == 'true' else True + v_ = False if os.getenv("_EXPYFUN_WIN_INVISIBLE") == "true" else True self.set_visible(v_) # this is when we set fullscreen # ensure we got the correct window size got_size = win.get_size() if not np.array_equal(got_size, window_size): - raise RuntimeError('Window size requested by config (%s) does not ' - 'match obtained window size (%s), is the ' - 'screen resolution set incorrectly?' - % (window_size, got_size)) + raise RuntimeError( + "Window size requested by config (%s) does not " + "match obtained window size (%s), is the " + "screen resolution set incorrectly?" % (window_size, got_size) + ) self._dispatch_events() - logger.info('Initialized %s window on screen %s with DPI %0.2f' - % (window_size, screen, self.dpi)) + logger.info( + "Initialized %s window on screen %s with DPI %0.2f" + % (window_size, screen, self.dpi) + ) def flip(self, when=None): """Flip screen, then run any "on-flip" functions. @@ -1013,6 +1150,7 @@ def flip(self, when=None): `call_on_every_flip`. """ from pyglet import gl + if when is not None: self.wait_until(when) call_list = self._on_next_flip + self._on_every_flip @@ -1033,7 +1171,7 @@ def flip(self, when=None): flip_time = self.get_time() for function in call_list: function() - self.write_data_line('flip', flip_time) + self.write_data_line("flip", flip_time) self._on_next_flip = [] return flip_time @@ -1056,7 +1194,7 @@ def estimate_screen_fs(self, n_rep=10): """ n_rep = int(n_rep) times = [self.flip() for _ in range(n_rep)] - return 1. / np.median(np.diff(times[1:])) + return 1.0 / np.median(np.diff(times[1:])) def set_visible(self, visible=True, flip=True): """Set the window visibility @@ -1075,13 +1213,13 @@ def set_visible(self, visible=True, flip=True): """ self._win.set_fullscreen(self._full_screen) self._win.set_visible(visible) - logger.exp('Expyfun: Set screen visibility {0}'.format(visible)) + logger.exp(f"Expyfun: Set screen visibility {visible}") if visible and flip: self.flip() # it seems like newer Pyglet sometimes messes up without two flips self.flip() -# ############################## KEYPRESS METHODS ############################# + # ############################## KEYPRESS METHODS ############################# def listen_presses(self): """Start listening for keypresses. @@ -1093,8 +1231,14 @@ def listen_presses(self): """ self._response_handler.listen_presses() - def get_presses(self, live_keys=None, timestamp=True, relative_to=None, - kind='presses', return_kinds=False): + def get_presses( + self, + live_keys=None, + timestamp=True, + relative_to=None, + kind="presses", + return_kinds=False, + ): """Get the entire keyboard / button box buffer. This will also clear events that are not requested per ``type``. @@ -1140,10 +1284,17 @@ def get_presses(self, live_keys=None, timestamp=True, relative_to=None, ExperimentController.wait_for_presses """ return self._response_handler.get_presses( - live_keys, timestamp, relative_to, kind, return_kinds) - - def wait_one_press(self, max_wait=np.inf, min_wait=0.0, live_keys=None, - timestamp=True, relative_to=None): + live_keys, timestamp, relative_to, kind, return_kinds + ) + + def wait_one_press( + self, + max_wait=np.inf, + min_wait=0.0, + live_keys=None, + timestamp=True, + relative_to=None, + ): """Returns only the first button pressed after min_wait. Parameters @@ -1182,10 +1333,12 @@ def wait_one_press(self, max_wait=np.inf, min_wait=0.0, live_keys=None, ExperimentController.wait_for_presses """ return self._response_handler.wait_one_press( - max_wait, min_wait, live_keys, timestamp, relative_to) + max_wait, min_wait, live_keys, timestamp, relative_to + ) - def wait_for_presses(self, max_wait, min_wait=0.0, live_keys=None, - timestamp=True, relative_to=None): + def wait_for_presses( + self, max_wait, min_wait=0.0, live_keys=None, timestamp=True, relative_to=None + ): """Returns all button presses between min_wait and max_wait. Parameters @@ -1222,9 +1375,10 @@ def wait_for_presses(self, max_wait, min_wait=0.0, live_keys=None, ExperimentController.wait_one_press """ return self._response_handler.wait_for_presses( - max_wait, min_wait, live_keys, timestamp, relative_to) + max_wait, min_wait, live_keys, timestamp, relative_to + ) - def _log_presses(self, pressed, kind='key'): + def _log_presses(self, pressed, kind="key"): """Write key presses to data file.""" # This function will typically be called by self._response_handler # after it retrieves some button presses @@ -1235,10 +1389,18 @@ def check_force_quit(self): """Check to see if any force quit keys were pressed.""" self._response_handler.check_force_quit() - def text_input(self, stop_key='return', instruction_string='Type' - ' response below', pos=[0, 0], color='white', - font_name='Arial', font_size=24, wrap=True, units='norm', - all_caps=True): + def text_input( + self, + stop_key="return", + instruction_string="Type" " response below", + pos=(0, 0), + color="white", + font_name="Arial", + font_size=24, + wrap=True, + units="norm", + all_caps=True, + ): """Allows participant to input text and view on the screen. Parameters @@ -1271,28 +1433,34 @@ def text_input(self, stop_key='return', instruction_string='Type' text : str The final input string. """ - letters = string.ascii_letters + ' ' - text = str() + letters = string.ascii_letters + " " + text = "" while True: - self.screen_text(instruction_string + '\n\n' + text + '|', - pos=pos, color=color, - font_name=font_name, font_size=font_size, - wrap=wrap, units=units, log_data=False) + self.screen_text( + instruction_string + "\n\n" + text + "|", + pos=pos, + color=color, + font_name=font_name, + font_size=font_size, + wrap=wrap, + units=units, + log_data=False, + ) self.flip() letter = self.wait_one_press(timestamp=False) if letter == stop_key: self.flip() break - if letter == 'backspace': + if letter == "backspace": text = text[:-1] else: - letter = ' ' if letter == 'space' else letter + letter = " " if letter == "space" else letter letter = letter.upper() if all_caps else letter - text += letter if letter in letters else '' - self.write_data_line('text_input', text) + text += letter if letter in letters else "" + self.write_data_line("text_input", text) return text -# ############################## KEYPRESS METHODS ############################# + # ############################## KEYPRESS METHODS ############################# def listen_joystick_button_presses(self): """Start listening for joystick buttons. @@ -1302,8 +1470,9 @@ def listen_joystick_button_presses(self): """ self._joystick_handler.listen_presses() - def get_joystick_button_presses(self, timestamp=True, relative_to=None, - kind='presses', return_kinds=False): + def get_joystick_button_presses( + self, timestamp=True, relative_to=None, kind="presses", return_kinds=False + ): """Get the entire joystick buffer. This will also clear events that are not requested per ``type``. @@ -1338,7 +1507,8 @@ def get_joystick_button_presses(self, timestamp=True, relative_to=None, """ self._dispatch_events() return self._joystick_handler.get_presses( - None, timestamp, relative_to, kind, return_kinds) + None, timestamp, relative_to, kind, return_kinds + ) def get_joystick_value(self, kind): """Get the current joystick x direction. @@ -1355,7 +1525,7 @@ def get_joystick_value(self, kind): """ return getattr(self._joystick_handler, kind) -# ############################## MOUSE METHODS ################################ + # ############################## MOUSE METHODS ################################ def listen_clicks(self): """Start listening for mouse clicks. @@ -1401,10 +1571,9 @@ def get_clicks(self, live_buttons=None, timestamp=True, relative_to=None): ExperimentController.wait_one_click ExperimentController.wait_for_clicks """ - return self._mouse_handler.get_clicks(live_buttons, timestamp, - relative_to) + return self._mouse_handler.get_clicks(live_buttons, timestamp, relative_to) - def get_mouse_position(self, units='pix'): + def get_mouse_position(self, units="pix"): """Mouse position in screen coordinates Parameters @@ -1427,7 +1596,7 @@ def get_mouse_position(self, units='pix'): """ check_units(units) pos = np.array(self._mouse_handler.pos) - pos = self._convert_units(pos[:, np.newaxis], 'norm', units)[:, 0] + pos = self._convert_units(pos[:, np.newaxis], "norm", units)[:, 0] return pos def toggle_cursor(self, visibility, flip=False): @@ -1456,8 +1625,15 @@ def toggle_cursor(self, visibility, flip=False): if flip: self.flip() - def wait_one_click(self, max_wait=np.inf, min_wait=0.0, live_buttons=None, - timestamp=True, relative_to=None, visible=None): + def wait_one_click( + self, + max_wait=np.inf, + min_wait=0.0, + live_buttons=None, + timestamp=True, + relative_to=None, + visible=None, + ): """Returns only the first mouse button clicked after min_wait. Parameters @@ -1501,11 +1677,11 @@ def wait_one_click(self, max_wait=np.inf, min_wait=0.0, live_buttons=None, ExperimentController.toggle_cursor ExperimentController.wait_for_clicks """ - return self._mouse_handler.wait_one_click(max_wait, min_wait, - live_buttons, timestamp, - relative_to, visible) + return self._mouse_handler.wait_one_click( + max_wait, min_wait, live_buttons, timestamp, relative_to, visible + ) - def move_mouse_to(self, pos, units='norm'): + def move_mouse_to(self, pos, units="norm"): """Move the mouse position to the specified position. Parameters @@ -1517,8 +1693,15 @@ def move_mouse_to(self, pos, units='norm'): """ self._mouse_handler._move_to(pos, units) - def wait_for_clicks(self, max_wait=np.inf, min_wait=0.0, live_buttons=None, - timestamp=True, relative_to=None, visible=None): + def wait_for_clicks( + self, + max_wait=np.inf, + min_wait=0.0, + live_buttons=None, + timestamp=True, + relative_to=None, + visible=None, + ): """Returns all clicks between min_wait and max_wait. Parameters @@ -1561,12 +1744,19 @@ def wait_for_clicks(self, max_wait=np.inf, min_wait=0.0, live_buttons=None, ExperimentController.toggle_cursor ExperimentController.wait_one_click """ - return self._mouse_handler.wait_for_clicks(max_wait, min_wait, - live_buttons, timestamp, - relative_to, visible) - - def wait_for_click_on(self, objects, max_wait=np.inf, min_wait=0.0, - live_buttons=None, timestamp=True, relative_to=None): + return self._mouse_handler.wait_for_clicks( + max_wait, min_wait, live_buttons, timestamp, relative_to, visible + ) + + def wait_for_click_on( + self, + objects, + max_wait=np.inf, + min_wait=0.0, + live_buttons=None, + timestamp=True, + relative_to=None, + ): """Returns the first click after min_wait over a visual object. Parameters @@ -1608,21 +1798,19 @@ def wait_for_click_on(self, objects, max_wait=np.inf, min_wait=0.0, if isinstance(objects, legal_types): objects = [objects] elif not isinstance(objects, list): - raise TypeError('objects must be a list or one of: %s' % - (legal_types,)) + raise TypeError("objects must be a list or one of: %s" % (legal_types,)) return self._mouse_handler.wait_for_click_on( - objects, max_wait, min_wait, live_buttons, timestamp, relative_to) + objects, max_wait, min_wait, live_buttons, timestamp, relative_to + ) def _log_clicks(self, clicked): - """Write mouse clicks to data file. - """ + """Write mouse clicks to data file.""" # This function will typically be called by self._response_handler # after it retrieves some mouse clicks for button, x, y, stamp in clicked: - self.write_data_line('mouseclick', '%s,%i,%i' % (button, x, y), - stamp) + self.write_data_line("mouseclick", "%s,%i,%i" % (button, x, y), stamp) -# ############################## AUDIO METHODS ################################ + # ############################## AUDIO METHODS ################################ def system_beep(self): """Play a system beep @@ -1674,13 +1862,13 @@ def load_buffer(self, samples): ExperimentController.stop """ if self._playing: - raise RuntimeError('Previous audio must be stopped before loading ' - 'the buffer') + raise RuntimeError( + "Previous audio must be stopped before loading " "the buffer" + ) samples = self._validate_audio(samples) - if not np.isclose(self._stim_scaler, 1.): + if not np.isclose(self._stim_scaler, 1.0): samples = samples * self._stim_scaler - logger.exp('Expyfun: Loading {} samples to buffer' - ''.format(samples.size)) + logger.exp(f"Expyfun: Loading {samples.size} samples to buffer" "") self._ac.load_buffer(samples) def play(self): @@ -1698,28 +1886,33 @@ def play(self): ExperimentController.start_stimulus ExperimentController.stop """ - logger.exp('Expyfun: Playing audio') + logger.exp("Expyfun: Playing audio") # ensure self._play comes first in list: self._play() return self.get_time() def _play(self): - """Play the audio buffer. - """ + """Play the audio buffer.""" if self._playing: - raise RuntimeError('Previous audio must be stopped before playing') + raise RuntimeError("Previous audio must be stopped before playing") self._ac.play() - logger.debug('Expyfun: started audio') - self.write_data_line('play') + logger.debug("Expyfun: started audio") + self.write_data_line("play") @property def _playing(self): """Whether or not a stimulus is currently playing""" return self._ac.playing - def stop(self): + def stop(self, wait=False): """Stop audio buffer playback and reset cursor to beginning of buffer + Parameters + ---------- + wait : bool + If True, try to wait until the end of the sound stimulus + (not guaranteed to yield accurate timings!). + See Also -------- ExperimentController.load_buffer @@ -1728,9 +1921,9 @@ def stop(self): ExperimentController.start_stimulus """ if self._ac is not None: # need to check b/c used in __exit__ - self._ac.stop() - self.write_data_line('stop') - logger.exp('Expyfun: Audio stopped and reset.') + self._ac.stop(wait=wait) + self.write_data_line("stop") + logger.exp("Expyfun: Audio stopped and reset.") def set_noise_db(self, new_db): """Set the level of the background noise. @@ -1769,10 +1962,9 @@ def set_stim_db(self, new_db): # not immediate: new value is applied on the next load_buffer call def _update_sound_scaler(self, desired_db, orig_rms): - """Calcs coefficient ensuring stim ampl equivalence across devices. - """ - exponent = (-(_get_dev_db(self.audio_type) - desired_db) / 20.0) - return (10 ** exponent) / float(orig_rms) + """Calcs coefficient ensuring stim ampl equivalence across devices.""" + exponent = -(_get_dev_db(self.audio_type) - desired_db) / 20.0 + return (10**exponent) / float(orig_rms) def _validate_audio(self, samples): """Converts audio sample data to the required format. @@ -1793,7 +1985,7 @@ def _validate_audio(self, samples): # check values if samples.size and np.max(np.abs(samples)) > 1: - raise ValueError('Sound data exceeds +/- 1.') + raise ValueError("Sound data exceeds +/- 1.") # samples /= np.max(np.abs(samples),axis=0) # check shape and dimensions, make stereo @@ -1804,52 +1996,60 @@ def _validate_audio(self, samples): if np.isclose(self.stim_fs, 24414, atol=1): max_samples = 4000000 - 1 if samples.shape[0] > max_samples: - raise RuntimeError('Sample too long {0} > {1}' - ''.format(samples.shape[0], max_samples)) + raise RuntimeError( + f"Sample too long {samples.shape[0]} > {max_samples}" "" + ) # resample if needed if self._fs_mismatch and not self._suppress_resamp: - logger.warning('Expyfun: Resampling {} seconds of audio' - ''.format(round(len(samples) / self.stim_fs, 2))) + logger.warning( + f"Expyfun: Resampling {round(len(samples) / self.stim_fs, 2)} " + "seconds of audio" + ) with warnings.catch_warnings(record=True): - warnings.simplefilter('ignore') + warnings.simplefilter("ignore") from mne.filter import resample if samples.size: samples = resample( - samples.astype(np.float64), self.fs, self.stim_fs, - axis=0).astype(np.float32) + samples.astype(np.float64), self.fs, self.stim_fs, axis=0 + ).astype(np.float32) # check RMS if self._check_rms is not None and samples.size: chans = [samples[:, x] for x in range(samples.shape[1])] - if self._check_rms == 'wholefile': - chan_rms = [np.sqrt(np.mean(x ** 2)) for x in chans] + if self._check_rms == "wholefile": + chan_rms = [np.sqrt(np.mean(x**2)) for x in chans] max_rms = max(chan_rms) else: # 'windowed' # ~226 sec at 44100 Hz if samples.size >= _SLOW_LIMIT and not self._slow_rms_warned: warnings.warn( - 'Checking RMS with a 10 ms window and many samples is ' - 'slow, consider using None or "wholefile" modes.') + "Checking RMS with a 10 ms window and many samples is " + 'slow, consider using None or "wholefile" modes.' + ) self._slow_rms_warned = True win_length = int(self.fs * 0.01) # 10ms running window max_rms = [running_rms(x, win_length).max() for x in chans] max_rms = max(max_rms) if max_rms > 2 * self._stim_rms: - warn_string = ('Expyfun: Stimulus max RMS ({}) exceeds stated ' - 'RMS ({}) by more than 6 dB.' - ''.format(max_rms, self._stim_rms)) + warn_string = ( + f"Expyfun: Stimulus max RMS ({max_rms}) exceeds stated " + f"RMS ({self._stim_rms}) by more than 6 dB." + "" + ) logger.warning(warn_string) warnings.warn(warn_string) elif max_rms < 0.5 * self._stim_rms: - warn_string = ('Expyfun: Stimulus max RMS ({}) is less than ' - 'stated RMS ({}) by more than 6 dB.' - ''.format(max_rms, self._stim_rms)) + warn_string = ( + f"Expyfun: Stimulus max RMS ({max_rms}) is less than " + f"stated RMS ({self._stim_rms}) by more than 6 dB." + "" + ) logger.warning(warn_string) # let's make sure we don't change our version of this array later samples = samples.view() - samples.flags['WRITEABLE'] = False + samples.flags["WRITEABLE"] = False return samples def set_rms_checking(self, check_rms): @@ -1864,24 +2064,25 @@ def set_rms_checking(self, check_rms): ``stim_rms``. ``'wholefile'`` checks the RMS of the stimulus as a whole, while ``None`` disables RMS checking. """ - if check_rms not in [None, 'wholefile', 'windowed']: - raise ValueError('check_rms must be one of "wholefile", "windowed"' - ', or None.') + if check_rms not in [None, "wholefile", "windowed"]: + raise ValueError( + 'check_rms must be one of "wholefile", "windowed"' ", or None." + ) self._slow_rms_warned = False self._check_rms = check_rms -# ############################## OTHER METHODS ################################ + # ############################## OTHER METHODS ################################ @property def participant(self): - return self._exp_info['participant'] + return self._exp_info["participant"] @property def session(self): - return self._exp_info['session'] + return self._exp_info["session"] @property def exp_name(self): - return self._exp_info['exp_name'] + return self._exp_info["exp_name"] @property def data_fname(self): @@ -1917,35 +2118,38 @@ def write_data_line(self, event_type, value=None, timestamp=None): """ if timestamp is None: timestamp = self._master_clock() - ll = '\t'.join(_sanitize(x) for x in [timestamp, event_type, - value]) + '\n' + ll = "\t".join(_sanitize(x) for x in [timestamp, event_type, value]) + "\n" if self._data_file is not None: if self._data_file.closed: - logger.warning('Data line not written due to closed file %s:\n' - '%s' % (self.data_fname, ll[:-1])) + logger.warning( + "Data line not written due to closed file %s:\n" + "%s" % (self.data_fname, ll[:-1]) + ) else: self._data_file.write(ll) self.flush() def _get_time_correction(self, clock_type): - """Clock correction (sec) for different devices (screen, bbox, etc.) - """ - time_correction = (self._master_clock() - - self._time_correction_fxns[clock_type]()) + """Clock correction (sec) for different devices (screen, bbox, etc.)""" + time_correction = ( + self._master_clock() - self._time_correction_fxns[clock_type]() + ) if clock_type not in self._time_corrections: self._time_corrections[clock_type] = time_correction diff = time_correction - self._time_corrections[clock_type] max_dt = self._time_correction_maxs.get(clock_type, 50e-6) if np.abs(diff) > max_dt: - logger.warning('Expyfun: drift of > {} microseconds ({}) ' - 'between {} clock and EC master clock.' - ''.format(max_dt * 1e6, int(round(diff * 1e6)), - clock_type)) - logger.debug('Expyfun: time correction between {} clock and EC ' - 'master clock is {}. This is a change of {}.' - ''.format(clock_type, time_correction, time_correction - - self._time_corrections[clock_type])) + logger.warning( + f"Expyfun: drift of > {max_dt * 1e6} microseconds " + f"({int(round(diff * 1e6))}) " + f"between {clock_type} clock and EC master clock." + ) + logger.debug( + f"Expyfun: time correction between {clock_type} clock and EC " + f"master clock is {time_correction}. This is a change of " + f"{time_correction - self._time_corrections[clock_type]}." + ) return time_correction def wait_secs(self, secs): @@ -1991,9 +2195,11 @@ def wait_until(self, timestamp): """ time_left = timestamp - self._master_clock() if time_left < 0: - logger.warning('Expyfun: wait_until was called with a timestamp ' - '({}) that had already passed {} seconds prior.' - ''.format(timestamp, -time_left)) + logger.warning( + "Expyfun: wait_until was called with a timestamp " + f"({timestamp}) that had already passed {-time_left} seconds prior." + "" + ) else: self.wait_secs(time_left) return time_left @@ -2018,22 +2224,23 @@ def identify_trial(self, **ids): ExperimentController.stop ExperimentController.trial_ok """ - if self._trial_progress != 'stopped': - raise RuntimeError('Cannot identify a trial twice') + if self._trial_progress != "stopped": + raise RuntimeError("Cannot identify a trial twice") call_set = set(self._id_call_dict.keys()) passed_set = set(ids.keys()) if not call_set == passed_set: - raise KeyError('All keys passed in {0} must match the set of ' - 'keys required {1}'.format(passed_set, call_set)) + raise KeyError( + f"All keys passed in {passed_set} must match the set of " + f"keys required {call_set}" + ) ll = max([len(key) for key in ids.keys()]) for key, id_ in ids.items(): - logger.exp('Expyfun: Stamp trial ID to {0} : {1}' - ''.format(key.ljust(ll), str(id_))) + logger.exp(f"Expyfun: Stamp trial ID to {key.ljust(ll)} : {str(id_)}" "") if isinstance(id_, dict): self._id_call_dict[key](**id_) else: self._id_call_dict[key](id_) - self._trial_progress = 'identified' + self._trial_progress = "identified" def trial_ok(self): """Report that the trial was okay and do post-trial tasks. @@ -2047,19 +2254,21 @@ def trial_ok(self): ExperimentController.start_stimulus ExperimentController.stop """ - if self._trial_progress != 'started': - raise RuntimeError('trial cannot be okay unless it was started, ' - 'did you call ec.start_stimulus?') + if self._trial_progress != "started": + raise RuntimeError( + "trial cannot be okay unless it was started, " + "did you call ec.start_stimulus?" + ) if self._playing: - logger.warning('ec.trial_ok called before stimulus had stopped') + logger.warning("ec.trial_ok called before stimulus had stopped") for func in self._on_trial_ok: func() - logger.exp('Expyfun: Trial OK') - self._trial_progress = 'stopped' + logger.exp("Expyfun: Trial OK") + self._trial_progress = "stopped" def _stamp_ec_id(self, id_): """Stamp id -- currently anything allowed""" - self.write_data_line('trial_id', id_) + self.write_data_line("trial_id", id_) def _stamp_binary_id(self, id_, wait_for_last=True): """Helper for ec to stamp a set of IDs using binary controller @@ -2069,14 +2278,14 @@ def _stamp_binary_id(self, id_, wait_for_last=True): but for now it's unified. ``delay`` is the inter-trigger delay. """ if not isinstance(id_, (list, tuple, np.ndarray)): - raise TypeError('id must be array-like') + raise TypeError("id must be array-like") id_ = np.array(id_) - if not np.all(np.in1d(id_, [0, 1])): - raise ValueError('All values of id must be 0 or 1') + if not np.all(np.isin(id_, [0, 1])): + raise ValueError("All values of id must be 0 or 1") id_ = (id_.astype(int) + 1) << 2 # 0, 1 -> 4, 8 self._stamp_ttl_triggers(id_, wait_for_last, True) - def stamp_triggers(self, ids, check='binary', wait_for_last=True): + def stamp_triggers(self, ids, check="binary", wait_for_last=True): """Stamp binary values Parameters @@ -2100,22 +2309,24 @@ def stamp_triggers(self, ids, check='binary', wait_for_last=True): -------- ExperimentController.identify_trial """ - if check not in ('int4', 'binary'): + if check not in ("int4", "binary"): raise ValueError('Check must be either "int4" or "binary"') ids = [ids] if not isinstance(ids, list) else ids if not all(isinstance(id_, int) and 1 <= id_ <= 15 for id_ in ids): - raise ValueError('ids must all be integers between 1 and 15') - if check == 'binary': + raise ValueError("ids must all be integers between 1 and 15") + if check == "binary": _vals = [1, 2, 4, 8] if not all(id_ in _vals for id_ in ids): - raise ValueError('with check="binary", ids must all be ' - '1, 2, 4, or 8: {0}'.format(ids)) + raise ValueError( + 'with check="binary", ids must all be ' f"1, 2, 4, or 8: {ids}" + ) self._stamp_ttl_triggers(ids, wait_for_last, False) def _stamp_ttl_triggers(self, ids, wait_for_last, is_trial_id): - logger.exp('Stamping TTL triggers: %s', ids) + logger.exp("Stamping TTL triggers: %s", ids) self._tc.stamp_triggers( - ids, wait_for_last=wait_for_last, is_trial_id=is_trial_id) + ids, wait_for_last=wait_for_last, is_trial_id=is_trial_id + ) self.flush() def flush(self): @@ -2125,24 +2336,22 @@ def flush(self): self._data_file.flush() def close(self): - """Close all connections in experiment controller. - """ + """Close all connections in experiment controller.""" self.__exit__(None, None, None) def __enter__(self): - logger.debug('Expyfun: Entering') + logger.debug("Expyfun: Entering") return self def __exit__(self, err_type, value, traceback): - """ - Notes - ----- + """Exit cleanly. + err_type, value and traceback will be None when called by self.close() """ - logger.info('Expyfun: Exiting') + logger.info("Expyfun: Exiting") # do external cleanups cleanup_actions = [] - if hasattr(self, '_win'): + if hasattr(self, "_win"): cleanup_actions.append(self._win.close) cleanup_actions.extend([self.stop_noise, self.stop]) cleanup_actions.extend(self._extra_cleanup_fun) @@ -2171,66 +2380,63 @@ def refocus(self): This function currently does nothing on Linux and OSX. """ # noqa: E501 - if sys.platform == 'win32': + if sys.platform == "win32": from pyglet.libs.win32 import _user32 + m_hWnd = self._win._hwnd hCurWnd = _user32.GetForegroundWindow() if hCurWnd != m_hWnd: + # m_hWnd, HWND_TOPMOST, ..., SWP_NOSIZE | SWP_NOMOVE + _user32.SetWindowPos(m_hWnd, -1, 0, 0, 0, 0, 0x0001 | 0x0002) dwMyID = _user32.GetWindowThreadProcessId(m_hWnd, 0) dwCurID = _user32.GetWindowThreadProcessId(hCurWnd, 0) _user32.AttachThreadInput(dwCurID, dwMyID, True) - # _user32.SetWindowPos(m_hWnd, HWND_TOPMOST, 0, 0, 0, 0, - # SWP_NOSIZE | SWP_NOMOVE) - # _user32.SetWindowPos(m_hWnd, HWND_NOTOPMOST, 0, 0, 0, 0, - # SWP_NOSIZE | SWP_NOMOVE) self._win.activate() # _user32.SetForegroundWindow(m_hWnd) _user32.AttachThreadInput(dwCurID, dwMyID, False) _user32.SetFocus(m_hWnd) _user32.SetActiveWindow(m_hWnd) -# ############################## READ-ONLY PROPERTIES ######################### + # ############################## READ-ONLY PROPERTIES ######################### @property def id_types(self): - """Trial ID types needed for each trial. - """ + """Trial ID types needed for each trial.""" return sorted(self._id_call_dict.keys()) @property def fs(self): - """Playback frequency of the audio controller (samples / second). - """ + """Playback frequency of the audio controller (samples / second).""" return self._ac.fs # not user-settable @property def stim_fs(self): - """Sampling rate at which the stimuli were generated. - """ + """Sampling rate at which the stimuli were generated.""" return self._stim_fs # not user-settable @property def stim_db(self): - """Sound power in dB of the stimuli. - """ + """Sound power in dB of the stimuli.""" return self._stim_db # not user-settable @property def noise_db(self): - """Sound power in dB of the background noise. - """ + """Sound power in dB of the background noise.""" return self._noise_db # not user-settable @property def current_time(self): - """Timestamp from the experiment master clock. - """ + """Timestamp from the experiment master clock.""" return self._master_clock() @property def _fs_mismatch(self): - """Quantify if sample rates substantively differ. - """ + """Quantify if sample rates substantively differ.""" return not np.allclose(self.stim_fs, self.fs, rtol=0, atol=0.5) + # Testing cruft to work around "queue full" errors on Windows + def _ac_flush(self): + if isinstance(getattr(self, "_ac", None), SoundCardController): + self._ac.halt() + def get_keyboard_input(prompt, default=None, out_type=str, valid=None): """Get keyboard input of a specific type @@ -2258,11 +2464,11 @@ def get_keyboard_input(prompt, default=None, out_type=str, valid=None): # pass a lambda, e.g., that made sure a float was in a given range # TODO: add tests if not isinstance(out_type, type): - raise TypeError('out_type must be a type') + raise TypeError("out_type must be a type") good = False while not good: response = input(prompt) - if response == '' and default is not None: + if response == "" and default is not None: response = default try: response = out_type(response) @@ -2276,27 +2482,28 @@ def get_keyboard_input(prompt, default=None, out_type=str, valid=None): def _get_dev_db(audio_controller): - """Selects device-specific amplitude to ensure equivalence across devices. - """ + """Selects device-specific amplitude to ensure equivalence across devices.""" # First try to get the level from the expyfun.json file. - level = get_config('DB_OF_SINE_AT_1KHZ_1RMS') + level = get_config("DB_OF_SINE_AT_1KHZ_1RMS") if level is None: level = dict( - RM1=108., # approx w/ knob @ 12 o'clock (knob not detented) - RP2=108., - RP2legacy=108., - RZ6=114., + RM1=108.0, # approx w/ knob @ 12 o'clock (knob not detented) + RP2=108.0, + RP2legacy=108.0, + RZ6=114.0, # TODO: these values not calibrated, system-dependent - pyglet=100., - rtmixer=100., - dummy=100., # only used for testing + pyglet=100.0, + rtmixer=100.0, + dummy=100.0, # only used for testing ).get(audio_controller, None) else: level = float(level) if level is None: - logger.warning('Expyfun: Unknown audio controller %s: stim scaler may ' - 'not work correctly. You may want to remove your ' - 'headphones if this is the first run of your ' - 'experiment.' % (audio_controller,)) + logger.warning( + "Expyfun: Unknown audio controller %s: stim scaler may " + "not work correctly. You may want to remove your " + "headphones if this is the first run of your " + "experiment." % (audio_controller,) + ) level = 100 # for untested TDT models return level diff --git a/expyfun/_externals/__init__.py b/expyfun/_externals/__init__.py deleted file mode 100644 index f6748740..00000000 --- a/expyfun/_externals/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# -*- coding: utf-8 -*- - -from .decorator import decorator # noqa -from ._h5io import read_hdf5, write_hdf5 # noqa, analysis:ignore diff --git a/expyfun/_externals/_h5io.py b/expyfun/_externals/_h5io.py deleted file mode 100644 index 0130dff3..00000000 --- a/expyfun/_externals/_h5io.py +++ /dev/null @@ -1,425 +0,0 @@ -# -*- coding: utf-8 -*- -# Authors: Eric Larson -# -# License: BSD (3-clause) - -import sys -import tempfile -from shutil import rmtree -from os import path as op - -import numpy as np -try: - from scipy import sparse -except ImportError: - sparse = None - -# Adapted from six -PY3 = sys.version_info[0] == 3 -text_type = str if PY3 else unicode # noqa -string_types = str if PY3 else basestring # noqa - -special_chars = {'{FWDSLASH}': '/'} - - -############################################################################## -# WRITING - -def _check_h5py(): - """Helper to check if h5py is installed""" - try: - import h5py - except ImportError: - raise ImportError('the h5py module is required to use HDF5 I/O') - return h5py - - -def _create_titled_group(root, key, title): - """Helper to create a titled group in h5py""" - out = root.create_group(key) - out.attrs['TITLE'] = title - return out - - -def _create_titled_dataset(root, key, title, data, comp_kw=None): - """Helper to create a titled dataset in h5py""" - comp_kw = {} if comp_kw is None else comp_kw - out = root.create_dataset(key, data=data, **comp_kw) - out.attrs['TITLE'] = title - return out - - -def _create_pandas_dataset(fname, root, key, title, data): - h5py = _check_h5py() - rootpath = '/'.join([root, key]) - data.to_hdf(fname, rootpath) - with h5py.File(fname, mode='a') as fid: - fid[rootpath].attrs['TITLE'] = 'pd_dataframe' - - -def write_hdf5(fname, data, overwrite=False, compression=4, - title='h5io', slash='error'): - """Write python object to HDF5 format using h5py - - Parameters - ---------- - fname : str - Filename to use. - data : object - Object to write. Can be of any of these types: - {ndarray, dict, list, tuple, int, float, str} - Note that dict objects must only have ``str`` keys. It is recommended - to use ndarrays where possible, as it is handled most efficiently. - overwrite : True | False | 'update' - If True, overwrite file (if it exists). If 'update', appends the title - to the file (or replace value if title exists). - compression : int - Compression level to use (0-9) to compress data using gzip. - title : str - The top-level directory name to use. Typically it is useful to make - this your package name, e.g. ``'mnepython'``. - slash : 'error' | 'replace' - Whether to replace forward-slashes ('/') in any key found nested within - keys in data. This does not apply to the top level name (title). - If 'error', '/' is not allowed in any lower-level keys. - """ - h5py = _check_h5py() - mode = 'w' - if op.isfile(fname): - if isinstance(overwrite, string_types): - if overwrite != 'update': - raise ValueError('overwrite must be "update" or a bool') - mode = 'a' - elif not overwrite: - raise IOError('file "%s" exists, use overwrite=True to overwrite' - % fname) - if not isinstance(title, string_types): - raise ValueError('title must be a string') - comp_kw = dict() - if compression > 0: - comp_kw = dict(compression='gzip', compression_opts=compression) - with h5py.File(fname, mode=mode) as fid: - if title in fid: - del fid[title] - cleanup_data = [] - _triage_write(title, data, fid, comp_kw, str(type(data)), - cleanup_data=cleanup_data, slash=slash, title=title) - - # Will not be empty if any extra data to be written - for data in cleanup_data: - # In case different extra I/O needs different inputs - title = list(data.keys())[0] - if title in ['pd_dataframe', 'pd_series']: - rootname, key, value = data[title] - _create_pandas_dataset(fname, rootname, key, title, value) - - -def _triage_write(key, value, root, comp_kw, where, - cleanup_data=[], slash='error', title=None): - if key != title and '/' in key: - if slash == 'error': - raise ValueError('Found a key with "/", ' - 'this is not allowed if slash == error') - elif slash == 'replace': - # Auto-replace keys with proper values - for key_spec, val_spec in special_chars.items(): - key = key.replace(val_spec, key_spec) - else: - raise ValueError("slash must be one of ['error', 'replace'") - - if isinstance(value, dict): - sub_root = _create_titled_group(root, key, 'dict') - for key, sub_value in value.items(): - if not isinstance(key, string_types): - raise TypeError('All dict keys must be strings') - _triage_write( - 'key_{0}'.format(key), sub_value, sub_root, comp_kw, - where + '["%s"]' % key, cleanup_data=cleanup_data, slash=slash) - elif isinstance(value, (list, tuple)): - title = 'list' if isinstance(value, list) else 'tuple' - sub_root = _create_titled_group(root, key, title) - for vi, sub_value in enumerate(value): - _triage_write( - 'idx_{0}'.format(vi), sub_value, sub_root, comp_kw, - where + '[%s]' % vi, cleanup_data=cleanup_data, slash=slash) - elif isinstance(value, type(None)): - _create_titled_dataset(root, key, 'None', [False]) - elif isinstance(value, (int, float)): - if isinstance(value, int): - title = 'int' - else: # isinstance(value, float): - title = 'float' - _create_titled_dataset(root, key, title, np.atleast_1d(value)) - elif isinstance(value, np.bool_): - _create_titled_dataset(root, key, 'np_bool_', np.atleast_1d(value)) - elif isinstance(value, string_types): - if isinstance(value, text_type): # unicode - value = np.fromstring(value.encode('utf-8'), np.uint8) - title = 'unicode' - else: - value = np.fromstring(value.encode('ASCII'), np.uint8) - title = 'ascii' - _create_titled_dataset(root, key, title, value, comp_kw) - elif isinstance(value, np.ndarray): - _create_titled_dataset(root, key, 'ndarray', value) - elif sparse is not None and isinstance(value, sparse.csc_matrix): - sub_root = _create_titled_group(root, key, 'csc_matrix') - _triage_write('data', value.data, sub_root, comp_kw, - where + '.csc_matrix_data', cleanup_data=cleanup_data, - slash=slash) - _triage_write('indices', value.indices, sub_root, comp_kw, - where + '.csc_matrix_indices', cleanup_data=cleanup_data, - slash=slash) - _triage_write('indptr', value.indptr, sub_root, comp_kw, - where + '.csc_matrix_indptr', cleanup_data=cleanup_data, - slash=slash) - elif sparse is not None and isinstance(value, sparse.csr_matrix): - sub_root = _create_titled_group(root, key, 'csr_matrix') - _triage_write('data', value.data, sub_root, comp_kw, - where + '.csr_matrix_data', cleanup_data=cleanup_data, - slash=slash) - _triage_write('indices', value.indices, sub_root, comp_kw, - where + '.csr_matrix_indices', cleanup_data=cleanup_data, - slash=slash) - _triage_write('indptr', value.indptr, sub_root, comp_kw, - where + '.csr_matrix_indptr', cleanup_data=cleanup_data, - slash=slash) - _triage_write('shape', value.shape, sub_root, comp_kw, - where + '.csr_matrix_shape', cleanup_data=cleanup_data, - slash=slash) - else: - try: - from pandas import DataFrame, Series - except ImportError: - pass - else: - if isinstance(value, (DataFrame, Series)): - if isinstance(value, DataFrame): - title = 'pd_dataframe' - else: - title = 'pd_series' - rootname = root.name - cleanup_data.append({title: (rootname, key, value)}) - return - - err_str = 'unsupported type %s (in %s)' % (type(value), where) - raise TypeError(err_str) - -############################################################################## -# READING - - -def read_hdf5(fname, title='h5io', slash='ignore'): - """Read python object from HDF5 format using h5py - - Parameters - ---------- - fname : str - File to load. - title : str - The top-level directory name to use. Typically it is useful to make - this your package name, e.g. ``'mnepython'``. - slash : 'ignore' | 'replace' - Whether to replace the string {FWDSLASH} with the value /. This does - not apply to the top level name (title). If 'ignore', nothing will be - replaced. - - Returns - ------- - data : object - The loaded data. Can be of any type supported by ``write_hdf5``. - """ - h5py = _check_h5py() - if not op.isfile(fname): - raise IOError('file "%s" not found' % fname) - if not isinstance(title, string_types): - raise ValueError('title must be a string') - with h5py.File(fname, mode='r') as fid: - if title not in fid: - raise ValueError('no "%s" data found' % title) - if isinstance(fid[title], h5py.Group): - if 'TITLE' not in fid[title].attrs: - raise ValueError('no "%s" data found' % title) - data = _triage_read(fid[title], slash=slash) - return data - - -def _triage_read(node, slash='ignore'): - if slash not in ['ignore', 'replace']: - raise ValueError("slash must be one of 'replace', 'ignore'") - h5py = _check_h5py() - type_str = node.attrs['TITLE'] - if isinstance(type_str, bytes): - type_str = type_str.decode() - if isinstance(node, h5py.Group): - if type_str == 'dict': - data = dict() - for key, subnode in node.items(): - if slash == 'replace': - for key_spec, val_spec in special_chars.items(): - key = key.replace(key_spec, val_spec) - data[key[4:]] = _triage_read(subnode, slash=slash) - elif type_str in ['list', 'tuple']: - data = list() - ii = 0 - while True: - subnode = node.get('idx_{0}'.format(ii), None) - if subnode is None: - break - data.append(_triage_read(subnode, slash=slash)) - ii += 1 - assert len(data) == ii - data = tuple(data) if type_str == 'tuple' else data - return data - elif type_str == 'csc_matrix': - if sparse is None: - raise RuntimeError('scipy must be installed to read this data') - data = sparse.csc_matrix((_triage_read(node['data'], slash=slash), - _triage_read(node['indices'], - slash=slash), - _triage_read(node['indptr'], - slash=slash))) - elif type_str == 'csr_matrix': - if sparse is None: - raise RuntimeError('scipy must be installed to read this data') - data = sparse.csr_matrix((_triage_read(node['data'], slash=slash), - _triage_read(node['indices'], - slash=slash), - _triage_read(node['indptr'], - slash=slash)), - shape=_triage_read(node['shape'])) - elif type_str in ['pd_dataframe', 'pd_series']: - from pandas import read_hdf - rootname = node.name - filename = node.file.filename - data = read_hdf(filename, rootname, mode='r') - else: - raise NotImplementedError('Unknown group type: {0}' - ''.format(type_str)) - elif type_str == 'ndarray': - data = np.array(node) - elif type_str in ('int', 'float'): - cast = int if type_str == 'int' else float - data = cast(np.array(node)[0]) - elif type_str == 'np_bool_': - data = np.bool_(np.array(node)[0]) - elif type_str in ('unicode', 'ascii', 'str'): # 'str' for backward compat - decoder = 'utf-8' if type_str == 'unicode' else 'ASCII' - cast = text_type if type_str == 'unicode' else str - data = cast(np.array(node).tostring().decode(decoder)) - elif type_str == 'None': - data = None - else: - raise TypeError('Unknown node type: {0}'.format(type_str)) - return data - - -# ############################################################################ -# UTILITIES - -def _sort_keys(x): - """Sort and return keys of dict""" - keys = list(x.keys()) # note: not thread-safe - idx = np.argsort([str(k) for k in keys]) - keys = [keys[ii] for ii in idx] - return keys - - -def object_diff(a, b, pre=''): - """Compute all differences between two python variables - - Parameters - ---------- - a : object - Currently supported: dict, list, tuple, ndarray, int, str, bytes, - float. - b : object - Must be same type as x1. - pre : str - String to prepend to each line. - - Returns - ------- - diffs : str - A string representation of the differences. - """ - - try: - from pandas import DataFrame, Series - except ImportError: - DataFrame = Series = type(None) - - out = '' - if type(a) != type(b): - out += pre + ' type mismatch (%s, %s)\n' % (type(a), type(b)) - elif isinstance(a, dict): - k1s = _sort_keys(a) - k2s = _sort_keys(b) - m1 = set(k2s) - set(k1s) - if len(m1): - out += pre + ' x1 missing keys %s\n' % (m1) - for key in k1s: - if key not in k2s: - out += pre + ' x2 missing key %s\n' % key - else: - out += object_diff(a[key], b[key], pre + 'd1[%s]' % repr(key)) - elif isinstance(a, (list, tuple)): - if len(a) != len(b): - out += pre + ' length mismatch (%s, %s)\n' % (len(a), len(b)) - else: - for xx1, xx2 in zip(a, b): - out += object_diff(xx1, xx2, pre='') - elif isinstance(a, (string_types, int, float, bytes)): - if a != b: - out += pre + ' value mismatch (%s, %s)\n' % (a, b) - elif a is None: - pass # b must be None due to our type checking - elif isinstance(a, np.ndarray): - if not np.array_equal(a, b): - out += pre + ' array mismatch\n' - elif sparse is not None and sparse.isspmatrix(a): - # sparsity and sparse type of b vs a already checked above by type() - if b.shape != a.shape: - out += pre + (' sparse matrix a and b shape mismatch' - '(%s vs %s)' % (a.shape, b.shape)) - else: - c = a - b - c.eliminate_zeros() - if c.nnz > 0: - out += pre + (' sparse matrix a and b differ on %s ' - 'elements' % c.nnz) - elif isinstance(a, (DataFrame, Series)): - if b.shape != a.shape: - out += pre + (' pandas values a and b shape mismatch' - '(%s vs %s)' % (a.shape, b.shape)) - else: - c = a.values - b.values - nzeros = np.sum(c != 0) - if nzeros > 0: - out += pre + (' pandas values a and b differ on %s ' - 'elements' % nzeros) - else: - raise RuntimeError(pre + ': unsupported type %s (%s)' % (type(a), a)) - return out - - -class _TempDir(str): - """Class for creating and auto-destroying temp dir - - This is designed to be used with testing modules. Instances should be - defined inside test functions. Instances defined at module level can not - guarantee proper destruction of the temporary directory. - - When used at module level, the current use of the __del__() method for - cleanup can fail because the rmtree function may be cleaned up before this - object (an alternative could be using the atexit module instead). - """ - def __new__(self): - new = str.__new__(self, tempfile.mkdtemp()) - return new - - def __init__(self): - self._path = self.__str__() - - def __del__(self): - rmtree(self._path, ignore_errors=True) diff --git a/expyfun/_externals/decorator.py b/expyfun/_externals/decorator.py deleted file mode 100644 index 0836d2b3..00000000 --- a/expyfun/_externals/decorator.py +++ /dev/null @@ -1,254 +0,0 @@ -# -*- coding: utf-8 -*- -########################## LICENCE ############################### - -# Copyright (c) 2005-2012, Michele Simionato -# All rights reserved. - -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: - -# Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# Redistributions in bytecode form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in -# the documentation and/or other materials provided with the -# distribution. - -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, -# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, -# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS -# OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND -# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR -# TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE -# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH -# DAMAGE. - -""" -Decorator module, see http://pypi.python.org/pypi/decorator -for the documentation. -""" -from __future__ import print_function - -__version__ = '3.4.0' - -__all__ = ["decorator", "FunctionMaker", "contextmanager"] - - -import sys, re, inspect -if sys.version >= '3': - from inspect import getfullargspec - def get_init(cls): - return cls.__init__ -else: - class getfullargspec(object): - "A quick and dirty replacement for getfullargspec for Python 2.X" - def __init__(self, f): - self.args, self.varargs, self.varkw, self.defaults = \ - inspect.getargspec(f) - self.kwonlyargs = [] - self.kwonlydefaults = None - def __iter__(self): - yield self.args - yield self.varargs - yield self.varkw - yield self.defaults - def get_init(cls): - return cls.__init__.__func__ - -DEF = re.compile(r'\s*def\s*([_\w][_\w\d]*)\s*\(') - -# basic functionality -class FunctionMaker(object): - """ - An object with the ability to create functions with a given signature. - It has attributes name, doc, module, signature, defaults, dict and - methods update and make. - """ - def __init__(self, func=None, name=None, signature=None, - defaults=None, doc=None, module=None, funcdict=None): - self.shortsignature = signature - if func: - # func can be a class or a callable, but not an instance method - self.name = func.__name__ - if self.name == '': # small hack for lambda functions - self.name = '_lambda_' - self.doc = func.__doc__ - self.module = func.__module__ - if inspect.isfunction(func): - argspec = getfullargspec(func) - self.annotations = getattr(func, '__annotations__', {}) - for a in ('args', 'varargs', 'varkw', 'defaults', 'kwonlyargs', - 'kwonlydefaults'): - setattr(self, a, getattr(argspec, a)) - for i, arg in enumerate(self.args): - setattr(self, 'arg%d' % i, arg) - if sys.version < '3': # easy way - self.shortsignature = self.signature = \ - inspect.formatargspec( - formatvalue=lambda val: "", *argspec)[1:-1] - else: # Python 3 way - allargs = list(self.args) - allshortargs = list(self.args) - if self.varargs: - allargs.append('*' + self.varargs) - allshortargs.append('*' + self.varargs) - elif self.kwonlyargs: - allargs.append('*') # single star syntax - for a in self.kwonlyargs: - allargs.append('%s=None' % a) - allshortargs.append('%s=%s' % (a, a)) - if self.varkw: - allargs.append('**' + self.varkw) - allshortargs.append('**' + self.varkw) - self.signature = ', '.join(allargs) - self.shortsignature = ', '.join(allshortargs) - self.dict = func.__dict__.copy() - # func=None happens when decorating a caller - if name: - self.name = name - if signature is not None: - self.signature = signature - if defaults: - self.defaults = defaults - if doc: - self.doc = doc - if module: - self.module = module - if funcdict: - self.dict = funcdict - # check existence required attributes - assert hasattr(self, 'name') - if not hasattr(self, 'signature'): - raise TypeError('You are decorating a non function: %s' % func) - - def update(self, func, **kw): - "Update the signature of func with the data in self" - func.__name__ = self.name - func.__doc__ = getattr(self, 'doc', None) - func.__dict__ = getattr(self, 'dict', {}) - func.__defaults__ = getattr(self, 'defaults', ()) - func.__kwdefaults__ = getattr(self, 'kwonlydefaults', None) - func.__annotations__ = getattr(self, 'annotations', None) - callermodule = sys._getframe(3).f_globals.get('__name__', '?') - func.__module__ = getattr(self, 'module', callermodule) - func.__dict__.update(kw) - - def make(self, src_templ, evaldict=None, addsource=False, **attrs): - "Make a new function from a given template and update the signature" - src = src_templ % vars(self) # expand name and signature - evaldict = evaldict or {} - mo = DEF.match(src) - if mo is None: - raise SyntaxError('not a valid function template\n%s' % src) - name = mo.group(1) # extract the function name - names = set([name] + [arg.strip(' *') for arg in - self.shortsignature.split(',')]) - for n in names: - if n in ('_func_', '_call_'): - raise NameError('%s is overridden in\n%s' % (n, src)) - if not src.endswith('\n'): # add a newline just for safety - src += '\n' # this is needed in old versions of Python - try: - code = compile(src, '', 'single') - # print >> sys.stderr, 'Compiling %s' % src - exec(code, evaldict) - except: - print('Error in generated code:', file=sys.stderr) - print(src, file=sys.stderr) - raise - func = evaldict[name] - if addsource: - attrs['__source__'] = src - self.update(func, **attrs) - return func - - @classmethod - def create(cls, obj, body, evaldict, defaults=None, - doc=None, module=None, addsource=True, **attrs): - """ - Create a function from the strings name, signature and body. - evaldict is the evaluation dictionary. If addsource is true an attribute - __source__ is added to the result. The attributes attrs are added, - if any. - """ - if isinstance(obj, str): # "name(signature)" - name, rest = obj.strip().split('(', 1) - signature = rest[:-1] #strip a right parens - func = None - else: # a function - name = None - signature = None - func = obj - self = cls(func, name, signature, defaults, doc, module) - ibody = '\n'.join(' ' + line for line in body.splitlines()) - return self.make('def %(name)s(%(signature)s):\n' + ibody, - evaldict, addsource, **attrs) - -def decorator(caller, func=None): - """ - decorator(caller) converts a caller function into a decorator; - decorator(caller, func) decorates a function using a caller. - """ - if func is not None: # returns a decorated function - evaldict = func.__globals__.copy() - evaldict['_call_'] = caller - evaldict['_func_'] = func - return FunctionMaker.create( - func, "return _call_(_func_, %(shortsignature)s)", - evaldict, undecorated=func, __wrapped__=func) - else: # returns a decorator - if inspect.isclass(caller): - name = caller.__name__.lower() - callerfunc = get_init(caller) - doc = 'decorator(%s) converts functions/generators into ' \ - 'factories of %s objects' % (caller.__name__, caller.__name__) - fun = getfullargspec(callerfunc).args[1] # second arg - elif inspect.isfunction(caller): - name = '_lambda_' if caller.__name__ == '' \ - else caller.__name__ - callerfunc = caller - doc = caller.__doc__ - fun = getfullargspec(callerfunc).args[0] # first arg - else: # assume caller is an object with a __call__ method - name = caller.__class__.__name__.lower() - callerfunc = caller.__call__.__func__ - doc = caller.__call__.__doc__ - fun = getfullargspec(callerfunc).args[1] # second arg - evaldict = callerfunc.__globals__.copy() - evaldict['_call_'] = caller - evaldict['decorator'] = decorator - return FunctionMaker.create( - '%s(%s)' % (name, fun), - 'return decorator(_call_, %s)' % fun, - evaldict, undecorated=caller, __wrapped__=caller, - doc=doc, module=caller.__module__) - -######################### contextmanager ######################## - -def __call__(self, func): - 'Context manager decorator' - return FunctionMaker.create( - func, "with _self_: return _func_(%(shortsignature)s)", - dict(_self_=self, _func_=func), __wrapped__=func) - -try: # Python >= 3.2 - - from contextlib import _GeneratorContextManager - ContextManager = type( - 'ContextManager', (_GeneratorContextManager,), dict(__call__=__call__)) - -except ImportError: # Python >= 2.5 - - from contextlib import GeneratorContextManager - def __init__(self, f, *a, **k): - return GeneratorContextManager.__init__(self, f(*a, **k)) - ContextManager = type( - 'ContextManager', (GeneratorContextManager,), - dict(__call__=__call__, __init__=__init__)) - -contextmanager = decorator(ContextManager) diff --git a/expyfun/_eyelink_controller.py b/expyfun/_eyelink_controller.py index eb257888..cfd88dcc 100644 --- a/expyfun/_eyelink_controller.py +++ b/expyfun/_eyelink_controller.py @@ -5,16 +5,17 @@ # # License: BSD (3-clause) -import numpy as np import datetime import os -from os import path as op -import sys import subprocess +import sys import time +from os import path as op -from .visual import FixationDot, Circle, RawImage, Line, Text -from ._utils import get_config, verbose_dec, logger, string_types +import numpy as np + +from ._utils import get_config, logger, verbose_dec +from .visual import Circle, FixationDot, Line, RawImage, Text # Constants TRIAL_OK = 0 @@ -35,6 +36,7 @@ def dummy_fun(*args, **kwargs): # don't prevent basic functionality for folks who don't use EL try: import pylink + cal_super_class = pylink.EyeLinkCustomDisplay openGraphicsEx = pylink.openGraphicsEx except ImportError: @@ -43,87 +45,105 @@ def dummy_fun(*args, **kwargs): openGraphicsEx = dummy_fun -eye_list = ['LEFT_EYE', 'RIGHT_EYE', 'BINOCULAR'] # Used by eyeAvailable +eye_list = ["LEFT_EYE", "RIGHT_EYE", "BINOCULAR"] # Used by eyeAvailable def _get_key_trans_dict(): """Helper to translate pyglet keys to pylink codes""" from pyglet.window import key - key_trans_dict = {str(key.F1): pylink.F1_KEY, - str(key.F2): pylink.F2_KEY, - str(key.F3): pylink.F3_KEY, - str(key.F4): pylink.F4_KEY, - str(key.F5): pylink.F5_KEY, - str(key.F6): pylink.F6_KEY, - str(key.F7): pylink.F7_KEY, - str(key.F8): pylink.F8_KEY, - str(key.F9): pylink.F9_KEY, - str(key.F10): pylink.F10_KEY, - str(key.PAGEUP): pylink.PAGE_UP, - str(key.PAGEDOWN): pylink.PAGE_DOWN, - str(key.UP): pylink.CURS_UP, - str(key.DOWN): pylink.CURS_DOWN, - str(key.LEFT): pylink.CURS_LEFT, - str(key.RIGHT): pylink.CURS_RIGHT, - str(key.BACKSPACE): '\b', - str(key.RETURN): pylink.ENTER_KEY, - str(key.ESCAPE): pylink.ESC_KEY, - str(key.NUM_ADD): key.PLUS, - str(key.NUM_SUBTRACT): key.MINUS, - } + + key_trans_dict = { + str(key.F1): pylink.F1_KEY, + str(key.F2): pylink.F2_KEY, + str(key.F3): pylink.F3_KEY, + str(key.F4): pylink.F4_KEY, + str(key.F5): pylink.F5_KEY, + str(key.F6): pylink.F6_KEY, + str(key.F7): pylink.F7_KEY, + str(key.F8): pylink.F8_KEY, + str(key.F9): pylink.F9_KEY, + str(key.F10): pylink.F10_KEY, + str(key.PAGEUP): pylink.PAGE_UP, + str(key.PAGEDOWN): pylink.PAGE_DOWN, + str(key.UP): pylink.CURS_UP, + str(key.DOWN): pylink.CURS_DOWN, + str(key.LEFT): pylink.CURS_LEFT, + str(key.RIGHT): pylink.CURS_RIGHT, + str(key.BACKSPACE): "\b", + str(key.RETURN): pylink.ENTER_KEY, + str(key.ESCAPE): pylink.ESC_KEY, + str(key.NUM_ADD): key.PLUS, + str(key.NUM_SUBTRACT): key.MINUS, + } return key_trans_dict def _get_color_dict(): """Helper to translate pylink colors to pyglet""" - color_dict = {str(CR_HAIR_COLOR): (1.0, 1.0, 1.0), - str(PUPIL_HAIR_COLOR): (1.0, 1.0, 1.0), - str(PUPIL_BOX_COLOR): (0.0, 1.0, 0.0), - str(SEARCH_LIMIT_BOX_COLOR): (1.0, 0.0, 0.0), - str(MOUSE_CURSOR_COLOR): (1.0, 0.0, 0.0)} + color_dict = { + str(CR_HAIR_COLOR): (1.0, 1.0, 1.0), + str(PUPIL_HAIR_COLOR): (1.0, 1.0, 1.0), + str(PUPIL_BOX_COLOR): (0.0, 1.0, 0.0), + str(SEARCH_LIMIT_BOX_COLOR): (1.0, 0.0, 0.0), + str(MOUSE_CURSOR_COLOR): (1.0, 0.0, 0.0), + } return color_dict -def _check(val, msg, out='error'): +def _check(val, msg, out="error"): """Helper to check output""" if val != TRIAL_OK: msg = msg.format(val) - if out == 'warn': + if out == "warn": logger.warning(msg) else: raise RuntimeError(msg) _dummy_names = [ - 'setSaccadeVelocityThreshold', 'setAccelerationThreshold', - 'setUpdateInterval', 'setFixationUpdateAccumulate', 'setFileEventFilter', - 'setLinkEventFilter', 'setFileSampleFilter', 'setLinkSampleFilter', - 'setPupilSizeDiameter', 'setAcceptTargetFixationButton', - 'openDataFile', 'startRecording', 'waitForModeReady', - 'isRecording', 'stopRecording', 'closeDataFile', 'doTrackerSetup', - 'receiveDataFile', 'close', 'eyeAvailable', 'sendCommand', + "setSaccadeVelocityThreshold", + "setAccelerationThreshold", + "setUpdateInterval", + "setFixationUpdateAccumulate", + "setFileEventFilter", + "setLinkEventFilter", + "setFileSampleFilter", + "setLinkSampleFilter", + "setPupilSizeDiameter", + "setAcceptTargetFixationButton", + "openDataFile", + "startRecording", + "waitForModeReady", + "isRecording", + "stopRecording", + "closeDataFile", + "doTrackerSetup", + "receiveDataFile", + "close", + "eyeAvailable", + "sendCommand", ] -class DummyEl(object): +class DummyEl: """Dummy EyeLink controller.""" def __init__(self): for name in _dummy_names: setattr(self, name, dummy_fun) - self.getTrackerVersion = lambda: 'Dummy' + self.getTrackerVersion = lambda: "Dummy" self.getDummyMode = lambda: True self.getCurrentMode = lambda: IN_RECORD_MODE self.waitForBlockStart = lambda a, b, c: 1 def sendMessage(self, msg): """Send a message.""" - if not isinstance(msg, string_types): - raise TypeError('msg must be str') + if not isinstance(msg, str): + raise TypeError("msg must be str") return TRIAL_OK -class EyelinkController(object): +class EyelinkController: """Eyelink communication and control methods. Parameters @@ -147,50 +167,53 @@ class EyelinkController(object): """ @verbose_dec - def __init__(self, ec, link='default', fs=1000, verbose=None): - if link == 'default': - link = get_config('EXPYFUN_EYELINK', None) + def __init__(self, ec, link="default", fs=1000, verbose=None): + if link == "default": + link = get_config("EXPYFUN_EYELINK", None) if link is not None and pylink is None: - raise ImportError('Could not import pylink, please ensure it ' - 'is installed correctly to use the EyeLink') + raise ImportError( + "Could not import pylink, please ensure it " + "is installed correctly to use the EyeLink" + ) valid_fs = (250, 500, 1000, 2000) if fs not in valid_fs: - raise ValueError('fs must be one of {0}'.format(list(valid_fs))) + raise ValueError(f"fs must be one of {list(valid_fs)}") output_dir = ec._output_dir if output_dir is None: output_dir = os.getcwd() - if not isinstance(output_dir, string_types): - raise TypeError('output_dir must be a string') + if not isinstance(output_dir, str): + raise TypeError("output_dir must be a string") if not op.isdir(output_dir): os.mkdir(output_dir) self._output_dir = output_dir self._ec = ec - if 'el_id' in self._ec._id_call_dict: - raise RuntimeError('Cannot use initialize EL twice') - logger.info('EyeLink: Initializing on {}'.format(link)) + if "el_id" in self._ec._id_call_dict: + raise RuntimeError("Cannot use initialize EL twice") + logger.info(f"EyeLink: Initializing on {link}") ec.flush() if link is not None: - iswin = (sys.platform == 'win32') - cmd = 'ping -n 1 -w 100' if iswin else 'fping -c 1 -t100' - cmd = subprocess.Popen('%s %s' % (cmd, link), - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) + iswin = sys.platform == "win32" + cmd = "ping -n 1 -w 100" if iswin else "fping -c 1 -t100" + cmd = subprocess.Popen( + "%s %s" % (cmd, link), stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) if cmd.returncode: - raise RuntimeError('could not connect to Eyelink @ %s, ' - 'is it turned on?' % link) + raise RuntimeError( + "could not connect to Eyelink @ %s, " "is it turned on?" % link + ) self._eyelink = DummyEl() if link is None else pylink.EyeLink(link) self._file_list = [] self._size = np.array(self._ec.window_size_pix) self._ec._extra_cleanup_fun += [self._close] self._ec.flush() self._setup(fs) - self._ec._id_call_dict['el_id'] = self._stamp_trial_id + self._ec._id_call_dict["el_id"] = self._stamp_trial_id self._ec._ofp_critical_funs.append(self._stamp_trial_start) self._ec._on_trial_ok.append(self._stamp_trial_ok) self._fake_calibration = False # Only used for testing self._closed = False # to prevent double-closing self._current_open_file = None - logger.debug('EyeLink: Setup complete') + logger.debug("EyeLink: Setup complete") self._ec.flush() def _setup(self, fs=1000): @@ -206,10 +229,10 @@ def _setup(self, fs=1000): """ # map the gaze positions from the tracker to screen pixel positions res = self._size - res_str = '0 0 {0} {1}'.format(res[0] - 1, res[1] - 1) - logger.debug('EyeLink: Setting display coordinates and saccade levels') - self._command('screen_pixel_coords = ' + res_str) - self._message('DISPLAY_COORDS ' + res_str) + res_str = f"0 0 {res[0] - 1} {res[1] - 1}" + logger.debug("EyeLink: Setting display coordinates and saccade levels") + self._command("screen_pixel_coords = " + res_str) + self._message("DISPLAY_COORDS " + res_str) # set calibration parameters self.custom_calibration() @@ -219,33 +242,31 @@ def _setup(self, fs=1000): self._eyelink.setAccelerationThreshold(9500) self._eyelink.setUpdateInterval(50) self._eyelink.setFixationUpdateAccumulate(50) - self._command('sample_rate = {0}'.format(fs)) + self._command(f"sample_rate = {fs}") # retrieve tracker version and tracker software version v = str(self._eyelink.getTrackerVersion()).strip() - logger.info('Eyelink: Running experiment on a version ''{0}'' ' - 'tracker.'.format(v)) - v = v.split('.') + logger.info("Eyelink: Running experiment on a version " f"{v}" " " "tracker.") + v = v.split(".") # set EDF file contents - logger.debug('EyeLink: Setting file and event filters') - fef = 'LEFT,RIGHT,FIXATION,SACCADE,BLINK,MESSAGE,BUTTON,INPUT' + logger.debug("EyeLink: Setting file and event filters") + fef = "LEFT,RIGHT,FIXATION,SACCADE,BLINK,MESSAGE,BUTTON,INPUT" self._eyelink.setFileEventFilter(fef) - lef = ('LEFT,RIGHT,FIXATION,SACCADE,BLINK,MESSAGE,' - 'BUTTON,FIXUPDATE,INPUT') + lef = "LEFT,RIGHT,FIXATION,SACCADE,BLINK,MESSAGE," "BUTTON,FIXUPDATE,INPUT" self._eyelink.setLinkEventFilter(lef) - fsf = 'LEFT,RIGHT,GAZE,HREF,AREA,GAZERES,STATUS,INPUT' - lsf = 'LEFT,RIGHT,GAZE,GAZERES,AREA,STATUS,INPUT' - if len(v) > 1 and v[0] == '3' and v[1] == '4': + fsf = "LEFT,RIGHT,GAZE,HREF,AREA,GAZERES,STATUS,INPUT" + lsf = "LEFT,RIGHT,GAZE,GAZERES,AREA,STATUS,INPUT" + if len(v) > 1 and v[0] == "3" and v[1] == "4": # remote mode possible add HTARGET ( head target) - fsf += ',HTARGET' + fsf += ",HTARGET" # set link data (used for gaze cursor) - lsf += ',HTARGET' + lsf += ",HTARGET" self._eyelink.setFileSampleFilter(fsf) self._eyelink.setLinkSampleFilter(lsf) # Ensure that we get areas - self._eyelink.setPupilSizeDiameter('NO') + self._eyelink.setPupilSizeDiameter("NO") # calibration/drift cordisp.rection target self._eyelink.setAcceptTargetFixationButton(5) @@ -266,23 +287,24 @@ def fs(self): @property def _is_file_open(self): - return (self._current_open_file is not None) + return self._current_open_file is not None def _open_file(self): """Opens a new file on the Eyelink""" if self._is_file_open: - raise RuntimeError('Cannot start new file, old must be closed') - file_name = datetime.datetime.now().strftime('%H%M%S') + raise RuntimeError("Cannot start new file, old must be closed") + file_name = datetime.datetime.now().strftime("%H%M%S") while file_name in self._file_list: # This should succeed in under 1 second - file_name = datetime.datetime.now().strftime('%H%M%S') + file_name = datetime.datetime.now().strftime("%H%M%S") # make absolutely sure we don't break this, but it shouldn't ever # be wrong assert len(file_name) <= 8 - logger.info('Eyelink: Opening remote file with filename {}' - ''.format(file_name)) - _check(self._eyelink.openDataFile(file_name), - 'Remote file "' + file_name + '" could not be opened: {0}') + logger.info(f"Eyelink: Opening remote file with filename {file_name}" "") + _check( + self._eyelink.openDataFile(file_name), + 'Remote file "' + file_name + '" could not be opened: {0}', + ) self._current_open_file = file_name self._file_list.append(file_name) return self._current_open_file @@ -290,28 +312,33 @@ def _open_file(self): def _start_recording(self): """Start Eyelink recording""" if not self._is_file_open: - raise RuntimeError('cannot start recording without file open') + raise RuntimeError("cannot start recording without file open") for ii in range(5): self._ec.wait_secs(0.1) - out = 'check' if ii < 4 else 'error' - _check(self._eyelink.startRecording(1, 1, 1, 1), - 'Recording could not be started: {0}', out) + out = "check" if ii < 4 else "error" + _check( + self._eyelink.startRecording(1, 1, 1, 1), + "Recording could not be started: {0}", + out, + ) # self._eyelink.waitForModeReady(100) # doesn't work - _check(not self._eyelink.waitForBlockStart(100, 1, 0), - 'No link samples received: {0}') + _check( + not self._eyelink.waitForBlockStart(100, 1, 0), + "No link samples received: {0}", + ) if not self.recording: - raise RuntimeError('Eyelink is not recording') + raise RuntimeError("Eyelink is not recording") # double-check mode = self._eyelink.getCurrentMode() if mode != IN_RECORD_MODE: - raise RuntimeError('Eyelink is not recording: {0}'.format(mode)) + raise RuntimeError(f"Eyelink is not recording: {mode}") self._ec.flush() self._toggle_dummy_cursor(True) @property def recording(self): """Returns boolean for whether or not the Eyelink is recording""" - return (self._eyelink.isRecording() == TRIAL_OK) + return self._eyelink.isRecording() == TRIAL_OK def stop(self): """Stop Eyelink recording and close current file @@ -322,12 +349,11 @@ def stop(self): EyelinkController.transfer_remote_file """ if not self.recording: - raise RuntimeError('Cannot stop, not currently recording') - logger.info('Eyelink: Stopping recording') + raise RuntimeError("Cannot stop, not currently recording") + logger.info("Eyelink: Stopping recording") self._eyelink.stopRecording() - logger.info('Eyelink: Closing file') - _check(self._eyelink.closeDataFile(), - 'File could not be closed: {0}', 'warn') + logger.info("Eyelink: Closing file") + _check(self._eyelink.closeDataFile(), "File could not be closed: {0}", "warn") self._current_open_file = None self._toggle_dummy_cursor(False) @@ -363,9 +389,11 @@ def calibrate(self, beep=False, prompt=True): # open file to record *before* running calibration so it gets saved! fname = self._open_file() if prompt: - self._ec.screen_prompt('We will now perform a screen calibration.' - '\n\nPress a button to continue.') - logger.info('EyeLink: Entering calibration') + self._ec.screen_prompt( + "We will now perform a screen calibration." + "\n\nPress a button to continue." + ) + logger.info("EyeLink: Entering calibration") self._ec.flush() # enter Eyetracker camera setup mode, calibration and validation self._ec.flip() @@ -377,7 +405,7 @@ def calibrate(self, beep=False, prompt=True): self._eyelink.doTrackerSetup() cal.release_event_handlers() self._ec.flip() - logger.info('EyeLink: Completed calibration') + logger.info("EyeLink: Completed calibration") self._ec.flush() self._start_recording() return fname @@ -401,13 +429,13 @@ def _stamp_trial_id(self, ids): # such as one number for each trial independent variable. # Here we just force up to 12 integers for simplicity. if not isinstance(ids, (list, tuple)): - raise TypeError('ids must be a list (or tuple)') + raise TypeError("ids must be a list (or tuple)") if not all([np.isscalar(x) for x in ids]): - raise ValueError('All ids must be numeric') + raise ValueError("All ids must be numeric") if len(ids) > 12: - raise ValueError('ids must not have more than 12 entries') - ids = ' '.join([str(int(ii)) for ii in ids]) - msg = 'TRIALID {}'.format(ids) + raise ValueError("ids must not have more than 12 entries") + ids = " ".join([str(int(ii)) for ii in ids]) + msg = f"TRIALID {ids}" self._message(msg) def _stamp_trial_start(self): @@ -416,17 +444,16 @@ def _stamp_trial_start(self): This is a timing-critical operation used to synchronize the recording to stimulus presentation. """ - self._eyelink.sendMessage('SYNCTIME') + self._eyelink.sendMessage("SYNCTIME") def _stamp_trial_ok(self): - """Signal the end of a trial - """ - self._eyelink.sendMessage('TRIAL OK') + """Signal the end of a trial""" + self._eyelink.sendMessage("TRIAL OK") def _message(self, msg): """Send message to eyelink, must be a string""" self._eyelink.sendMessage(msg) - self._command('record_status_message "{0}"'.format(msg)) + self._command(f'record_status_message "{msg}"') def _command(self, cmd): """Send Eyelink a command, must be a string""" @@ -449,11 +476,10 @@ def transfer_remote_file(self, remote_name): -------- EyelinkController.stop """ - fname = op.join(self._output_dir, '{0}.edf'.format(remote_name)) - logger.info('Eyelink: saving Eyelink file: {0} ...' - ''.format(remote_name)) + fname = op.join(self._output_dir, f"{remote_name}.edf") + logger.info(f"Eyelink: saving Eyelink file: {remote_name} ..." "") status = self._eyelink.receiveDataFile(remote_name, fname) - logger.info('Eyelink: transferred {0} bytes'.format(status)) + logger.info(f"Eyelink: transferred {status} bytes") return fname def _close(self): @@ -463,21 +489,30 @@ def _close(self): if self.recording: self.stop() # make sure files get transferred - fnames = [self.transfer_remote_file(remote_name) - for remote_name in self._file_list] + fnames = [ + self.transfer_remote_file(remote_name) + for remote_name in self._file_list + ] self._file_list = list() self._eyelink.close() self._closed = True - assert 'el_id' in self._ec._id_call_dict - del self._ec._id_call_dict['el_id'] + assert "el_id" in self._ec._id_call_dict + del self._ec._id_call_dict["el_id"] idx = self._ec._ofp_critical_funs.index(self._stamp_trial_start) self._ec._ofp_critical_funs.pop(idx) idx = self._ec._on_trial_ok.index(self._stamp_trial_ok) self._ec._on_trial_ok.pop(idx) return fnames - def wait_for_fix(self, fix_pos, fix_time=0., tol=100., max_wait=np.inf, - check_interval=0.001, units='norm'): + def wait_for_fix( + self, + fix_pos, + fix_time=0.0, + tol=100.0, + max_wait=np.inf, + check_interval=0.001, + units="norm", + ): """Wait for gaze to settle within a defined region Parameters @@ -512,11 +547,12 @@ def wait_for_fix(self, fix_pos, fix_time=0., tol=100., max_wait=np.inf, time_out = time_in + max_wait fix_pos = np.array(fix_pos) if not (fix_pos.ndim == 1 and fix_pos.size == 2): - raise ValueError('fix_pos must be a 2-element array-like vector') - fix_pos = self._ec._convert_units(fix_pos[:, np.newaxis], units, 'pix') + raise ValueError("fix_pos must be a 2-element array-like vector") + fix_pos = self._ec._convert_units(fix_pos[:, np.newaxis], units, "pix") fix_pos = fix_pos[:, 0] - while (time.time() < time_out and not - (fix_success and time.time() - time_in >= fix_time)): + while time.time() < time_out and not ( + fix_success and time.time() - time_in >= fix_time + ): # sample eye position eye_pos = self.get_eye_position() # in pixels if _within_distance(eye_pos, fix_pos, tol): @@ -529,8 +565,16 @@ def wait_for_fix(self, fix_pos, fix_time=0., tol=100., max_wait=np.inf, return fix_success - def maintain_fix(self, fix_pos, check_duration, tol=100., period=.250, - check_interval=0.001, units='norm', stop_early=False): + def maintain_fix( + self, + fix_pos, + check_duration, + tol=100.0, + period=0.250, + check_interval=0.001, + units="norm", + stop_early=False, + ): """Check to see if subject is fixating in a region. This checks to make sure that the subjects gaze falls within a region @@ -569,12 +613,15 @@ def maintain_fix(self, fix_pos, check_duration, tol=100., period=.250, fix_pos = np.array(fix_pos) if not (fix_pos.ndim == 1 and fix_pos.size == 2): - raise ValueError('fix_pos must be a 2-element array-like vector') - fix_pos = self._ec._convert_units(fix_pos[:, np.newaxis], units, 'pix') + raise ValueError("fix_pos must be a 2-element array-like vector") + fix_pos = self._ec._convert_units(fix_pos[:, np.newaxis], units, "pix") fix_pos = fix_pos[:, 0] check = [] - while ((fix_success and time.time() < time_end) if stop_early else - time.time() < time_end): + while ( + (fix_success and time.time() < time_end) + if stop_early + else time.time() < time_end + ): if fix_success: # sample eye position eye_pos = self.get_eye_position() # in pixels @@ -588,8 +635,14 @@ def maintain_fix(self, fix_pos, check_duration, tol=100., period=.250, self._ec.wait_secs(check_interval) return fix_success - def custom_calibration(self, ctype='HV5', horiz=2./3., vert=2./3., - coordinates=None, units='norm'): + def custom_calibration( + self, + ctype="HV5", + horiz=2.0 / 3.0, + vert=2.0 / 3.0, + coordinates=None, + units="norm", + ): """Set Eyetracker to use a custom calibration sequence Parameters @@ -611,54 +664,79 @@ def custom_calibration(self, ctype='HV5', horiz=2./3., vert=2./3., -------- EyelinkController.calibrate """ - allowed_types = ['H3', 'HV5', 'HV9', 'HV13', 'custom'] + allowed_types = ["H3", "HV5", "HV9", "HV13", "custom"] if ctype not in allowed_types: - raise ValueError('ctype cannot be "{0}", but must be one of {1}' - ''.format(ctype, allowed_types)) - if ctype != 'custom': + raise ValueError( + f'ctype cannot be "{ctype}", but must be one of {allowed_types}' "" + ) + if ctype != "custom": if coordinates is not None: - raise ValueError('If ctype is not \'custom\' coordinates canno' - 't be used to generate calibration pattern.') + raise ValueError( + "If ctype is not 'custom' coordinates canno" + "t be used to generate calibration pattern." + ) horiz, vert = float(horiz), float(vert) - xx = np.array(([0., horiz], [0., vert])) - h_pix, v_pix = np.diff(self._ec._convert_units(xx, units, 'pix'), - axis=1)[:, 0] - h_max, v_max = self._size[0] / 2., self._size[1] / 2. - for p, m, s in zip((h_pix, v_pix), (h_max, v_max), ('horiz', 'vert')): + xx = np.array(([0.0, horiz], [0.0, vert])) + h_pix, v_pix = np.diff(self._ec._convert_units(xx, units, "pix"), axis=1)[:, 0] + h_max, v_max = self._size[0] / 2.0, self._size[1] / 2.0 + for p, m, s in zip((h_pix, v_pix), (h_max, v_max), ("horiz", "vert")): if p > m: - raise ValueError('{0} too large ({1} > {2})' - ''.format(s, p, m)) + raise ValueError(f"{s} too large ({p} > {m})" "") # make the locations - if ctype == 'HV5': + if ctype == "HV5": mat = np.array([[0, 0], [1, 0], [-1, 0], [0, 1], [0, -1]]) - elif ctype == 'HV9': - mat = np.array([[0, 0], [1, 0], [-1, 0], [0, 1], [0, -1], [1, 1], - [-1, -1], [1, -1], [-1, 1]]) - elif ctype == 'H3': + elif ctype == "HV9": + mat = np.array( + [ + [0, 0], + [1, 0], + [-1, 0], + [0, 1], + [0, -1], + [1, 1], + [-1, -1], + [1, -1], + [-1, 1], + ] + ) + elif ctype == "H3": mat = np.array([[0, 0], [1, 0], [-1, 0]]) - elif ctype == 'HV13': - mat = np.array([[0, 0], [1, 0], [-1, 0], [0, 1], [0, -1], [1, 1], - [-1, -1], [1, -1], [-1, 1], [.5, .5], [-.5, -.5], - [.5, -.5], [-.5, .5]]) - elif ctype == 'custom': + elif ctype == "HV13": + mat = np.array( + [ + [0, 0], + [1, 0], + [-1, 0], + [0, 1], + [0, -1], + [1, 1], + [-1, -1], + [1, -1], + [-1, 1], + [0.5, 0.5], + [-0.5, -0.5], + [0.5, -0.5], + [-0.5, 0.5], + ] + ) + elif ctype == "custom": mat = np.array(coordinates, float) if mat.ndim != 2 or mat.shape[-1] != 2: - raise ValueError('Each coordinate must be a list with length 2' - '.') + raise ValueError("Each coordinate must be a list with length 2" ".") offsets = mat * np.array([h_pix, v_pix]) - coords = (self._size / 2. + offsets) + coords = self._size / 2.0 + offsets n_samples = coords.shape[0] - targs = ' '.join(['{0},{1}'.format(*c) for c in coords]) - seq = ','.join([str(x) for x in range(n_samples + 1)]) - self._command('calibration_type = {0}'.format(ctype)) - self._command('generate_default_targets = NO') - self._command('calibration_samples = {0}'.format(n_samples)) - self._command('calibration_sequence = ' + seq) - self._command('calibration_targets = ' + targs) - self._command('validation_samples = {0}'.format(n_samples)) - self._command('validation_sequence = ' + seq) - self._command('validation_targets = ' + targs) + targs = " ".join(["{0},{1}".format(*c) for c in coords]) + seq = ",".join([str(x) for x in range(n_samples + 1)]) + self._command(f"calibration_type = {ctype}") + self._command("generate_default_targets = NO") + self._command(f"calibration_samples = {n_samples}") + self._command("calibration_sequence = " + seq) + self._command("calibration_targets = " + targs) + self._command(f"validation_samples = {n_samples}") + self._command("validation_sequence = " + seq) + self._command("validation_targets = " + targs) def get_eye_position(self): """The current eye position in pixels @@ -676,11 +754,14 @@ def get_eye_position(self): if not self.dummy_mode: sample = self._eyelink.getNewestSample() if sample is None: - raise RuntimeError('No sample data, consider starting a ' - 'recording using el.start()') + raise RuntimeError( + "No sample data, consider starting a " "recording using el.start()" + ) if sample.isBinocular(): - eye_pos = (np.array(sample.getLeftEye().getGaze()) + - np.array(sample.getRightEye().getGaze())) / 2. + eye_pos = ( + np.array(sample.getLeftEye().getGaze()) + + np.array(sample.getRightEye().getGaze()) + ) / 2.0 elif sample.isLeftSample(): eye_pos = np.array(sample.getLeftEye().getGaze()) elif sample.isRightSample(): @@ -699,8 +780,7 @@ def _toggle_dummy_cursor(self, visibility): @property def file_list(self): - """The list of files started on the EyeLink - """ + """The list of files started on the EyeLink""" return self._file_list @property @@ -730,7 +810,7 @@ def __init__(self, ec, beep=False): self.ec = ec self.keys = [] ws = np.array(ec.window_size_pix) - self.img_span = 1.5 * np.array((float(ws[0]) / ws[1], 1.)) + self.img_span = 1.5 * np.array((float(ws[0]) / ws[1], 1.0)) # set up reusable objects self.targ_circ = FixationDot(self.ec) @@ -749,11 +829,15 @@ def __init__(self, ec, beep=False): self.img_size = (0, 0) def setup_event_handlers(self): - self.label = Text(self.ec, 'Eye Label', units='norm', - pos=(0, -self.img_span[1] / 2.), - anchor_y='top', color='white') - self.img = RawImage(self.ec, np.zeros((1, 2, 3)), - pos=(0, 0), units='norm') + self.label = Text( + self.ec, + "Eye Label", + units="norm", + pos=(0, -self.img_span[1] / 2.0), + anchor_y="top", + color="white", + ) + self.img = RawImage(self.ec, np.zeros((1, 2, 3)), pos=(0, 0), units="norm") def on_mouse_press(x, y, button, modifiers): self.state = 1 @@ -773,9 +857,13 @@ def on_key_press(symbol, modifiers): self.keys += [pylink.KeyInput(key, modifiers)] # create new handler at top of handling stack - self.ec.window.push_handlers(on_key_press, on_mouse_press, - on_mouse_motion, on_mouse_release, - on_mouse_drag) + self.ec.window.push_handlers( + on_key_press, + on_mouse_press, + on_mouse_motion, + on_mouse_release, + on_mouse_drag, + ) def release_event_handlers(self): try: @@ -794,7 +882,7 @@ def record_abort_hide(self): pass def draw_cal_target(self, x, y): - self.targ_circ.set_pos((x, y), units='pix') + self.targ_circ.set_pos((x, y), units="pix") self.targ_circ.draw() self.ec.flip() @@ -829,19 +917,25 @@ def _img2win(self, x, y): return x, y def alert_printf(self, msg): - logger.warning('EyeLink: alert_printf {}'.format(msg)) + logger.warning(f"EyeLink: alert_printf {msg}") def setup_image_display(self, w, h): # convert w, h from pixels to relative units x = np.array([[0, 0], [0, self.img_span[1]]], float) - x = np.diff(self.ec._convert_units(x, 'norm', 'pix')[1]) / h + x = np.diff(self.ec._convert_units(x, "norm", "pix")[1]) / h self.img.set_scale(x) self.clear_display() def image_title(self, text): - text = "
{0}
".format(text) - self.label = Text(self.ec, text, units='norm', anchor_y='top', - color='white', pos=(0, -self.img_span[1] / 2.)) + text = f"
{text}
" + self.label = Text( + self.ec, + text, + units="norm", + anchor_y="top", + color="white", + pos=(0, -self.img_span[1] / 2.0), + ) def set_image_palette(self, r, g, b): self.palette = np.array([r, g, b], np.uint8).T @@ -850,7 +944,7 @@ def draw_image_line(self, width, line, totlines, buff): if self.image_buffer is None: self.img_size = (width, totlines) self.image_buffer = np.empty((totlines, width, 3), float) - self.image_buffer[line - 1, :, :] = self.palette[buff, :] / 255. + self.image_buffer[line - 1, :, :] = self.palette[buff, :] / 255.0 if line == totlines: self.img.set_image(self.image_buffer) self.img.draw() @@ -862,18 +956,18 @@ def draw_line(self, x1, y1, x2, y2, colorindex): color = _get_color_dict()[str(colorindex)] x1, y1 = self._img2win(x1, y1) x2, y2 = self._img2win(x2, y2) - Line(self.ec, ((x1, x2), (y1, y2)), 'pix', color).draw() + Line(self.ec, ((x1, x2), (y1, y2)), "pix", color).draw() def draw_lozenge(self, x, y, width, height, colorindex): - coords = self._img2win(x + width / 2., y + width / 2.) - width = width * self.img.scale / 2. - height = height * self.img.scale / 2. + coords = self._img2win(x + width / 2.0, y + width / 2.0) + width = width * self.img.scale / 2.0 + height = height * self.img.scale / 2.0 self.loz_circ.set_line_color(_get_color_dict()[str(colorindex)]) - self.loz_circ.set_pos(coords, units='pix') - self.loz_circ.set_radius((width, height), units='pix') + self.loz_circ.set_pos(coords, units="pix") + self.loz_circ.set_radius((width, height), units="pix") self.loz_circ.draw() def _within_distance(pos_1, pos_2, radius): """Helper for checking eye position""" - return np.sum((pos_1 - pos_2) ** 2) <= radius ** 2 + return np.sum((pos_1 - pos_2) ** 2) <= radius**2 diff --git a/expyfun/_git.py b/expyfun/_git.py index 5c66a215..fb8d1908 100644 --- a/expyfun/_git.py +++ b/expyfun/_git.py @@ -1,16 +1,17 @@ -# -*- coding: utf-8 -*- import os -from os import path as op import sys import warnings +from importlib import reload +from io import StringIO +from os import path as op -from ._utils import _TempDir, string_types, run_subprocess, StringIO, reload +from ._utils import _TempDir, run_subprocess from ._version import __version__ this_version = __version__[-7:] try: - run_subprocess(['git', '--help']) + run_subprocess(["git", "--help"]) except Exception as exp: _has_git, why_not = False, str(exp) else: @@ -20,22 +21,21 @@ def _check_git(): """Helper to check the expyfun version""" if not _has_git: - raise RuntimeError('git not found: {0}'.format(why_not)) + raise RuntimeError(f"git not found: {why_not}") def _check_version_format(version): """Helper to ensure version is of proper format""" - if not isinstance(version, string_types) or len(version) != 7: - raise TypeError('version must be a string of length 7, got {0}' - ''.format(version)) + if not isinstance(version, str) or len(version) != 7: + raise TypeError(f"version must be a string of length 7, got {version}" "") def _active_version(wd): """Helper to get the currently active version""" - return run_subprocess(['git', 'rev-parse', 'HEAD'], cwd=wd)[0][:7] + return run_subprocess(["git", "rev-parse", "HEAD"], cwd=wd)[0][:7] -def download_version(version='current', dest_dir=None): +def download_version(version="current", dest_dir=None): """Download specific expyfun version Parameters @@ -56,24 +56,28 @@ def download_version(version='current', dest_dir=None): _check_version_format(version) if dest_dir is None: dest_dir = os.getcwd() - if not isinstance(dest_dir, string_types) or not op.isdir(dest_dir): - raise IOError('Destination directory {0} does not exist' - ''.format(dest_dir)) - if op.isdir(op.join(dest_dir, 'expyfun')): - raise IOError('Destination directory {0} already has "expyfun" ' - 'subdirectory'.format(dest_dir)) + if not isinstance(dest_dir, str) or not op.isdir(dest_dir): + raise OSError(f"Destination directory {dest_dir} does not exist" "") + if op.isdir(op.join(dest_dir, "expyfun")): + raise OSError( + f'Destination directory {dest_dir} already has "expyfun" ' "subdirectory" + ) # fetch locally and get the proper version tempdir = _TempDir() - expyfun_dir = op.join(tempdir, 'expyfun') # git will auto-create this dir - repo_url = 'git://github.com/LABSN/expyfun.git' - run_subprocess(['git', 'clone', repo_url, expyfun_dir]) - version = _active_version(expyfun_dir) if version == 'current' else version + expyfun_dir = op.join(tempdir, "expyfun") # git will auto-create this dir + repo_url = "https://github.com/LABSN/expyfun.git" + env = os.environ.copy() + env["GIT_TERMINAL_PROMPT"] = "0" # do not prompt for credentials + run_subprocess( + ["git", "clone", repo_url, expyfun_dir, "--single-branch", "--branch", "main"], + env=env, + ) + version = _active_version(expyfun_dir) if version == "current" else version try: - run_subprocess(['git', 'checkout', version], cwd=expyfun_dir) + run_subprocess(["git", "checkout", version], cwd=expyfun_dir, env=env) except Exception as exp: - raise RuntimeError('Could not check out version {0}: {1}' - ''.format(version, str(exp))) + raise RuntimeError(f"Could not check out version {version}: {str(exp)}" "") assert _active_version(expyfun_dir) == version # install @@ -82,29 +86,53 @@ def download_version(version='current', dest_dir=None): # ensure our version-specific "setup" is imported sys.path.insert(0, expyfun_dir) orig_stdout = sys.stdout + # numpy.distutils is gone, but all we use is setup from it. Let's use the one + # from setuptools instead. + orig_numpy_distutils_core = None + if "numpy.distutils.core" in sys.modules: + orig_numpy_distutils_core = sys.modules["numpy.distutils.core"] + import setuptools + + sys.modules["numpy.distutils.core"] = setuptools try: # on pytest with Py3k this can be problematic - if 'setup' in sys.modules: - del sys.modules['setup'] + if "setup" in sys.modules: + del sys.modules["setup"] import setup + reload(setup) setup_version = setup.git_version() # This is necessary because for a while git_version returned # a tuple of (version, fork) - if not isinstance(setup_version, string_types): + if not isinstance(setup_version, str): setup_version = setup_version[0] assert version.lower() == setup_version[:7].lower() del setup_version + # Now we need to monkey-patch to change FULL_VERSION, which can be for example: + # 2.0.0.dev-090948e + # to + # 2.0.0.dev0+090948e + if "-" in setup.FULL_VERSION: + setup.FULL_VERSION = setup.FULL_VERSION.replace("-", "0+") # PEP440 sys.stdout = StringIO() with warnings.catch_warnings(record=True): # PEP440 - setup.setup_package( - script_args=['build', '--build-purelib', dest_dir]) + setup.setup_package(script_args=["build", "--build-purelib", dest_dir]) finally: sys.stdout = orig_stdout sys.path.pop(sys.path.index(expyfun_dir)) os.chdir(orig_dir) - print('\n'.join(['Successfully checked out expyfun version:', version, - 'into destination directory:', op.join(dest_dir)])) + if orig_numpy_distutils_core is not None: + sys.modules["numpy.distutils.core"] = orig_numpy_distutils_core + print( + "\n".join( + [ + "Successfully checked out expyfun version:", + version, + "into destination directory:", + op.join(dest_dir), + ] + ) + ) def assert_version(version): @@ -117,5 +145,7 @@ def assert_version(version): """ _check_version_format(version) if this_version.lower() != version.lower(): - raise AssertionError('Requested version {0} does not match current ' - 'version {1}'.format(version, this_version)) + raise AssertionError( + f"Requested version {version} does not match current " + f"version {this_version}" + ) diff --git a/expyfun/_input_controllers.py b/expyfun/_input_controllers.py index 5fb7a082..57631a09 100644 --- a/expyfun/_input_controllers.py +++ b/expyfun/_input_controllers.py @@ -7,17 +7,16 @@ # # License: BSD (3-clause) -from functools import partial import sys +from functools import partial import numpy as np -from .visual import (Triangle, Rectangle, Circle, Diamond, ConcentricCircles, - FixationDot) -from ._utils import clock, string_types, logger +from ._utils import clock, logger +from .visual import Circle, ConcentricCircles, Diamond, FixationDot, Rectangle, Triangle -class Keyboard(object): +class Keyboard: """Retrieve presses from various devices. Public metohds: @@ -34,16 +33,19 @@ class Keyboard(object): _retrieve_events """ - key_event_types = {'presses': ['press'], 'releases': ['release'], - 'both': ['press', 'release']} + key_event_types = { + "presses": ["press"], + "releases": ["release"], + "both": ["press", "release"], + } def __init__(self, ec, force_quit_keys): self.master_clock = ec._master_clock self.log_presses = ec._log_presses self.force_quit_keys = force_quit_keys self.listen_start = None - ec._time_correction_fxns['keypress'] = self._get_timebase - self.get_time_corr = partial(ec._get_time_correction, 'keypress') + ec._time_correction_fxns["keypress"] = self._get_timebase + self.get_time_corr = partial(ec._get_time_correction, "keypress") self.time_correction = self.get_time_corr() self.ec = ec # always init pyglet response handler for error (and non-error) keys @@ -57,19 +59,18 @@ def __init__(self, ec, force_quit_keys): def _clear_events(self): self._clear_keyboard_events() - def _retrieve_events(self, live_keys, kind='presses'): + def _retrieve_events(self, live_keys, kind="presses"): return self._retrieve_keyboard_events(live_keys, kind) def _get_timebase(self): - """Get keyboard time reference (in seconds) - """ + """Get keyboard time reference (in seconds)""" return clock() def _clear_keyboard_events(self): self.ec._dispatch_events() self._keyboard_buffer = [] - def _retrieve_keyboard_events(self, live_keys, kind='presses'): + def _retrieve_keyboard_events(self, live_keys, kind="presses"): # add escape keys if live_keys is not None: live_keys = [str(x) for x in live_keys] # accept ints @@ -83,22 +84,21 @@ def _retrieve_keyboard_events(self, live_keys, kind='presses'): targets.append(key) return targets - def _on_pyglet_keypress(self, symbol, modifiers, emulated=False, - isPress=True): + def _on_pyglet_keypress(self, symbol, modifiers, emulated=False, isPress=True): """Handler for on_key_press pyglet events""" key_time = clock() if emulated: this_key = str(symbol) else: from pyglet.window import key + this_key = key.symbol_string(symbol).lower() - this_key = this_key.lstrip('_').lstrip('NUM_') - press_or_release = {True: 'press', False: 'release'}[isPress] + this_key = this_key.lstrip("_").lstrip("NUM_") + press_or_release = {True: "press", False: "release"}[isPress] self._keyboard_buffer.append((this_key, key_time, press_or_release)) def _on_pyglet_keyrelease(self, symbol, modifiers, emulated=False): - self._on_pyglet_keypress(symbol, modifiers, emulated=emulated, - isPress=False) + self._on_pyglet_keypress(symbol, modifiers, emulated=emulated, isPress=False) def listen_presses(self): """Start listening for keypresses.""" @@ -106,8 +106,9 @@ def listen_presses(self): self.listen_start = self.master_clock() self._clear_events() - def get_presses(self, live_keys, timestamp, relative_to, kind='presses', - return_kinds=False): + def get_presses( + self, live_keys, timestamp, relative_to, kind="presses", return_kinds=False + ): """Get the current entire keyboard / button box buffer. Parameters @@ -129,27 +130,30 @@ def get_presses(self, live_keys, timestamp, relative_to, kind='presses', The presses (and possibly timestamps and/or types). """ if kind not in self.key_event_types.keys(): - raise ValueError('Kind argument must be one of: '+', '.join( - self.key_event_types.keys())) + raise ValueError( + "Kind argument must be one of: " + + ", ".join(self.key_event_types.keys()) + ) events = [] if timestamp and relative_to is None: if self.listen_start is None: - raise ValueError('I cannot timestamp: relative_to is None and ' - 'you have not yet called listen_presses.') + raise ValueError( + "I cannot timestamp: relative_to is None and " + "you have not yet called listen_presses." + ) relative_to = self.listen_start events = self._retrieve_events(live_keys, kind) events = self._correct_presses(events, timestamp, relative_to, kind) events = [e[:-1] for e in events] if not return_kinds else events return events - def wait_one_press(self, max_wait, min_wait, live_keys, timestamp, - relative_to): + def wait_one_press(self, max_wait, min_wait, live_keys, timestamp, relative_to): """Return the first button pressed after min_wait. Parameters ---------- max_wait : float - Maxmimum time to wait. + Maximum time to wait. min_wait : float Minimum time to wait. live_keys : list | None @@ -165,10 +169,10 @@ def wait_one_press(self, max_wait, min_wait, live_keys, timestamp, The press. Will be tuple if timestamp is True. """ relative_to, start_time = self._init_wait_press( - max_wait, min_wait, live_keys, relative_to) + max_wait, min_wait, live_keys, relative_to + ) pressed = [] - while (not len(pressed) and - self.master_clock() - start_time < max_wait): + while not len(pressed) and self.master_clock() - start_time < max_wait: pressed = self._retrieve_events(live_keys) # handle non-presses @@ -179,14 +183,13 @@ def wait_one_press(self, max_wait, min_wait, live_keys, timestamp, pressed = (None, None) if timestamp else None return pressed - def wait_for_presses(self, max_wait, min_wait, live_keys, - timestamp, relative_to): + def wait_for_presses(self, max_wait, min_wait, live_keys, timestamp, relative_to): """Return all button presses between min_wait and max_wait. Parameters ---------- max_wait : float - Maxmimum time to wait. + Maximum time to wait. min_wait : float Minimum time to wait. live_keys : list | None @@ -202,9 +205,10 @@ def wait_for_presses(self, max_wait, min_wait, live_keys, The list of presses (and possibly timestamps). """ relative_to, start_time = self._init_wait_press( - max_wait, min_wait, live_keys, relative_to) + max_wait, min_wait, live_keys, relative_to + ) pressed = [] - while (self.master_clock() - start_time < max_wait): + while self.master_clock() - start_time < max_wait: pressed = self._retrieve_events(live_keys) pressed = self._correct_presses(pressed, timestamp, relative_to) pressed = [p[:2] if timestamp else p[0] for p in pressed] @@ -224,18 +228,20 @@ def check_force_quit(self, keys=None): # only grab the force-quit keys keys = self._retrieve_keyboard_events([]) else: - if isinstance(keys, string_types): + if isinstance(keys, str): keys = [keys] if isinstance(keys, list): keys = [k for k in keys if k in self.force_quit_keys] else: - raise TypeError('Force quit checking requires a string or ' - ' list of strings, not a {}.' - ''.format(type(keys))) + raise TypeError( + "Force quit checking requires a string or " + f" list of strings, not a {type(keys)}." + "" + ) if len(keys): - raise RuntimeError('Quit key pressed') + raise RuntimeError("Quit key pressed") - def _correct_presses(self, events, timestamp, relative_to, kind='presses'): + def _correct_presses(self, events, timestamp, relative_to, kind="presses"): """Correct timing of presses and check for quit press.""" events = [(k, s + self.time_correction, r) for k, s, r in events] self.log_presses(events) @@ -251,10 +257,11 @@ def _correct_presses(self, events, timestamp, relative_to, kind='presses'): def _init_wait_press(self, max_wait, min_wait, live_keys, relative_to): """Prepare for ``wait_one_press`` and ``wait_for_presses``.""" if np.isinf(max_wait) and live_keys == []: - raise ValueError('max_wait cannot be infinite if there are no live' - ' keys.') + raise ValueError( + "max_wait cannot be infinite if there are no live" " keys." + ) if not min_wait <= max_wait: - raise ValueError('min_wait must be less than max_wait') + raise ValueError("min_wait must be less than max_wait") start_time = self.master_clock() relative_to = start_time if relative_to is None else relative_to self.ec.wait_secs(min_wait) @@ -263,7 +270,7 @@ def _init_wait_press(self, max_wait, min_wait, live_keys, relative_to): return relative_to, start_time -class Mouse(object): +class Mouse: """Class to track mouse properties and events Parameters @@ -289,21 +296,28 @@ class Mouse(object): def __init__(self, ec, visible=False): from pyglet.window import mouse + self.ec = ec self.set_visible(visible) self.master_clock = ec._master_clock self.log_clicks = ec._log_clicks self.listen_start = None - ec._time_correction_fxns['mouseclick'] = self._get_timebase - self.get_time_corr = partial(ec._get_time_correction, 'mouseclick') + ec._time_correction_fxns["mouseclick"] = self._get_timebase + self.get_time_corr = partial(ec._get_time_correction, "mouseclick") self.time_correction = self.get_time_corr() self._check_force_quit = ec.check_force_quit self.ec._win.on_mouse_press = self._on_pyglet_mouse_click self._mouse_buffer = [] - self._button_names = {mouse.LEFT: 'left', mouse.MIDDLE: 'middle', - mouse.RIGHT: 'right'} - self._button_ids = {'left': mouse.LEFT, 'middle': mouse.MIDDLE, - 'right': mouse.RIGHT} + self._button_names = { + mouse.LEFT: "left", + mouse.MIDDLE: "middle", + mouse.RIGHT: "right", + } + self._button_ids = { + "left": mouse.LEFT, + "middle": mouse.MIDDLE, + "right": mouse.RIGHT, + } self._legal_types = (Rectangle, Circle) def set_visible(self, visible): @@ -326,10 +340,12 @@ def visible(self): @property def pos(self): """The current position of the mouse in normalized units""" - x = (self.ec._win._mouse_x - - self.ec._win.width / 2.) / (self.ec._win.width / 2.) - y = (self.ec._win._mouse_y - - self.ec._win.height / 2.) / (self.ec._win.height / 2.) + x = (self.ec._win._mouse_x - self.ec._win.width / 2.0) / ( + self.ec._win.width / 2.0 + ) + y = (self.ec._win._mouse_y - self.ec._win.height / 2.0) / ( + self.ec._win.height / 2.0 + ) return np.array([x, y]) ########################################################################### @@ -342,8 +358,7 @@ def _retrieve_events(self, live_buttons): return self._retrieve_mouse_events(live_buttons) def _get_timebase(self): - """Get mouse time reference (in seconds) - """ + """Get mouse time reference (in seconds)""" return clock() def _clear_mouse_events(self): @@ -365,35 +380,35 @@ def _on_pyglet_mouse_click(self, x, y, button, modifiers): self._mouse_buffer.append((this_button, x, y, button_time)) def listen_clicks(self): - """Start listening for mouse clicks. - """ + """Start listening for mouse clicks.""" self.time_correction = self.get_time_corr() self.listen_start = self.master_clock() self._clear_events() def get_clicks(self, live_buttons, timestamp, relative_to): - """Get the current entire mouse buffer. - """ + """Get the current entire mouse buffer.""" clicked = [] if timestamp and relative_to is None: if self.listen_start is None: - raise ValueError('I cannot timestamp: relative_to is None and ' - 'you have not yet called listen_clicks.') + raise ValueError( + "I cannot timestamp: relative_to is None and " + "you have not yet called listen_clicks." + ) else: relative_to = self.listen_start clicked = self._retrieve_events(live_buttons) return self._correct_clicks(clicked, timestamp, relative_to) - def wait_one_click(self, max_wait, min_wait, live_buttons, - timestamp, relative_to, visible): - """Returns only the first button clicked after min_wait. - """ + def wait_one_click( + self, max_wait, min_wait, live_buttons, timestamp, relative_to, visible + ): + """Returns only the first button clicked after min_wait.""" relative_to, start_time, was_visible = self._init_wait_click( - max_wait, min_wait, live_buttons, timestamp, relative_to, visible) + max_wait, min_wait, live_buttons, timestamp, relative_to, visible + ) clicked = [] - while (not len(clicked) and - self.master_clock() - start_time < max_wait): + while not len(clicked) and self.master_clock() - start_time < max_wait: clicked = self._retrieve_events(live_buttons) # handle non-clicks @@ -405,29 +420,30 @@ def wait_one_click(self, max_wait, min_wait, live_buttons, clicked = None return clicked - def wait_for_clicks(self, max_wait, min_wait, live_buttons, - timestamp, relative_to, visible=None): - """Returns all clicks between min_wait and max_wait. - """ + def wait_for_clicks( + self, max_wait, min_wait, live_buttons, timestamp, relative_to, visible=None + ): + """Returns all clicks between min_wait and max_wait.""" relative_to, start_time, was_visible = self._init_wait_click( - max_wait, min_wait, live_buttons, timestamp, relative_to, visible) + max_wait, min_wait, live_buttons, timestamp, relative_to, visible + ) clicked = [] - while (self.master_clock() - start_time < max_wait): + while self.master_clock() - start_time < max_wait: clicked = self._retrieve_events(live_buttons) return self._correct_clicks(clicked, timestamp, relative_to) - def wait_for_click_on(self, objects, max_wait, min_wait, - live_buttons, timestamp, relative_to): - """Waits for a click on one of the supplied window objects - """ + def wait_for_click_on( + self, objects, max_wait, min_wait, live_buttons, timestamp, relative_to + ): + """Waits for a click on one of the supplied window objects""" relative_to, start_time, was_visible = self._init_wait_click( - max_wait, min_wait, live_buttons, timestamp, relative_to, True) + max_wait, min_wait, live_buttons, timestamp, relative_to, True + ) index = None ci = 0 - while (self.master_clock() - start_time < max_wait and - index is None): + while self.master_clock() - start_time < max_wait and index is None: clicked = self._retrieve_events(live_buttons) self._check_force_quit() while ci < len(clicked) and index is None: # clicks first @@ -453,29 +469,28 @@ def wait_for_click_on(self, objects, max_wait, min_wait, def _correct_clicks(self, clicked, timestamp, relative_to): """Correct timing of clicks""" if len(clicked): - clicked = [(b, x, y, s + self.time_correction) for - b, x, y, s in clicked] + clicked = [(b, x, y, s + self.time_correction) for b, x, y, s in clicked] self.log_clicks(clicked) buttons = [(b, x, y) for b, x, y, _ in clicked] self._check_force_quit() if timestamp: - clicked = [(b, x, y, s - relative_to) for - b, x, y, s in clicked] + clicked = [(b, x, y, s - relative_to) for b, x, y, s in clicked] else: clicked = buttons return clicked - def _init_wait_click(self, max_wait, min_wait, live_buttons, timestamp, - relative_to, visible): - """Actions common to ``wait_one_click`` and ``wait_for_clicks`` - """ + def _init_wait_click( + self, max_wait, min_wait, live_buttons, timestamp, relative_to, visible + ): + """Actions common to ``wait_one_click`` and ``wait_for_clicks``""" if np.isinf(max_wait) and live_buttons == []: - raise ValueError('max_wait cannot be infinite if there are no live' - ' mouse buttons.') + raise ValueError( + "max_wait cannot be infinite if there are no live" " mouse buttons." + ) if not min_wait <= max_wait: - raise ValueError('min_wait must be less than max_wait') + raise ValueError("min_wait must be less than max_wait") if visible not in [True, False, None]: - raise ValueError('set_visible must be one of (True, False, None)') + raise ValueError("set_visible must be one of (True, False, None)") start_time = self.master_clock() if timestamp and relative_to is None: relative_to = start_time @@ -489,27 +504,25 @@ def _init_wait_click(self, max_wait, min_wait, live_buttons, timestamp, # Define some functions for determining if a click point is in an object def _point_in_object(self, pos, obj): - """Determine if a point is within a visual object - """ + """Determine if a point is within a visual object""" if isinstance(obj, (Rectangle, Circle, Diamond, Triangle)): return self._point_in_tris(pos, obj) elif isinstance(obj, (ConcentricCircles, FixationDot)): return np.any([self._point_in_tris(pos, c) for c in obj._circles]) def _point_in_tris(self, pos, obj): - """Check to see if a point is in any of the triangles - """ - these_tris = obj._tris['fill'].reshape(-1, 3) + """Check to see if a point is in any of the triangles""" + these_tris = obj._tris["fill"].reshape(-1, 3) for tri in these_tris: - if self._point_in_tri(pos, obj._points['fill'][tri]): + if self._point_in_tri(pos, obj._points["fill"][tri]): return True return False def _point_in_tri(self, pos, tri): - """Check to see if a point is in a single triangle - """ - signs = np.sign([np.cross(tri[np.mod(i + 1, 3)] - tri[i], - pos - tri[i]) for i in range(3)]) + """Check to see if a point is in a single triangle""" + signs = np.sign( + [_cross_2d(tri[np.mod(i + 1, 3)] - tri[i], pos - tri[i]) for i in range(3)] + ) if np.all(signs[1:] == signs[0]): return True else: @@ -517,12 +530,16 @@ def _point_in_tri(self, pos, tri): def _move_to(self, pos, units): # adapted from pyautogui (BSD) - x, y = self.ec._convert_units(np.array( - [pos]).T, units, 'pix')[:, 0].round().astype(int) + x, y = ( + self.ec._convert_units(np.array([pos]).T, units, "pix")[:, 0] + .round() + .astype(int) + ) # The "y" we use is inverted relative to the OSes y = self.ec.window.height - y - if sys.platform == 'darwin': - from pyglet.libs.darwin.cocoapy import quartz, CGPoint, CGRect + if sys.platform == "darwin": + from pyglet.libs.darwin.cocoapy import CGPoint, CGRect, quartz + # Convert from window to global view, window = self.ec.window._nsview, self.ec.window._nswindow point = CGPoint() @@ -548,9 +565,10 @@ def _move_to(self, pos, units): # func(kCGHIDEventTap, event) # time.sleep(0.001) # quartz.CFRelease(event) - elif sys.platform.startswith('win'): + elif sys.platform.startswith("win"): # Convert from window to global from pyglet.window.win32 import POINT, _user32, byref + point = POINT() point.x = x point.y = y @@ -559,8 +577,13 @@ def _move_to(self, pos, units): _user32.SetCursorPos(point.x, point.y) else: # https://stackoverflow.com/questions/2433447 - from pyglet.libs.x11.xlib import (XWarpPointer, XFlush, - XSelectInput, KeyReleaseMask) + from pyglet.libs.x11.xlib import ( + KeyReleaseMask, + XFlush, + XSelectInput, + XWarpPointer, + ) + display, window = self.ec.window._x_display, self.ec.window._window XSelectInput(display, window, KeyReleaseMask) XWarpPointer(display, 0, window, 0, 0, 0, 0, x, y) @@ -576,13 +599,14 @@ class CedrusBox(Keyboard): def __init__(self, ec, force_quit_keys): import pyxid + pyxid.use_response_pad_timer = True dev = pyxid.get_xid_devices()[0] dev.reset_base_timer() assert dev.is_response_device() self._dev = dev - super(CedrusBox, self).__init__(ec, force_quit_keys) - ec._time_correction_maxs['keypress'] = 1e-3 # higher tolerance + super().__init__(ec, force_quit_keys) + ec._time_correction_maxs["keypress"] = 1e-3 # higher tolerance def _get_timebase(self): """WARNING: For now this will clear the event queue!""" @@ -599,7 +623,7 @@ def _clear_events(self): self._retrieve_events(None) self._keyboard_buffer = [] - def _retrieve_events(self, live_keys, kind='presses'): + def _retrieve_events(self, live_keys, kind="presses"): # add escape keys if live_keys is not None: live_keys = [str(x) for x in live_keys] # accept ints @@ -608,9 +632,8 @@ def _retrieve_events(self, live_keys, kind='presses'): self._dev.poll_for_response() while self._dev.response_queue_size() > 0: key = self._dev.get_next_response() - press_or_release = {True: 'press', - False: 'release'}[key['pressed']] - key = [str(key['key'] + 1), key['time'] / 1000., press_or_release] + press_or_release = {True: "press", False: "release"}[key["pressed"]] + key = [str(key["key"] + 1), key["time"] / 1000.0, press_or_release] self._keyboard_buffer.append(key) self._dev.poll_for_response() # check to see if we have matches @@ -632,39 +655,47 @@ class Joystick(Keyboard): def __init__(self, ec): import pyglet.input + self.ec = ec self.master_clock = ec._master_clock - self.log_presses = partial(ec._log_presses, kind='joy') + self.log_presses = partial(ec._log_presses, kind="joy") self.force_quit_keys = [] self.listen_start = None - ec._time_correction_fxns['joystick'] = self._get_timebase - self.get_time_corr = partial(ec._get_time_correction, 'joystick') + ec._time_correction_fxns["joystick"] = self._get_timebase + self.get_time_corr = partial(ec._get_time_correction, "joystick") self.time_correction = self.get_time_corr() self._keyboard_buffer = [] self._dev = pyglet.input.get_joysticks()[0] - logger.info('Expyfun: Initializing joystick %s' % (self._dev.device,)) + logger.info("Expyfun: Initializing joystick %s" % (self._dev.device,)) self._dev.open(window=ec._win, exclusive=True) - assert hasattr(self._dev, 'on_joybutton_press') - self._dev.on_joybutton_press = partial( - self._on_pyglet_joybutton, kind='press') + assert hasattr(self._dev, "on_joybutton_press") + self._dev.on_joybutton_press = partial(self._on_pyglet_joybutton, kind="press") self._dev.on_joybutton_release = partial( - self._on_pyglet_joybutton, kind='release') + self._on_pyglet_joybutton, kind="release" + ) - def _on_pyglet_joybutton(self, joystick, button='foo', kind='press'): + def _on_pyglet_joybutton(self, joystick, button="foo", kind="press"): """Handler for on_joybutton_press events.""" key_time = clock() self._keyboard_buffer.append((str(button), key_time, kind)) def _close(self): - dev = getattr(self, '_dev', None) + dev = getattr(self, "_dev", None) if dev is not None: dev.close() -for key in ('x', 'y', 'hat_x', 'hat_y', 'z', 'rz', 'rx', 'ry'): +for key in ("x", "y", "hat_x", "hat_y", "z", "rz", "rx", "ry"): + def _wrap(key=key): - sign = -1 if key in ('rz',) else 1 + sign = -1 if key in ("rz",) else 1 return property(lambda self: sign * getattr(self._dev, key)) + setattr(Joystick, key, _wrap()) del _wrap del key + + +# https://github.com/numpy/numpy/pull/26694/files +def _cross_2d(x, y): + return x[..., 0] * y[..., 1] - x[..., 1] * y[..., 0] diff --git a/expyfun/_parallel.py b/expyfun/_parallel.py index 127ab079..6150947b 100644 --- a/expyfun/_parallel.py +++ b/expyfun/_parallel.py @@ -1,6 +1,4 @@ -# -*- coding: utf-8 -*- -"""Parallel util functions -""" +"""Parallel util functions""" # Adapted from mne-python with permission @@ -62,9 +60,10 @@ def _check_n_jobs(n_jobs): The checked number of jobs. Always positive. """ if not isinstance(n_jobs, int): - raise TypeError('n_jobs must be an integer') + raise TypeError("n_jobs must be an integer") if n_jobs <= 0: import multiprocessing + n_cores = multiprocessing.cpu_count() n_jobs = max(min(n_cores + n_jobs + 1, n_cores), 1) return n_jobs diff --git a/expyfun/_sound_controllers/__init__.py b/expyfun/_sound_controllers/__init__.py index 3d779b52..809bd6c7 100644 --- a/expyfun/_sound_controllers/__init__.py +++ b/expyfun/_sound_controllers/__init__.py @@ -1,3 +1,8 @@ -from ._sound_controller import (SoundCardController, SoundPlayer, _BACKENDS, - _import_backend) -_AUTO_BACKENDS = ('auto',) + _BACKENDS +from ._sound_controller import ( + SoundCardController, + SoundPlayer, + _BACKENDS, + _import_backend, +) + +_AUTO_BACKENDS = ("auto",) + _BACKENDS diff --git a/expyfun/_sound_controllers/_pyglet.py b/expyfun/_sound_controllers/_pyglet.py index e4e8a1f8..0fc00864 100644 --- a/expyfun/_sound_controllers/_pyglet.py +++ b/expyfun/_sound_controllers/_pyglet.py @@ -10,21 +10,18 @@ import warnings import numpy as np - import pyglet from .._utils import _new_pyglet _PRIORITY = 200 -_use_silent = (os.getenv('_EXPYFUN_SILENT', '') == 'true') -_opts_dict = dict(linux2=('pulse',), - win32=('directsound',), - darwin=('openal',)) -_opts_dict['linux'] = _opts_dict['linux2'] # new name on Py3k -_driver = _opts_dict[sys.platform] if not _use_silent else ('silent',) +_use_silent = os.getenv("_EXPYFUN_SILENT", "") == "true" +_opts_dict = dict(linux2=("pulse",), win32=("directsound",), darwin=("openal",)) +_opts_dict["linux"] = _opts_dict["linux2"] # new name on Py3k +_driver = _opts_dict[sys.platform] if not _use_silent else ("silent",) -pyglet.options['audio'] = _driver +pyglet.options["audio"] = _driver # We might also want this at some point if we hit OSX problems: # pyglet.options['shadow_window'] = False @@ -35,6 +32,7 @@ except ImportError: from pyglet.media import AudioFormat from pyglet.media import Player, SourceGroup # noqa + try: from pyglet.media.codecs import StaticMemorySource except ImportError: @@ -43,30 +41,42 @@ except ImportError: from pyglet.media.sources.base import StaticMemorySource # noqa except Exception as exp: - warnings.warn('Pyglet could not be imported:\n%s' % exp) + warnings.warn("Pyglet could not be imported:\n%s" % exp) Player = AudioFormat = SourceGroup = StaticMemorySource = object def _check_pyglet_audio(): - if pyglet.media.get_audio_driver() is None and \ - not (_new_pyglet() and _driver == ('silent',)): - raise SystemError('pyglet audio ("%s") could not be initialized' - % pyglet.options['audio'][0]) + if pyglet.media.get_audio_driver() is None and not ( + _new_pyglet() and _driver == ("silent",) + ): + raise SystemError( + 'pyglet audio ("%s") could not be initialized' % pyglet.options["audio"][0] + ) class SoundPlayer(Player): """SoundPlayer based on Pyglet.""" - def __init__(self, data, fs=None, loop=False, api=None, name=None, - fixed_delay=None, api_options=None): + def __init__( + self, + data, + fs=None, + loop=False, + api=None, + name=None, + fixed_delay=None, + api_options=None, + ): assert AudioFormat is not None if any(x is not None for x in (api, name, fixed_delay, api_options)): - raise ValueError('The Pyglet backend does not support specifying ' - 'api, name, fixed_delay, or api_options') + raise ValueError( + "The Pyglet backend does not support specifying " + "api, name, fixed_delay, or api_options" + ) # We could maybe let Pyglet make this decision, but hopefully # people won't need to tweak the Pyglet backend anyway self.fs = 44100 if fs is None else fs - super(SoundPlayer, self).__init__() + super().__init__() _check_pyglet_audio() sms = _as_static(data, self.fs) if _new_pyglet(): @@ -79,30 +89,32 @@ def __init__(self, data, fs=None, loop=False, api=None, name=None, self.queue(group) self._ec_duration = sms._duration - def stop(self, wait=True, extra_delay=0.): + def stop(self, wait=True, extra_delay=0.0): """Stop.""" - self.pause() - self.seek(0.) + try: + self.pause() + # assert timestamp >= 0, 'Timestamp beyond dequeued source memory' + except AssertionError: + pass + self.seek(0.0) @property def playing(self): # Pyglet has this, but it doesn't notice when it's finished on its own - return (super(SoundPlayer, self).playing and not - np.isclose(self.time, self._ec_duration)) + return super().playing and not np.isclose(self.time, self._ec_duration) def _as_static(data, fs): """Get data into the Pyglet audio format.""" fs = int(fs) if data.ndim not in (1, 2): - raise ValueError('Data must have one or two dimensions') + raise ValueError("Data must have one or two dimensions") n_ch = data.shape[0] if data.ndim == 2 else 1 - audio_format = AudioFormat(channels=n_ch, sample_size=16, - sample_rate=fs) - data = data.T.ravel('C') + audio_format = AudioFormat(channels=n_ch, sample_size=16, sample_rate=fs) + data = data.T.ravel("C") data[data < -1] = -1 data[data > 1] = 1 - data = (data * (2 ** 15)).astype('int16').tobytes() + data = (data * (2**15)).astype("int16").tobytes() return StaticMemorySourceFixed(data, audio_format) diff --git a/expyfun/_sound_controllers/_rtmixer.py b/expyfun/_sound_controllers/_rtmixer.py index 233f6b60..71673853 100644 --- a/expyfun/_sound_controllers/_rtmixer.py +++ b/expyfun/_sound_controllers/_rtmixer.py @@ -7,10 +7,10 @@ import sys import numpy as np - -from rtmixer import Mixer, RingBuffer import sounddevice -from .._utils import logger, get_config +from rtmixer import Mixer, RingBuffer + +from .._utils import get_config, logger _PRIORITY = 100 _DEFAULT_NAME = None @@ -19,11 +19,11 @@ # only initialize each mixer once and reuse it until this gets garbage # collected -class _MixerRegistry(dict): +class _MixerRegistry(dict): def __del__(self): for mixer in self.values(): - print(f'Closing {mixer}') + print(f"Closing {mixer}") mixer.abort() mixer.close() self.clear() @@ -32,15 +32,15 @@ def _get_mixer(self, fs, n_channels, api, name, api_options): """Select the API and device.""" # API if api is None: - api = get_config('SOUND_CARD_API', None) + api = get_config("SOUND_CARD_API", None) if api is None: # Eventually we should maybe allow 'Windows WDM-KS', # 'Windows DirectSound', or 'MME' api = dict( - darwin='Core Audio', - win32='Windows WASAPI', - linux='ALSA', - linux2='ALSA', + darwin="Core Audio", + win32="Windows WASAPI", + linux="ALSA", + linux2="ALSA", )[sys.platform] key = (fs, n_channels, api, name) if key not in self: @@ -54,83 +54,97 @@ def _get_mixer(self, fs, n_channels, api, name, api_options): def _init_mixer(fs, n_channels, api, name, api_options=None): devices = sounddevice.query_devices() if len(devices) == 0: - raise OSError('No sound devices found!') + raise OSError("No sound devices found!") apis = sounddevice.query_hostapis() valid_apis = [] for ai, this_api in enumerate(apis): - if this_api['name'] == api: + if this_api["name"] == api: api = this_api break else: - valid_apis.append(this_api['name']) + valid_apis.append(this_api["name"]) else: m = 'Could not find host API %s. Valid choices are "%s"' - raise RuntimeError(m % (api, ', '.join(valid_apis))) + raise RuntimeError(m % (api, ", ".join(valid_apis))) del this_api # Name if name is None: - name = get_config('SOUND_CARD_NAME', None) + name = get_config("SOUND_CARD_NAME", None) if name is None: global _DEFAULT_NAME if _DEFAULT_NAME is None: - di = api['default_output_device'] - _DEFAULT_NAME = devices[di]['name'] - logger.exp('Selected default sound device: %r' % (_DEFAULT_NAME,)) + di = api["default_output_device"] + _DEFAULT_NAME = devices[di]["name"] + logger.exp("Selected default sound device: %r" % (_DEFAULT_NAME,)) name = _DEFAULT_NAME possible = list() for di, device in enumerate(devices): - if device['hostapi'] == ai: - possible.append(device['name']) - if name in device['name']: + if device["hostapi"] == ai: + possible.append(device["name"]) + if name in device["name"]: break else: - raise RuntimeError('Could not find device on API %r with name ' - 'containing %r, found:\n%s' - % (api['name'], name, '\n'.join(possible))) - param_str = ('sound card %r (devices[%d]) via %r' - % (device['name'], di, api['name'])) + raise RuntimeError( + "Could not find device on API %r with name " + "containing %r, found:\n%s" % (api["name"], name, "\n".join(possible)) + ) + param_str = "sound card %r (devices[%d]) via %r" % (device["name"], di, api["name"]) extra_settings = None if api_options is not None: - if api['name'] == 'Windows WASAPI': + if api["name"] == "Windows WASAPI": # exclusive mode is needed for zero jitter on Windows in testing extra_settings = sounddevice.WasapiSettings(**api_options) else: raise ValueError( 'api_options only supported for "Windows WASAPI" backend, ' - 'using %s backend got api_options=%s' - % (api['name'], api_options)) - param_str += ' with options %s' % (api_options,) - param_str += ', %d channels' % (n_channels,) + "using %s backend got api_options=%s" % (api["name"], api_options) + ) + param_str += " with options %s" % (api_options,) + param_str += ", %d channels" % (n_channels,) if fs is not None: - param_str += ' @ %d Hz' % (fs,) + param_str += " @ %d Hz" % (fs,) try: mixer = Mixer( - samplerate=fs, latency='low', channels=n_channels, - dither_off=True, device=di, - extra_settings=extra_settings) + samplerate=fs, + latency="low", + channels=n_channels, + dither_off=True, + device=di, + extra_settings=extra_settings, + ) except Exception as exp: - raise RuntimeError('Could not set up %s:\n%s' % (param_str, exp)) + raise RuntimeError(f"Could not set up {param_str}:\n{exp}") from None assert mixer.channels == n_channels if fs is None: - param_str += ' @ %d Hz' % (mixer.samplerate,) + param_str += " @ %d Hz" % (mixer.samplerate,) else: assert mixer.samplerate == fs mixer.start() assert mixer.active - logger.info('Expyfun: using %s, %0.1f ms nominal latency' - % (param_str, 1000 * device['default_low_output_latency'])) + logger.info( + "Expyfun: using %s, %0.1f ms nominal latency" + % (param_str, 1000 * device["default_low_output_latency"]) + ) return mixer -class SoundPlayer(object): +class SoundPlayer: """SoundPlayer based on rtmixer.""" - def __init__(self, data, fs=None, loop=False, api=None, name=None, - fixed_delay=None, api_options=None): + def __init__( + self, + data, + fs=None, + loop=False, + api=None, + name=None, + fixed_delay=None, + api_options=None, + ): data = np.atleast_2d(data).T - data = np.asarray(data, np.float32, 'C') + data = np.asarray(data, np.float32, "C") self._data = data self.loop = bool(loop) self._n_samples, n_channels = self._data.shape @@ -138,20 +152,23 @@ def __init__(self, data, fs=None, loop=False, api=None, name=None, self._n_channels = n_channels self._mixer = None # in case the next line crashes, __del__ works self._mixer = _mixer_registry._get_mixer( - fs, self._n_channels, api, name, api_options) + fs, self._n_channels, api, name, api_options + ) if loop: - self._ring = RingBuffer(self._data.itemsize * self._n_channels, - self._data.size) + self._ring = RingBuffer( + self._data.itemsize * self._n_channels, self._data.size + ) self._ring.write(self._data) self._fs = float(self._mixer.samplerate) self._ec_duration = self._n_samples / self._fs self._action = None self._fixed_delay = fixed_delay if fixed_delay is not None: - logger.info('Expyfun: Using fixed audio delay %0.1f ms' - % (1000 * fixed_delay,)) + logger.info( + "Expyfun: Using fixed audio delay %0.1f ms" % (1000 * fixed_delay,) + ) else: - logger.info('Expyfun: Variable audio delay') + logger.info("Expyfun: Variable audio delay") @property def fs(self): @@ -166,25 +183,28 @@ def _start_time(self): if self._fixed_delay is not None: return self._mixer.time + self._fixed_delay else: - return 0. + return 0.0 def play(self): """Play.""" if not self.playing and self._mixer is not None: if self.loop: self._action = self._mixer.play_ringbuffer( - self._ring, start=self._start_time) + self._ring, start=self._start_time + ) else: self._action = self._mixer.play_buffer( - self._data, self._data.shape[1], start=self._start_time) + self._data, self._data.shape[1], start=self._start_time + ) - def stop(self, wait=True, extra_delay=0.): + def stop(self, wait=True, extra_delay=0.0): """Stop.""" if self.playing: action, self._action = self._action, None # Impose the same delay here that we imposed on the stim start cancel_action = self._mixer.cancel( - action, time=self._start_time + extra_delay) + action, time=self._start_time + extra_delay + ) if wait: self._mixer.wait(cancel_action) else: @@ -192,12 +212,33 @@ def stop(self, wait=True, extra_delay=0.): def delete(self): """Delete.""" - if getattr(self, '_mixer', None) is not None: + if getattr(self, "_mixer", None) is not None: self.stop(wait=False) mixer, self._mixer = self._mixer, None - stats = mixer.fetch_and_reset_stats().stats - logger.exp('%d underflows %d blocks' - % (stats.output_underflows, stats.blocks)) + try: + stats = mixer.fetch_and_reset_stats().stats + except RuntimeError as exc: # action queue is full + logger.exp(f"Could not fetch mixer stats ({exc})") + else: + logger.exp( + f"{stats.output_underflows} underflows " f"{stats.blocks} blocks" + ) def __del__(self): # noqa self.delete() + + +def _abort_all_queues(): + for mixer in _mixer_registry.values(): + if len(mixer.actions) == 0: + continue + do_start_stop = mixer.stopped + if do_start_stop: + mixer.start() + for action in list(mixer.actions): + mixer.wait(mixer.cancel(action)) + mixer.wait() + assert len(mixer.actions) == 0, mixer.actions + if do_start_stop: + mixer.abort(ignore_errors=False) + assert len(mixer.actions) == 0, mixer.actions diff --git a/expyfun/_sound_controllers/_sound_controller.py b/expyfun/_sound_controllers/_sound_controller.py index 10187fce..bef54133 100644 --- a/expyfun/_sound_controllers/_sound_controller.py +++ b/expyfun/_sound_controllers/_sound_controller.py @@ -9,31 +9,42 @@ import operator import os import os.path as op - -import numpy as np - -from .._fixes import rfft, irfft, rfftfreq -from .._utils import logger, flush_logger, _check_params import warnings +import numpy as np -_BACKENDS = tuple(sorted( - op.splitext(x.lstrip('._'))[0] for x in os.listdir(op.dirname(__file__)) - if x.startswith('_') and x.endswith(('.py', '.pyc')) and - not x.startswith(('_sound_controller.py', '__init__.py')))) +from .._fixes import irfft, rfft, rfftfreq +from .._utils import _check_params, flush_logger, logger + +_BACKENDS = tuple( + sorted( + op.splitext(x.lstrip("._"))[0] + for x in os.listdir(op.dirname(__file__)) + if x.startswith("_") + and x.endswith((".py", ".pyc")) + and not x.startswith(("_sound_controller.py", "__init__.py")) + ) +) # libsoundio stub (kind of iffy) # https://gist.github.com/larsoner/fd9228f321d369c8a00c66a246fcc83f _SOUND_CARD_KEYS = ( - 'TYPE', 'SOUND_CARD_BACKEND', 'SOUND_CARD_API', - 'SOUND_CARD_NAME', 'SOUND_CARD_FS', 'SOUND_CARD_FIXED_DELAY', - 'SOUND_CARD_TRIGGER_CHANNELS', 'SOUND_CARD_API_OPTIONS', - 'SOUND_CARD_TRIGGER_SCALE', 'SOUND_CARD_TRIGGER_INSERTION', - 'SOUND_CARD_TRIGGER_ID_AFTER_ONSET', 'SOUND_CARD_DRIFT_TRIGGER', + "TYPE", + "SOUND_CARD_BACKEND", + "SOUND_CARD_API", + "SOUND_CARD_NAME", + "SOUND_CARD_FS", + "SOUND_CARD_FIXED_DELAY", + "SOUND_CARD_TRIGGER_CHANNELS", + "SOUND_CARD_API_OPTIONS", + "SOUND_CARD_TRIGGER_SCALE", + "SOUND_CARD_TRIGGER_INSERTION", + "SOUND_CARD_TRIGGER_ID_AFTER_ONSET", + "SOUND_CARD_DRIFT_TRIGGER", ) -class SoundCardController(object): +class SoundCardController: """Use a sound card. Parameters @@ -90,33 +101,32 @@ class SoundCardController(object): the configuration file. """ - def __init__(self, params, stim_fs, n_channels=2, trigger_duration=0.01, - ec=None): + def __init__(self, params, stim_fs, n_channels=2, trigger_duration=0.01, ec=None): self.ec = ec defaults = dict( - SOUND_CARD_BACKEND='auto', + SOUND_CARD_BACKEND="auto", SOUND_CARD_TRIGGER_CHANNELS=0, - SOUND_CARD_TRIGGER_SCALE=1. / float(2 ** 31 - 1), - SOUND_CARD_TRIGGER_INSERTION='prepend', + SOUND_CARD_TRIGGER_SCALE=1.0 / float(2**31 - 1), + SOUND_CARD_TRIGGER_INSERTION="prepend", SOUND_CARD_TRIGGER_ID_AFTER_ONSET=False, - SOUND_CARD_DRIFT_TRIGGER='end', + SOUND_CARD_DRIFT_TRIGGER="end", ) # any omitted become None - params = _check_params(params, _SOUND_CARD_KEYS, defaults, 'params') - if params['SOUND_CARD_FS'] is not None: - params['SOUND_CARD_FS'] = float(params['SOUND_CARD_FS']) - - self.backend, self.backend_name = _import_backend( - params['SOUND_CARD_BACKEND']) - self._n_channels_stim = int(params['SOUND_CARD_TRIGGER_CHANNELS']) - trig_scale = float(params['SOUND_CARD_TRIGGER_SCALE']) + params = _check_params(params, _SOUND_CARD_KEYS, defaults, "params") + if params["SOUND_CARD_FS"] is not None: + params["SOUND_CARD_FS"] = float(params["SOUND_CARD_FS"]) + + self.backend, self.backend_name = _import_backend(params["SOUND_CARD_BACKEND"]) + self._n_channels_stim = int(params["SOUND_CARD_TRIGGER_CHANNELS"]) + trig_scale = float(params["SOUND_CARD_TRIGGER_SCALE"]) self._id_after_onset = ( - str(params['SOUND_CARD_TRIGGER_ID_AFTER_ONSET']).lower() == 'true') + str(params["SOUND_CARD_TRIGGER_ID_AFTER_ONSET"]).lower() == "true" + ) self._extra_onset_triggers = list() - drift_trigger = params['SOUND_CARD_DRIFT_TRIGGER'] + drift_trigger = params["SOUND_CARD_DRIFT_TRIGGER"] if np.isscalar(drift_trigger): drift_trigger = [drift_trigger] # convert possible command-line option - if isinstance(drift_trigger, str) and drift_trigger != 'end': + if isinstance(drift_trigger, str) and drift_trigger != "end": drift_trigger = eval(drift_trigger) if isinstance(drift_trigger, str): drift_trigger = [drift_trigger] @@ -124,55 +134,59 @@ def __init__(self, params, stim_fs, n_channels=2, trigger_duration=0.01, drift_trigger = list(drift_trigger) # make mutable for trig in drift_trigger: if isinstance(trig, str): - assert trig == 'end', trig + assert trig == "end", trig else: assert isinstance(trig, (int, float)), type(trig) self._drift_trigger_time = drift_trigger assert self._n_channels_stim >= 0 self._n_channels = int(operator.index(n_channels)) del n_channels - insertion = str(params['SOUND_CARD_TRIGGER_INSERTION']) - if insertion not in ('prepend', 'append'): - raise ValueError('SOUND_CARD_TRIGGER_INSERTION must be "prepend" ' - 'or "append", got %r' % (insertion,)) - self._stim_sl = slice(None, None, 1 if insertion == 'prepend' else -1) - extra = '' + insertion = str(params["SOUND_CARD_TRIGGER_INSERTION"]) + if insertion not in ("prepend", "append"): + raise ValueError( + 'SOUND_CARD_TRIGGER_INSERTION must be "prepend" ' + 'or "append", got %r' % (insertion,) + ) + self._stim_sl = slice(None, None, 1 if insertion == "prepend" else -1) + extra = "" if self._n_channels_stim: - extra = ('%d %sed stim and ' - % (self._n_channels_stim, insertion)) + extra = "%d %sed stim and " % (self._n_channels_stim, insertion) else: - extra = '' + extra = "" del insertion - logger.info('Expyfun: Setting up sound card using %s backend with %s' - '%d playback channels' - % (self.backend_name, extra, self._n_channels)) - self._kwargs = {key: params['SOUND_CARD_' + key.upper()] for key in ( - 'fs', 'api', 'name', 'fixed_delay', 'api_options')} + logger.info( + "Expyfun: Setting up sound card using %s backend with %s" + "%d playback channels" % (self.backend_name, extra, self._n_channels) + ) + self._kwargs = { + key: params["SOUND_CARD_" + key.upper()] + for key in ("fs", "api", "name", "fixed_delay", "api_options") + } temp_sound = np.zeros((self._n_channels_tot, 1000)) temp_sound = self.backend.SoundPlayer(temp_sound, **self._kwargs) - self.fs = temp_sound.fs - temp_sound.stop(wait=False) + self.fs = float(temp_sound.fs) + self._mixer = getattr(temp_sound, "_mixer", None) del temp_sound # Need to generate at RMS=1 to match TDT circuit, and use a power of # 2 length for the RingBuffer (here make it >= 15 sec) - n_samples = 2 ** int(np.ceil(np.log2(self.fs * 15.))) + n_samples = 2 ** int(np.ceil(np.log2(self.fs * 15.0))) noise = np.random.normal(0, 1.0, (self._n_channels, n_samples)) # Low-pass if necessary if stim_fs < self.fs: # note we can use cheap DFT method here b/c # circular convolution won't matter for AWGN (yay!) - freqs = rfftfreq(noise.shape[-1], 1. / self.fs) + freqs = rfftfreq(noise.shape[-1], 1.0 / self.fs) noise = rfft(noise, axis=-1) - noise[:, np.abs(freqs) > stim_fs / 2.] = 0.0 + noise[:, np.abs(freqs) > stim_fs / 2.0] = 0.0 noise = irfft(noise, axis=-1) # ensure true RMS of 1.0 (DFT method also lowers RMS, compensate here) noise /= np.sqrt(np.mean(noise * noise)) noise = np.concatenate( - (np.zeros((self._n_channels_stim, noise.shape[1]), noise.dtype), - noise)) + (np.zeros((self._n_channels_stim, noise.shape[1]), noise.dtype), noise) + ) self.noise_array = noise self.noise_level = 0.01 self.noise = None @@ -183,8 +197,10 @@ def __init__(self, params, stim_fs, n_channels=2, trigger_duration=0.01, flush_logger() def __repr__(self): - return ('' - % (self._n_channels, self._n_channels_stim)) + return "" % ( + self._n_channels, + self._n_channels_stim, + ) @property def _n_channels_tot(self): @@ -194,7 +210,8 @@ def start_noise(self): """Start noise.""" if not self._noise_playing: self.noise = self.backend.SoundPlayer( - self.noise_array * self.noise_level, loop=True, **self._kwargs) + self.noise_array * self.noise_level, loop=True, **self._kwargs + ) self.noise.play() def stop_noise(self, wait=False): @@ -235,46 +252,49 @@ def load_buffer(self, samples): sample_len = len(samples) extra = sample_len - stim_len if extra > 0: # stim shorter than samples (typical) - stim = np.pad(stim, ((0, extra), (0, 0)), 'constant') + stim = np.pad(stim, ((0, extra), (0, 0)), "constant") elif extra < 0: # samples shorter than stim (very brief stim) - samples = np.pad(samples, ((0, -extra), (0, 0)), 'constant') + samples = np.pad(samples, ((0, -extra), (0, 0)), "constant") # place the drift triggers trig2 = self._make_digital_trigger([2]) trig2_len = trig2.shape[0] trig2_starts = [] for trig2_time in self._drift_trigger_time: - if trig2_time == 'end': + if trig2_time == "end": stim[-trig2_len:] = np.bitwise_or(stim[-trig2_len:], trig2) - trig2_starts += [sample_len-trig2_len] + trig2_starts += [sample_len - trig2_len] else: trig2_start = int(np.round(trig2_time * self.fs)) - if ((trig2_start >= 0 and trig2_start <= stim_len) or - (trig2_start < 0 and abs(trig2_start) >= extra)): - warnings.warn('Drift triggers overlap' - ' with onset triggers.') - if ((trig2_start > 0 and - trig2_start > sample_len-trig2_len) or - (trig2_start < 0 and - abs(trig2_start) >= sample_len)): - warnings.warn('Drift trigger at {} seconds occurs' - ' outside stimulus window, ' - 'not stamping ' - 'trigger.'.format(trig2_time)) + if (trig2_start >= 0 and trig2_start <= stim_len) or ( + trig2_start < 0 and abs(trig2_start) >= extra + ): + warnings.warn("Drift triggers overlap" " with onset triggers.") + if (trig2_start > 0 and trig2_start > sample_len - trig2_len) or ( + trig2_start < 0 and abs(trig2_start) >= sample_len + ): + warnings.warn( + f"Drift trigger at {trig2_time} seconds occurs" + " outside stimulus window, " + "not stamping " + "trigger." + ) continue - stim[trig2_start:trig2_start+trig2_len] = \ - np.bitwise_or(stim[trig2_start:trig2_start+trig2_len], - trig2) + stim[trig2_start : trig2_start + trig2_len] = np.bitwise_or( + stim[trig2_start : trig2_start + trig2_len], trig2 + ) if trig2_start > 0: trig2_starts += [trig2_start] else: trig2_starts += [sample_len + trig2_start] if np.any(np.diff(trig2_starts) < trig2_len): - warnings.warn('Some 2-triggers overlap, times should be at ' - 'least {} seconds apart.'.format(trig2_len / - self.fs)) - self.ec.write_data_line('Drift triggers were stamped at the ' - 'folowing times: ', - str([t2s/self.fs for t2s in trig2_starts])) + warnings.warn( + "Some 2-triggers overlap, times should be at " + f"least {trig2_len / self.fs} seconds apart." + ) + self.ec.write_data_line( + "Drift triggers were stamped at the following times: ", + str([t2s / self.fs for t2s in trig2_starts]), + ) stim = self._scale_digital_trigger(stim) samples = np.concatenate((stim, samples)[self._stim_sl], axis=1) self.audio = self.backend.SoundPlayer(samples.T, **self._kwargs) @@ -317,16 +337,16 @@ def _make_digital_trigger(self, trigs, delay=None): stim = np.zeros((n_samples, self._n_channels_stim), np.int32) offset = 0 for trig in trigs: - stim[offset:offset + n_on] = trig + stim[offset : offset + n_on] = trig offset += n_each return stim def _scale_digital_trigger(self, triggers): - return ((triggers << 8) * - self._trig_scale).astype(np.float32) + return ((triggers << 8) * self._trig_scale).astype(np.float32) - def stamp_triggers(self, triggers, delay=None, wait_for_last=True, - is_trial_id=False): + def stamp_triggers( + self, triggers, delay=None, wait_for_last=True, is_trial_id=False + ): """Stamp a list of triggers with a given inter-trigger delay. Parameters @@ -350,8 +370,7 @@ def stamp_triggers(self, triggers, delay=None, wait_for_last=True, delay = 2 * self._trigger_duration stim = self._make_digital_trigger(triggers, delay) stim = self._scale_digital_trigger(stim) - stim = np.pad( - stim, ((0, 0), (0, self._n_channels)[self._stim_sl]), 'constant') + stim = np.pad(stim, ((0, 0), (0, self._n_channels)[self._stim_sl]), "constant") stim = self.backend.SoundPlayer(stim.T, **self._kwargs) stim.play() t_each = self._trigger_duration + delay @@ -404,11 +423,13 @@ def halt(self): """Halt.""" self.stop(wait=True) self.stop_noise(wait=True) + abort_all = getattr(self.backend, "_abort_all_queues", lambda: None) + abort_all() def _import_backend(backend): # Auto mode is special, will loop through all possible backends - if backend == 'auto': + if backend == "auto": backends = list() for backend in _BACKENDS: try: @@ -417,21 +438,21 @@ def _import_backend(backend): pass backends = sorted([backend._PRIORITY, backend] for backend in backends) if len(backends) == 0: - raise RuntimeError('Could not load any sound backend: %s' - % (_BACKENDS,)) + raise RuntimeError("Could not load any sound backend: %s" % (_BACKENDS,)) backend = op.splitext(op.basename(backends[0][1].__file__))[0][1:] if backend not in _BACKENDS: - raise ValueError('Unknown sound card backend %r, must be one of %s' - % (backend, ('auto',) + _BACKENDS)) - lib = importlib.import_module('._' + backend, - package='expyfun._sound_controllers') + raise ValueError( + "Unknown sound card backend %r, must be one of %s" + % (backend, ("auto",) + _BACKENDS) + ) + lib = importlib.import_module("._" + backend, package="expyfun._sound_controllers") return lib, backend -class SoundPlayer(object): +class SoundPlayer: """Play sounds via the sound card.""" def __new__(self, data, **kwargs): """Create a new instance.""" - backend = kwargs.pop('backend', 'auto') + backend = kwargs.pop("backend", "auto") return _import_backend(backend)[0].SoundPlayer(data, **kwargs) diff --git a/expyfun/_tdt_controller.py b/expyfun/_tdt_controller.py index 89ac7276..859e3479 100644 --- a/expyfun/_tdt_controller.py +++ b/expyfun/_tdt_controller.py @@ -7,31 +7,41 @@ # License: BSD (3-clause) import time -import numpy as np -from os import path as op -from functools import partial import warnings +from functools import partial +from os import path as op + +import numpy as np -from ._utils import _check_params, logger, ZeroClock from ._input_controllers import Keyboard +from ._utils import ZeroClock, _check_params, logger def _dummy_fun(self, name, ret, *args, **kwargs): - logger.info('dummy-tdt: {0} {1}'.format(name, str(args)[:20] + ' ... ' + - str(kwargs)[:20] + ' ...')) + logger.info( + "dummy-tdt: {0} {1}".format( + name, str(args)[:20] + " ... " + str(kwargs)[:20] + " ..." + ) + ) return ret -class DummyRPcoX(object): +class DummyRPcoX: """Dummy RPcoX.""" def __init__(self, model, interface): self.model = model self.interface = interface - names = ['LoadCOF', 'ClearCOF', 'Run', 'ZeroTag', 'SetTagVal', - 'GetSFreq', 'Halt'] - returns = [True, True, True, True, True, - 24414.0125, True] + names = [ + "LoadCOF", + "ClearCOF", + "Run", + "ZeroTag", + "SetTagVal", + "GetSFreq", + "Halt", + ] + returns = [True, True, True, True, True, 24414.0125, True] for name, ret in zip(names, returns): setattr(self, name, partial(_dummy_fun, self, name, ret)) self._clock = ZeroClock() @@ -40,7 +50,7 @@ def __init__(self, model, interface): def WriteTagVEX(self, name, offset, kind, data): """Write tag data.""" - if name == 'datainleft': + if name == "datainleft": self._stim_dur = len(data) / self.GetSFreq() return True @@ -54,14 +64,14 @@ def SoftTrg(self, trignum): def GetTagVal(self, name): """Get a tag value.""" - if name == 'masterclock': + if name == "masterclock": return self._clock.get_time() - elif name == 'npressabs': + elif name == "npressabs": return 0 - elif name == 'playing': - return (time.time() - self._play_start < self._stim_dur) + elif name == "playing": + return time.time() - self._play_start < self._stim_dur else: - raise ValueError('unknown tag "{0}"'.format(name)) + raise ValueError(f'unknown tag "{name}"') class TDTController(Keyboard): # lgtm [py/missing-call-to-init] @@ -100,36 +110,49 @@ class TDTController(Keyboard): # lgtm [py/missing-call-to-init] def __init__(self, tdt_params, ec): self.ec = ec - defaults = dict(TDT_MODEL='dummy', TDT_DELAY='0', TDT_TRIG_DELAY='0', - TYPE='tdt') # if not listed -> None - keys = ['TYPE', 'TDT_MODEL', 'TDT_CIRCUIT_PATH', 'TDT_INTERFACE', - 'TDT_DELAY', 'TDT_TRIG_DELAY'] - tdt_params = _check_params(tdt_params, keys, defaults, 'tdt_params') - if tdt_params['TYPE'] != 'tdt': - raise ValueError('tdt_params["TYPE"] must be "tdt", not ' - '{0}'.format(tdt_params['TYPE'])) - for key in ('TDT_DELAY', 'TDT_TRIG_DELAY'): + defaults = dict( + TDT_MODEL="dummy", TDT_DELAY="0", TDT_TRIG_DELAY="0", TYPE="tdt" + ) # if not listed -> None + keys = [ + "TYPE", + "TDT_MODEL", + "TDT_CIRCUIT_PATH", + "TDT_INTERFACE", + "TDT_DELAY", + "TDT_TRIG_DELAY", + ] + tdt_params = _check_params(tdt_params, keys, defaults, "tdt_params") + if tdt_params["TYPE"] != "tdt": + raise ValueError( + 'tdt_params["TYPE"] must be "tdt", not ' "{0}".format( + tdt_params["TYPE"] + ) + ) + for key in ("TDT_DELAY", "TDT_TRIG_DELAY"): tdt_params[key] = int(tdt_params[key]) - if tdt_params['TDT_DELAY'] < 0: - raise ValueError('tdt_delay must be non-negative.') - self._model = tdt_params['TDT_MODEL'] - legal_models = ['RM1', 'RP2', 'RZ6', 'RP2legacy', 'dummy'] + if tdt_params["TDT_DELAY"] < 0: + raise ValueError("tdt_delay must be non-negative.") + self._model = tdt_params["TDT_MODEL"] + legal_models = ["RM1", "RP2", "RZ6", "RP2legacy", "dummy"] if self.model not in legal_models: - raise ValueError('TDT_MODEL="{0}" must be one of ' - '{1}'.format(self.model, legal_models)) - - if tdt_params['TDT_CIRCUIT_PATH'] is None and self.model != 'dummy': - cl = dict(RM1='RM1', RP2='RM1', RP2legacy='RP2legacy', RZ6='RZ6') - self._circuit = op.join(op.dirname(__file__), 'data', - 'expCircuitF32_' + cl[self._model] + - '.rcx') + raise ValueError( + f'TDT_MODEL="{self.model}" must be one of ' f"{legal_models}" + ) + + if tdt_params["TDT_CIRCUIT_PATH"] is None and self.model != "dummy": + cl = dict(RM1="RM1", RP2="RM1", RP2legacy="RP2legacy", RZ6="RZ6") + self._circuit = op.join( + op.dirname(__file__), + "data", + "expCircuitF32_" + cl[self._model] + ".rcx", + ) else: - self._circuit = tdt_params['TDT_CIRCUIT_PATH'] - if self.model != 'dummy' and not op.isfile(self._circuit): - raise IOError('Could not find file {}'.format(self._circuit)) - if tdt_params['TDT_INTERFACE'] is None: - tdt_params['TDT_INTERFACE'] = 'USB' - self._interface = tdt_params['TDT_INTERFACE'] + self._circuit = tdt_params["TDT_CIRCUIT_PATH"] + if self.model != "dummy" and not op.isfile(self._circuit): + raise OSError(f"Could not find file {self._circuit}") + if tdt_params["TDT_INTERFACE"] is None: + tdt_params["TDT_INTERFACE"] = "USB" + self._interface = tdt_params["TDT_INTERFACE"] self._n_channels = 2 # initialize RPcoX connection @@ -143,28 +166,33 @@ def __init__(self, tdt_params, ec): self.connection = self.rpcox.ConnectRM1(IntName=interface, DevNum=1) """ # MID-LEVEL APPROACH - if tdt_params['TDT_MODEL'] != 'dummy': + if tdt_params["TDT_MODEL"] != "dummy": from tdt.util import connect_rpcox - use_model = self.model if self.model != 'RP2legacy' else 'RP2' + + use_model = self.model if self.model != "RP2legacy" else "RP2" try: - self.rpcox = connect_rpcox(name=use_model, - interface=self.interface, - device_id=1, address=None) + self.rpcox = connect_rpcox( + name=use_model, interface=self.interface, device_id=1, address=None + ) except Exception as exp: - raise OSError('Could not connect to {}, is it turned on? ' - '(TDT message: "{}")'.format(self._model, exp)) + raise OSError( + f"Could not connect to {self._model}, is it turned on? " + f'(TDT message: "{exp}")' + ) else: - msg = ('TDT is in dummy mode. No sound or triggers will ' - 'be produced. Check TDT configuration and TDTpy ' - 'installation.') + msg = ( + "TDT is in dummy mode. No sound or triggers will " + "be produced. Check TDT configuration and TDTpy " + "installation." + ) logger.warning(msg) # log it warnings.warn(msg) # make it red self.rpcox = DummyRPcoX(self._model, self._interface) if self.rpcox is not None: - logger.info('Expyfun: RPcoX connection established') + logger.info("Expyfun: RPcoX connection established") else: - raise IOError('Problem initializing RPcoX.') + raise OSError("Problem initializing RPcoX.") """ # start zBUS (may be needed for devices other than RM1) self.zbus = connect_zbus(interface=interface) @@ -175,30 +203,31 @@ def __init__(self, tdt_params, ec): """ # load circuit if not self.rpcox.LoadCOF(self.circuit): - logger.debug('Expyfun: Problem loading circuit. Clearing...') + logger.debug("Expyfun: Problem loading circuit. Clearing...") try: if self.rpcox.ClearCOF(): - logger.debug('Expyfun: TDT circuit cleared') + logger.debug("Expyfun: TDT circuit cleared") time.sleep(0.25) if not self.rpcox.LoadCOF(self.circuit): - raise RuntimeError('Second loading attempt failed') + raise RuntimeError("Second loading attempt failed") except Exception: - raise IOError('Expyfun: Problem loading circuit.') - logger.info('Expyfun: Circuit loaded to {1} via {2}:\n{0}' - ''.format(self.circuit, self.model, self.interface)) + raise OSError("Expyfun: Problem loading circuit.") + logger.info( + f"Expyfun: Circuit loaded to {self.model} via {self.interface}:\n" + f"{self.circuit}" + ) # run circuit if self.rpcox.Run(): - logger.info('Expyfun: TDT circuit running') + logger.info("Expyfun: TDT circuit running") else: - raise SystemError('Expyfun: Problem starting TDT circuit.') + raise SystemError("Expyfun: Problem starting TDT circuit.") time.sleep(0.25) self._set_noise_corr() - self._set_delay(tdt_params['TDT_DELAY'], - tdt_params['TDT_TRIG_DELAY']) + self._set_delay(tdt_params["TDT_DELAY"], tdt_params["TDT_TRIG_DELAY"]) # Set output values to zero (esp. first few) - for tag in ('datainleft', 'datainright'): + for tag in ("datainleft", "datainright"): self.rpcox.ZeroTag(tag) - self.rpcox.SetTagVal('trgname', 0) + self.rpcox.SetTagVal("trgname", 0) self._used_params = tdt_params def _add_keyboard_init(self, ec, force_quit_keys): @@ -206,11 +235,11 @@ def _add_keyboard_init(self, ec, force_quit_keys): # do BaseKeyboard init last, to make sure circuit is running Keyboard.__init__(self, ec, force_quit_keys) -# ############################### AUDIO METHODS ############################### + # ############################### AUDIO METHODS ############################### def _set_noise_corr(self, val=0): """Helper to set the noise correlation, only -1, 0, 1 supported""" assert val in (-1, 0, 1) - self.rpcox.SetTagVal('noise_corr', int(val)) + self.rpcox.SetTagVal("noise_corr", int(val)) def load_buffer(self, data): """Load audio samples into TDT buffer. @@ -223,21 +252,20 @@ def load_buffer(self, data): """ assert data.dtype == np.float32 # Leave the first sample zero so on reset the output goes to zero - self.rpcox.WriteTagVEX('datainleft', 1, 'F32', data[:, 0]) - self.rpcox.WriteTagVEX('datainright', 1, 'F32', data[:, 1]) - self.rpcox.SetTagVal('nsamples', max(data.shape[0] + 1, 1)) + self.rpcox.WriteTagVEX("datainleft", 1, "F32", data[:, 0]) + self.rpcox.WriteTagVEX("datainright", 1, "F32", data[:, 1]) + self.rpcox.SetTagVal("nsamples", max(data.shape[0] + 1, 1)) def play(self): - """Send the soft trigger to start the ring buffer playback. - """ - self.rpcox.SetTagVal('trgname', 1) + """Send the soft trigger to start the ring buffer playback.""" + self.rpcox.SetTagVal("trgname", 1) self._trigger(1) - logger.debug('Expyfun: Starting TDT ring buffer') + logger.debug("Expyfun: Starting TDT ring buffer") @property def playing(self): """Is a sound currently playing""" - return bool(int(self.rpcox.GetTagVal('playing'))) + return bool(int(self.rpcox.GetTagVal("playing"))) def stop(self, wait=True): """Send the soft trigger to stop and reset the ring buffer playback. @@ -248,13 +276,12 @@ def stop(self, wait=True): Unused by the TDT. """ self._trigger(2) - logger.debug('Expyfun: Stopping TDT audio') + logger.debug("Expyfun: Stopping TDT audio") def start_noise(self): - """Send the soft trigger to start the noise generator. - """ + """Send the soft trigger to start the noise generator.""" self._trigger(3) - logger.debug('Expyfun: Starting TDT noise') + logger.debug("Expyfun: Starting TDT noise") def stop_noise(self, wait=True): """Send the soft trigger to stop the noise generator. @@ -265,7 +292,7 @@ def stop_noise(self, wait=True): Unused by the TDT. """ self._trigger(4) - logger.debug('Expyfun: Stopping TDT noise') + logger.debug("Expyfun: Stopping TDT noise") def set_noise_level(self, level): """Set the noise level. @@ -275,21 +302,21 @@ def set_noise_level(self, level): level : float The new level. """ - self.rpcox.SetTagVal('noiselev', level) + self.rpcox.SetTagVal("noiselev", level) def _set_delay(self, delay, delay_trig): - """Set the delay (in ms) of the system - """ + """Set the delay (in ms) of the system""" assert isinstance(delay, int) # this should never happen assert isinstance(delay_trig, int) - self.rpcox.SetTagVal('onsetdel', delay) - logger.info('Expyfun: Setting TDT delay to %s' % delay) - self.rpcox.SetTagVal('trigdel', delay_trig) - logger.info('Expyfun: Setting TDT trigger delay to %s' % delay_trig) - -# ############################### TRIGGER METHODS ############################# - def stamp_triggers(self, triggers, delay=None, wait_for_last=True, - is_trial_id=False): + self.rpcox.SetTagVal("onsetdel", delay) + logger.info("Expyfun: Setting TDT delay to %s" % delay) + self.rpcox.SetTagVal("trigdel", delay_trig) + logger.info("Expyfun: Setting TDT trigger delay to %s" % delay_trig) + + # ############################### TRIGGER METHODS ############################# + def stamp_triggers( + self, triggers, delay=None, wait_for_last=True, is_trial_id=False + ): """Stamp a list of triggers with a given inter-trigger delay. Parameters @@ -308,7 +335,7 @@ def stamp_triggers(self, triggers, delay=None, wait_for_last=True, if delay is None: delay = 0.02 # we have a fixed trig duration of 0.01 for ti, trig in enumerate(triggers): - self.rpcox.SetTagVal('trgname', trig) + self.rpcox.SetTagVal("trgname", trig) self._trigger(6) if ti < len(triggers) - 1 or wait_for_last: self.ec.wait_secs(delay) @@ -322,35 +349,34 @@ def _trigger(self, trig): Trigger number to send to TDT. """ if not self.rpcox.SoftTrg(trig): - logger.warning('SoftTrg failure for trigger: {}'.format(trig)) + logger.warning(f"SoftTrg failure for trigger: {trig}") -# ############################### KEYBOARD METHODS ############################ + # ############################### KEYBOARD METHODS ############################ def _get_timebase(self): - """Return time since circuit was started (in seconds). - """ - return self.rpcox.GetTagVal('masterclock') / float(self.fs) + """Return time since circuit was started (in seconds).""" + return self.rpcox.GetTagVal("masterclock") / float(self.fs) def _clear_events(self): - """Clear keyboard buffers. - """ + """Clear keyboard buffers.""" self._trigger(7) self._clear_keyboard_events() - def _retrieve_events(self, live_keys, type='presses'): - """Values and timestamps currently in keyboard buffer. - """ - if type != 'presses': + def _retrieve_events(self, live_keys, type="presses"): # noqa: A002 + """Values and timestamps currently in keyboard buffer.""" + if type != "presses": raise RuntimeError("TDT Cannot get key release events") # get values from the tdt - press_count = int(round(self.rpcox.GetTagVal('npressabs'))) + press_count = int(round(self.rpcox.GetTagVal("npressabs"))) if press_count > 0: # this one is indexed from zero - press_times = self.rpcox.ReadTagVEX('presstimesabs', 0, - press_count, 'I32', 'I32', 1) + press_times = self.rpcox.ReadTagVEX( + "presstimesabs", 0, press_count, "I32", "I32", 1 + ) # this one is indexed from one (silly) - press_vals = self.rpcox.ReadTagVEX('pressvalsabs', 1, press_count, - 'I32', 'I32', 1) + press_vals = self.rpcox.ReadTagVEX( + "pressvalsabs", 1, press_count, "I32", "I32", 1 + ) press_times = np.array(press_times[0], float) / self.fs press_vals = np.log2(np.array(press_vals[0], float)) + 1 press_vals = [str(int(round(p))) for p in press_vals] @@ -361,7 +387,7 @@ def _retrieve_events(self, live_keys, type='presses'): presses.extend(self._retrieve_keyboard_events([])) return presses - def _correct_presses(self, events, timestamp, relative_to, kind='presses'): + def _correct_presses(self, events, timestamp, relative_to, kind="presses"): """Correct timing of presses and check for quit press""" events = [(k, s + self.time_correction, kind) for k, s in events] self.log_presses(events) @@ -376,9 +402,9 @@ def _correct_presses(self, events, timestamp, relative_to, kind='presses'): def halt(self): """Wrapper for tdt.util.RPcoX.Halt().""" self.rpcox.Halt() - logger.debug('Expyfun: Halting TDT circuit') + logger.debug("Expyfun: Halting TDT circuit") -# ############################ READ-ONLY PROPERTIES ########################### + # ############################ READ-ONLY PROPERTIES ########################### @property def fs(self): """Playback frequency of the audio (samples / second).""" @@ -408,5 +434,11 @@ def get_tdt_rates(): rates : dict The sample rates. """ - return {'6k': 6103.515625, '12k': 12207.03125, '25k': 24414.0625, - '50k': 48828.125, '100k': 97656.25, '200k': 195312.5} + return { + "6k": 6103.515625, + "12k": 12207.03125, + "25k": 24414.0625, + "50k": 48828.125, + "100k": 97656.25, + "200k": 195312.5, + } diff --git a/expyfun/_trigger_controllers.py b/expyfun/_trigger_controllers.py index b735e846..24d43c98 100644 --- a/expyfun/_trigger_controllers.py +++ b/expyfun/_trigger_controllers.py @@ -6,12 +6,13 @@ # License: BSD (3-clause) import sys + import numpy as np -from ._utils import verbose_dec, string_types, logger +from ._utils import logger, verbose_dec -class ParallelTrigger(object): +class ParallelTrigger: """Parallel port and dummy triggering support. .. warning:: When using the parallel port, calling @@ -46,34 +47,42 @@ class ParallelTrigger(object): """ @verbose_dec - def __init__(self, mode='dummy', address=None, trigger_duration=0.01, - ec=None, verbose=None): + def __init__( + self, mode="dummy", address=None, trigger_duration=0.01, ec=None, verbose=None + ): self.ec = ec - if mode == 'parallel': - if sys.platform.startswith('linux'): - address = '/dev/parport0' if address is None else address - if not isinstance(address, string_types): - raise ValueError('addrss must be a string or None, got %s ' - 'of type %s' % (address, type(address))) + if mode == "parallel": + if sys.platform.startswith("linux"): + address = "/dev/parport0" if address is None else address + if not isinstance(address, str): + raise ValueError( + "address must be a string or None, got %s " + "of type %s" % (address, type(address)) + ) from parallel import Parallel - logger.info('Expyfun: Using address %s' % (address,)) + + logger.info("Expyfun: Using address %s" % (address,)) self._port = Parallel(address) self._portname = address self._set_data = self._port.setData - elif sys.platform.startswith('win'): + elif sys.platform.startswith("win"): from ctypes import windll - if not hasattr(windll, 'inpout32'): + + if not hasattr(windll, "inpout32"): raise SystemError( - 'Must have inpout32 installed, see:\n\n' - 'http://www.highrez.co.uk/downloads/inpout32/') + "Must have inpout32 installed, see:\n\n" + "http://www.highrez.co.uk/downloads/inpout32/" + ) - base = '0x378' if address is None else address - logger.info('Expyfun: Using base address %s' % (base,)) - if isinstance(base, string_types): + base = "0x378" if address is None else address + logger.info("Expyfun: Using base address %s" % (base,)) + if isinstance(base, str): base = int(base, 16) if not isinstance(base, int): - raise ValueError('address must be int or None, got %s of ' - 'type %s' % (base, type(base))) + raise ValueError( + "address must be int or None, got %s of " + "type %s" % (base, type(base)) + ) self._port = windll.inpout32 mask = np.uint8(1 << 5 | 1 << 6 | 1 << 7) # Use ECP to put the port into byte mode @@ -87,18 +96,20 @@ def __init__(self, mode='dummy', address=None, trigger_duration=0.01, self._set_data = lambda data: self._port.Out32(base, data) self._portname = str(base) else: - raise NotImplementedError('Parallel port triggering only ' - 'supported on Linux and Windows') + raise NotImplementedError( + "Parallel port triggering only " "supported on Linux and Windows" + ) else: # mode == 'dummy': self._port = self._portname = None self._trigger_list = list() - self._set_data = lambda x: (self._trigger_list.append(x) - if x != 0 else None) + self._set_data = lambda x: ( + self._trigger_list.append(x) if x != 0 else None + ) self.trigger_duration = trigger_duration self.mode = mode def __repr__(self): - return '' % (self.mode, self._portname) + return "" % (self.mode, self._portname) def _stamp_trigger(self, trig): """Fake stamping.""" @@ -106,8 +117,9 @@ def _stamp_trigger(self, trig): self.ec.wait_secs(self.trigger_duration) self._set_data(0) - def stamp_triggers(self, triggers, delay=None, wait_for_last=True, - is_trial_id=False): + def stamp_triggers( + self, triggers, delay=None, wait_for_last=True, is_trial_id=False + ): """Stamp a list of triggers with a given inter-trigger delay. Parameters @@ -132,7 +144,7 @@ def stamp_triggers(self, triggers, delay=None, wait_for_last=True, def close(self): """Release hardware interfaces.""" - if hasattr(self, '_port'): + if hasattr(self, "_port"): del self._port def __del__(self): @@ -160,17 +172,16 @@ def decimals_to_binary(decimals, n_bits): """ decimals = np.array(decimals, int) if decimals.ndim != 1 or (decimals < 0).any(): - raise ValueError('decimals must be 1D with all nonnegative values') + raise ValueError("decimals must be 1D with all nonnegative values") n_bits = np.array(n_bits, int) if decimals.shape != n_bits.shape: - raise ValueError('n_bits must have same shape as decimals') + raise ValueError("n_bits must have same shape as decimals") if (n_bits <= 0).any(): - raise ValueError('all n_bits must be positive') + raise ValueError("all n_bits must be positive") binary = list() for d, b in zip(decimals, n_bits): - if d > 2 ** b - 1: - raise ValueError('cannot convert number {0} using {1} bits' - ''.format(d, b)) + if d > 2**b - 1: + raise ValueError(f"cannot convert number {d} using {b} bits" "") binary.extend([int(bb) for bb in np.binary_repr(d, b)]) assert len(binary) == n_bits.sum() # make sure we didn't do something dumb return binary @@ -192,21 +203,23 @@ def binary_to_decimals(binary, n_bits): Array of integers. """ if not np.array_equal(binary, np.array(binary, bool)): - raise ValueError('binary must only contain zeros and ones') + raise ValueError("binary must only contain zeros and ones") binary = np.array(binary, bool) if binary.ndim != 1: - raise ValueError('binary must be 1 dimensional') + raise ValueError("binary must be 1 dimensional") n_bits = np.atleast_1d(n_bits).astype(int) if np.any(n_bits <= 0): - raise ValueError('n_bits must all be > 0') + raise ValueError("n_bits must all be > 0") if n_bits.sum() != len(binary): - raise ValueError('the sum of n_bits must be equal to the number of ' - 'elements in binary') + raise ValueError( + "the sum of n_bits must be equal to the number of " "elements in binary" + ) offset = 0 outs = [] for nb in n_bits: - outs.append(np.sum(binary[offset:offset + nb] * - (2 ** np.arange(nb - 1, -1, -1)))) + outs.append( + np.sum(binary[offset : offset + nb] * (2 ** np.arange(nb - 1, -1, -1))) + ) offset += nb assert offset == len(binary) return np.array(outs) diff --git a/expyfun/_utils.py b/expyfun/_utils.py index 412191d8..d40908a7 100644 --- a/expyfun/_utils.py +++ b/expyfun/_utils.py @@ -4,63 +4,48 @@ # # License: BSD (3-clause) -import warnings -import operator -from copy import deepcopy -import subprocess +import atexit +import datetime import importlib +import inspect +import json +import logging +import operator import os import os.path as op -import inspect +import ssl +import subprocess import sys -import time import tempfile +import time import traceback -import ssl -from shutil import rmtree -import atexit -import json +import warnings +from copy import deepcopy from functools import partial -import logging -import datetime -from timeit import default_timer as clock +from shutil import rmtree from threading import Timer +from timeit import default_timer as clock +from urllib.request import urlopen import numpy as np import scipy as sp - -from ._externals import decorator +from decorator import decorator # set this first thing to make sure it "takes" try: import pyglet - pyglet.options['debug_gl'] = False + + pyglet.options["debug_gl"] = False del pyglet except Exception: pass -# for py3k (eventually) -if sys.version.startswith('2'): - string_types = basestring # noqa - input = raw_input # noqa, input is raw_input in py3k - text_type = unicode # noqa - from __builtin__ import reload - from urllib2 import urlopen # noqa - from cStringIO import StringIO # noqa -else: - string_types = str - text_type = str - from urllib.request import urlopen - input = input - from io import StringIO # noqa, analysis:ignore - from importlib import reload # noqa, analysis:ignore - ############################################################################### # LOGGING EXP = 25 -logging.addLevelName(EXP, 'EXP') +logging.addLevelName(EXP, "EXP") def exp(self, message, *args, **kwargs): @@ -69,7 +54,7 @@ def exp(self, message, *args, **kwargs): logging.Logger.exp = exp -logger = logging.getLogger('expyfun') +logger = logging.getLogger("expyfun") def flush_logger(): @@ -94,26 +79,32 @@ def set_log_level(verbose=None, return_old_level=False): If True, return the old verbosity level. """ if verbose is None: - verbose = get_config('EXPYFUN_LOGGING_LEVEL', 'INFO') + verbose = get_config("EXPYFUN_LOGGING_LEVEL", "INFO") elif isinstance(verbose, bool): - verbose = 'INFO' if verbose is True else 'WARNING' - if isinstance(verbose, string_types): + verbose = "INFO" if verbose is True else "WARNING" + if isinstance(verbose, str): verbose = verbose.upper() - logging_types = dict(DEBUG=logging.DEBUG, INFO=logging.INFO, - WARNING=logging.WARNING, ERROR=logging.ERROR, - CRITICAL=logging.CRITICAL) + logging_types = dict( + DEBUG=logging.DEBUG, + INFO=logging.INFO, + WARNING=logging.WARNING, + ERROR=logging.ERROR, + CRITICAL=logging.CRITICAL, + ) if verbose not in logging_types: - raise ValueError('verbose must be of a valid type') + raise ValueError("verbose must be of a valid type") verbose = logging_types[verbose] old_verbose = logger.level logger.setLevel(verbose) - return (old_verbose if return_old_level else None) + return old_verbose if return_old_level else None -def set_log_file(fname=None, - output_format='%(asctime)s - %(levelname)-7s - %(message)s', - overwrite=None): +def set_log_file( + fname=None, + output_format="%(asctime)s - %(levelname)-7s - %(message)s", + overwrite=None, +): """Convenience function for setting the log to print to a file Parameters @@ -138,10 +129,12 @@ def set_log_file(fname=None, logger.removeHandler(h) if fname is not None: if op.isfile(fname) and overwrite is None: - warnings.warn('Log entries will be appended to the file. Use ' - 'overwrite=False to avoid this message in the ' - 'future.') - mode = 'w' if overwrite is True else 'a' + warnings.warn( + "Log entries will be appended to the file. Use " + "overwrite=False to avoid this message in the " + "future." + ) + mode = "w" if overwrite is True else "a" lh = logging.FileHandler(fname, mode=mode) else: """ we should just be able to do: @@ -158,9 +151,10 @@ def set_log_file(fname=None, ############################################################################### # RANDOM UTILITIES -building_doc = any('sphinx-build' in ((''.join(i[4]).lower() + i[1]) - if i[4] is not None else '') - for i in inspect.stack()) +building_doc = any( + "sphinx-build" in (("".join(i[4]).lower() + i[1]) if i[4] is not None else "") + for i in inspect.stack() +) def run_subprocess(command, **kwargs): @@ -176,7 +170,7 @@ def run_subprocess(command, **kwargs): command : list of str Command to run as subprocess (see subprocess.Popen documentation). **kwargs : objects - Keywoard arguments to pass to ``subprocess.Popen``. + Keyword arguments to pass to ``subprocess.Popen``. Returns ------- @@ -192,10 +186,13 @@ def run_subprocess(command, **kwargs): p = subprocess.Popen(command, **kw) stdout_, stderr = p.communicate() - output = (stdout_.decode(), stderr.decode()) + output = ( + stdout_.decode() if stdout_ else "", + stderr.decode() if stderr else "", + ) if p.returncode: err_fun = subprocess.CalledProcessError.__init__ - if 'output' in _get_args(err_fun): + if "output" in _get_args(err_fun): raise subprocess.CalledProcessError(p.returncode, command, output) else: raise subprocess.CalledProcessError(p.returncode, command) @@ -203,7 +200,7 @@ def run_subprocess(command, **kwargs): return output -class ZeroClock(object): +class ZeroClock: """Clock that uses "clock" function but starts at zero on init.""" def __init__(self): @@ -222,10 +219,10 @@ def date_str(): datestr : str The date string. """ - return str(datetime.datetime.today()).replace(':', '_') + return str(datetime.datetime.today()).replace(":", "_") -class WrapStdOut(object): +class WrapStdOut: """Ridiculous class to work around how doctest captures stdout.""" def __getattr__(self, name): @@ -258,7 +255,7 @@ def __init__(self): def cleanup(self): if self._del_after is True: if self._print_del is True: - print('Deleting {} ...'.format(self._path)) + print(f"Deleting {self._path} ...") rmtree(self._path, ignore_errors=True) @@ -270,10 +267,9 @@ def check_units(units): units : str Must be ``'norm'``, ``'deg'``, ``'pix'``, or ``'cm'``. """ - good_units = ['norm', 'pix', 'deg', 'cm'] + good_units = ["norm", "pix", "deg", "cm"] if units not in good_units: - raise ValueError('"units" must be one of {}, not {}' - ''.format(good_units, units)) + raise ValueError(f'"units" must be one of {good_units}, not {units}' "") ############################################################################### @@ -281,7 +277,8 @@ def check_units(units): # Following deprecated class copied from scikit-learn -class deprecated(object): + +class deprecated: """Decorator to mark a function or class as deprecated. Issue a warning when the function is called/the class is instantiated and @@ -305,14 +302,8 @@ class deprecated(object): # scikit-learn will not import on all platforms b/c it can be # sklearn or scikits.learn, so a self-contained example is used above - def __init__(self, extra=''): - """ - Parameters - ---------- - extra: string - to be added to the deprecation messages - - """ + def __init__(self, extra=""): + # extra: string to be added to the deprecation messages self.extra = extra def __call__(self, obj): @@ -333,9 +324,10 @@ def _decorate_class(self, cls): def wrapped(*args, **kwargs): warnings.warn(msg, category=DeprecationWarning) return init(*args, **kwargs) + cls.__init__ = wrapped - wrapped.__name__ = '__init__' + wrapped.__name__ = "__init__" wrapped.__doc__ = self._update_doc(init.__doc__) wrapped.deprecated_original = init @@ -366,20 +358,28 @@ def _update_doc(self, olddoc): return newdoc -if hasattr(inspect, 'signature'): # py35 +if hasattr(inspect, "signature"): # py35 + def _get_args(function, varargs=False): params = inspect.signature(function).parameters - args = [key for key, param in params.items() - if param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)] + args = [ + key + for key, param in params.items() + if param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD) + ] if varargs: - varargs = [param.name for param in params.values() - if param.kind == param.VAR_POSITIONAL] + varargs = [ + param.name + for param in params.values() + if param.kind == param.VAR_POSITIONAL + ] if len(varargs) == 0: varargs = None return args, varargs else: return args else: + def _get_args(function, varargs=False): out = inspect.getargspec(function) # args, varargs, keywords, defaults if varargs: @@ -407,13 +407,13 @@ def verbose_dec(function, *args, **kwargs): """ arg_names = _get_args(function) - if len(arg_names) > 0 and arg_names[0] == 'self': - default_level = getattr(args[0], 'verbose', None) + if len(arg_names) > 0 and arg_names[0] == "self": + default_level = getattr(args[0], "verbose", None) else: default_level = None - if('verbose' in arg_names): - verbose_level = args[arg_names.index('verbose')] + if "verbose" in arg_names: + verbose_level = args[arg_names.index("verbose")] else: verbose_level = default_level @@ -434,7 +434,8 @@ def verbose_dec(function, *args, **kwargs): def _new_pyglet(): import pyglet - return _compare_version(pyglet.version, '>=', '1.4') + + return _compare_version(pyglet.version, ">=", "1.4") def _has_video(raise_error=False): @@ -448,7 +449,7 @@ def _has_video(raise_error=False): good = False else: if raise_error: - print('Found FFmpegSource for new Pyglet') + print("Found FFmpegSource for new Pyglet") else: try: from pyglet.media.avbin import AVbinSource # noqa @@ -461,60 +462,72 @@ def _has_video(raise_error=False): good = False else: if raise_error: - print('Found AVbinSource for old Pyglet 1') + print("Found AVbinSource for old Pyglet 1") else: if raise_error: - print('Found AVbinSource for old Pyglet 2') + print("Found AVbinSource for old Pyglet 2") if raise_error and not good: - raise RuntimeError('Video support not enabled, got exception(s):\n' - '\n***********************\n'.join(exceptions)) + raise RuntimeError( + "Video support not enabled, got exception(s):\n" + "\n***********************\n".join(exceptions) + ) return good def requires_video(): """Require FFmpeg/AVbin.""" import pytest - return pytest.mark.skipif(not _has_video(), reason='Requires FFmpeg/AVbin') + + return pytest.mark.skipif(not _has_video(), reason="Requires FFmpeg/AVbin") def requires_opengl21(func): """Require OpenGL.""" - import pytest import pyglet.gl + import pytest + vendor = pyglet.gl.gl_info.get_vendor() version = pyglet.gl.gl_info.get_version() sufficient = pyglet.gl.gl_info.have_version(2, 0) - return pytest.mark.skipif(not sufficient, - reason='OpenGL too old: %s %s' - % (vendor, version,))(func) + return pytest.mark.skipif( + not sufficient, + reason="OpenGL too old: %s %s" + % ( + vendor, + version, + ), + )(func) def requires_lib(lib): """Requires lib decorator.""" import pytest + try: importlib.import_module(lib) except Exception as exp: val = True - reason = 'Needs %s (%s)' % (lib, exp) + reason = "Needs %s (%s)" % (lib, exp) else: val = False - reason = '' + reason = "" return pytest.mark.skipif(val, reason=reason) def _has_scipy_version(version): - return _compare_version(sp.__version__, '>=', version) + return _compare_version(sp.__version__, ">=", version) def _get_user_home_path(): """Return standard preferences path""" # this has been checked on OSX64, Linux64, and Win32 - val = os.getenv('APPDATA' if 'nt' == os.name.lower() else 'HOME', None) + val = os.getenv("APPDATA" if "nt" == os.name.lower() else "HOME", None) if val is None: - raise ValueError('expyfun config file path could ' - 'not be determined, please report this ' - 'error to expyfun developers') + raise ValueError( + "expyfun config file path could " + "not be determined, please report this " + "error to expyfun developers" + ) return val @@ -532,13 +545,13 @@ def fetch_data_file(fname): fname : str The filename on the local system where the file was downloaded. """ - path = get_config('EXPYFUN_DATA_PATH', op.join(_get_user_home_path(), - '.expyfun', 'data')) + path = get_config( + "EXPYFUN_DATA_PATH", op.join(_get_user_home_path(), ".expyfun", "data") + ) fname_out = op.join(path, fname) if not op.isdir(op.dirname(fname_out)): os.makedirs(op.dirname(fname_out)) - fname_url = ('https://github.com/LABSN/expyfun-data/raw/master/{0}' - ''.format(fname)) + fname_url = f"https://github.com/LABSN/expyfun-data/raw/master/{fname}" "" try: # until we get proper certificates context = ssl._create_unverified_context() @@ -548,7 +561,7 @@ def fetch_data_file(fname): this_urlopen = urlopen if not op.isfile(fname_out): try: - with open(fname_out, 'wb') as fid: + with open(fname_out, "wb") as fid: www = this_urlopen(fname_url, timeout=30.0) try: fid.write(www.read()) @@ -570,40 +583,41 @@ def get_config_path(): will be '%APPDATA%\.expyfun\expyfun.json'. On every other system, this will be $HOME/.expyfun/expyfun.json. """ - val = op.join(_get_user_home_path(), '.expyfun', 'expyfun.json') + val = op.join(_get_user_home_path(), ".expyfun", "expyfun.json") return val # List the known configuration values -known_config_types = ('RESPONSE_DEVICE', - 'AUDIO_CONTROLLER', - 'DB_OF_SINE_AT_1KHZ_1RMS', - 'EXPYFUN_EYELINK', - 'SOUND_CARD_API', - 'SOUND_CARD_API_OPTIONS', - 'SOUND_CARD_BACKEND', - 'SOUND_CARD_FS', - 'SOUND_CARD_NAME', - 'SOUND_CARD_FIXED_DELAY', - 'SOUND_CARD_TRIGGER_CHANNELS', - 'SOUND_CARD_TRIGGER_INSERTION', - 'SOUND_CARD_TRIGGER_SCALE', - 'SOUND_CARD_TRIGGER_ID_AFTER_ONSET', - 'SOUND_CARD_DRIFT_TRIGGER', - 'TDT_CIRCUIT_PATH', - 'TDT_DELAY', - 'TDT_INTERFACE', - 'TDT_MODEL', - 'TDT_TRIG_DELAY', - 'TRIGGER_CONTROLLER', - 'TRIGGER_ADDRESS', - 'WINDOW_SIZE', - 'SCREEN_NUM', - 'SCREEN_WIDTH', - 'SCREEN_DISTANCE', - 'SCREEN_SIZE_PIX', - 'EXPYFUN_LOGGING_LEVEL', - ) +known_config_types = ( + "RESPONSE_DEVICE", + "AUDIO_CONTROLLER", + "DB_OF_SINE_AT_1KHZ_1RMS", + "EXPYFUN_EYELINK", + "SOUND_CARD_API", + "SOUND_CARD_API_OPTIONS", + "SOUND_CARD_BACKEND", + "SOUND_CARD_FS", + "SOUND_CARD_NAME", + "SOUND_CARD_FIXED_DELAY", + "SOUND_CARD_TRIGGER_CHANNELS", + "SOUND_CARD_TRIGGER_INSERTION", + "SOUND_CARD_TRIGGER_SCALE", + "SOUND_CARD_TRIGGER_ID_AFTER_ONSET", + "SOUND_CARD_DRIFT_TRIGGER", + "TDT_CIRCUIT_PATH", + "TDT_DELAY", + "TDT_INTERFACE", + "TDT_MODEL", + "TDT_TRIG_DELAY", + "TRIGGER_CONTROLLER", + "TRIGGER_ADDRESS", + "WINDOW_SIZE", + "SCREEN_NUM", + "SCREEN_WIDTH", + "SCREEN_DISTANCE", + "SCREEN_SIZE_PIX", + "EXPYFUN_LOGGING_LEVEL", +) # These allow for partial matches: 'NAME_1' is okay key if 'NAME' is listed known_config_wildcards = () @@ -628,8 +642,8 @@ def get_config(key=None, default=None, raise_error=False): value : str | None The preference key value. """ - if key is not None and not isinstance(key, string_types): - raise ValueError('key must be a string') + if key is not None and not isinstance(key, str): + raise ValueError("key must be a string") # first, check to see if key is in env if key is not None and key in os.environ: @@ -641,7 +655,7 @@ def get_config(key=None, default=None, raise_error=False): key_found = False val = default else: - with open(config_path, 'r') as fid: + with open(config_path) as fid: config = json.load(fid) if key is None: return config @@ -651,13 +665,14 @@ def get_config(key=None, default=None, raise_error=False): if not key_found and raise_error is True: meth_1 = 'os.environ["%s"] = VALUE' % key meth_2 = 'expyfun.utils.set_config("%s", VALUE)' % key - raise KeyError('Key "%s" not found in environment or in the ' - 'expyfun config file:\n%s\nTry either:\n' - ' %s\nfor a temporary solution, or:\n' - ' %s\nfor a permanent one. You can also ' - 'set the environment variable before ' - 'running python.' - % (key, config_path, meth_1, meth_2)) + raise KeyError( + 'Key "%s" not found in environment or in the ' + "expyfun config file:\n%s\nTry either:\n" + " %s\nfor a temporary solution, or:\n" + " %s\nfor a permanent one. You can also " + "set the environment variable before " + "running python." % (key, config_path, meth_1, meth_2) + ) return val @@ -675,25 +690,27 @@ def set_config(key, value): """ if key is None: return sorted(known_config_types) - if not isinstance(key, string_types): - raise ValueError('key must be a string') + if not isinstance(key, str): + raise ValueError("key must be a string") # While JSON allow non-string types, we allow users to override config # settings using env, which are strings, so we enforce that here - if not isinstance(value, string_types) and value is not None: - raise ValueError('value must be a string or None') - if key not in known_config_types and not \ - any(k in key for k in known_config_wildcards): + if not isinstance(value, str) and value is not None: + raise ValueError("value must be a string or None") + if key not in known_config_types and not any( + k in key for k in known_config_wildcards + ): warnings.warn('Setting non-standard config type: "%s"' % key) # Read all previous values config_path = get_config_path() if op.isfile(config_path): - with open(config_path, 'r') as fid: + with open(config_path) as fid: config = json.load(fid) else: config = dict() - logger.info('Attempting to create new expyfun configuration ' - 'file:\n%s' % config_path) + logger.info( + "Attempting to create new expyfun configuration " "file:\n%s" % config_path + ) if value is None: config.pop(key, None) else: @@ -703,7 +720,7 @@ def set_config(key, value): directory = op.split(config_path)[0] if not op.isdir(directory): os.mkdir(directory) - with open(config_path, 'w') as fid: + with open(config_path, "w") as fid: json.dump(config, fid, sort_keys=True, indent=0) @@ -711,7 +728,7 @@ def set_config(key, value): # MISC -def fake_button_press(ec, button='1', delay=0.): +def fake_button_press(ec, button="1", delay=0.0): """Fake a button press after a delay Notes @@ -720,29 +737,34 @@ def fake_button_press(ec, button='1', delay=0.): It uses threads to ensure that control is passed back, so other commands can be called (like wait_for_presses). """ + def send(): ec._response_handler._on_pyglet_keypress(button, [], True) - Timer(delay, send).start() if delay > 0. else send() + Timer(delay, send).start() if delay > 0.0 else send() -def fake_mouse_click(ec, pos, button='left', delay=0.): + +def fake_mouse_click(ec, pos, button="left", delay=0.0): """Fake a mouse click after a delay""" button = dict(left=1, middle=2, right=4)[button] # trans to pyglet def send(): ec._mouse_handler._on_pyglet_mouse_click(pos[0], pos[1], button, []) - Timer(delay, send).start() if delay > 0. else send() + + Timer(delay, send).start() if delay > 0.0 else send() def _check_pyglet_version(raise_error=False): - """Check pyglet version, return True if usable. - """ + """Check pyglet version, return True if usable.""" import pyglet - is_usable = _compare_version(pyglet.version, '>=', '1.2') + + is_usable = _compare_version(pyglet.version, ">=", "1.2") if raise_error is True and is_usable is False: - raise ImportError('On Linux, you must run at least Pyglet ' - 'version 1.2, and you are running ' - '{0}'.format(pyglet.version)) + raise ImportError( + "On Linux, you must run at least Pyglet " + "version 1.2, and you are running " + f"{pyglet.version}" + ) return is_usable @@ -797,7 +819,7 @@ def running_rms(signal, win_length): # sig2 = signal * signal c1 = np.cumsum(sig2) - out = c1[win_length - 1:].copy() + out = c1[win_length - 1 :].copy() if len(out) == 0: # len(signal) < len(win_length) out = np.array([np.sqrt(c1[-1] / signal.size)]) else: @@ -830,21 +852,23 @@ def _fix_audio_dims(signal, n_channels): signal = np.asarray(np.atleast_2d(signal), dtype=np.float32) # Check dimensionality if signal.ndim != 2: - raise ValueError('Sound data must have one or two dimensions, got %s.' - % (signal.ndim,)) + raise ValueError( + "Sound data must have one or two dimensions, got %s." % (signal.ndim,) + ) # Return data with correct dimensions if n_channels == 2 and signal.shape[0] == 1: signal = np.tile(signal, (n_channels, 1)) if signal.shape[0] != n_channels: - raise ValueError('signal channel count %d did not match required ' - 'channel count %d' % (signal.shape[0], n_channels)) + raise ValueError( + "signal channel count %d did not match required " + "channel count %d" % (signal.shape[0], n_channels) + ) return signal def _sanitize(text_like): - """Cast as string, encode as UTF-8 and sanitize any escape characters. - """ - return text_type(text_like).encode('unicode_escape').decode('utf-8') + """Cast as string, encode as UTF-8 and sanitize any escape characters.""" + return str(text_like).encode("unicode_escape").decode("utf-8") def _sort_keys(x): @@ -855,7 +879,7 @@ def _sort_keys(x): return keys -def object_diff(a, b, pre=''): +def object_diff(a, b, pre=""): """Compute all differences between two python variables Parameters @@ -877,72 +901,76 @@ def object_diff(a, b, pre=''): ----- Taken from mne-python with permission. """ - out = '' + out = "" if type(a) != type(b): - out += pre + ' type mismatch (%s, %s)\n' % (type(a), type(b)) + out += pre + " type mismatch (%s, %s)\n" % (type(a), type(b)) elif isinstance(a, dict): k1s = _sort_keys(a) k2s = _sort_keys(b) m1 = set(k2s) - set(k1s) if len(m1): - out += pre + ' x1 missing keys %s\n' % (m1) + out += pre + " x1 missing keys %s\n" % (m1) for key in k1s: if key not in k2s: - out += pre + ' x2 missing key %s\n' % key + out += pre + " x2 missing key %s\n" % key else: - out += object_diff(a[key], b[key], pre + 'd1[%s]' % repr(key)) + out += object_diff(a[key], b[key], pre + "d1[%s]" % repr(key)) elif isinstance(a, (list, tuple)): if len(a) != len(b): - out += pre + ' length mismatch (%s, %s)\n' % (len(a), len(b)) + out += pre + " length mismatch (%s, %s)\n" % (len(a), len(b)) else: for xx1, xx2 in zip(a, b): - out += object_diff(xx1, xx2, pre='') - elif isinstance(a, (string_types, int, float, bytes)): + out += object_diff(xx1, xx2, pre="") + elif isinstance(a, (str, int, float, bytes)): if a != b: - out += pre + ' value mismatch (%s, %s)\n' % (a, b) + out += pre + " value mismatch (%s, %s)\n" % (a, b) elif a is None: if b is not None: - out += pre + ' a is None, b is not (%s)\n' % (b) + out += pre + " a is None, b is not (%s)\n" % (b) elif isinstance(a, np.ndarray): if not np.array_equal(a, b): - out += pre + ' array mismatch\n' + out += pre + " array mismatch\n" else: - raise RuntimeError(pre + ': unsupported type %s (%s)' % (type(a), a)) + raise RuntimeError(pre + ": unsupported type %s (%s)" % (type(a), a)) return out def _check_skip_backend(backend): - from expyfun._sound_controllers import _import_backend import pytest + + from expyfun._sound_controllers import _import_backend + if isinstance(backend, dict): # actually an AC - backend = backend['SOUND_CARD_BACKEND'] + backend = backend["SOUND_CARD_BACKEND"] try: _import_backend(backend) except Exception as exc: - pytest.skip('Skipping test for backend %s: %s' % (backend, exc)) + pytest.skip("Skipping test for backend %s: %s" % (backend, exc)) def _check_params(params, keys, defaults, name): if not isinstance(params, dict): - raise TypeError('{0} must be a dict, got type {1}' - .format(name, type(params))) + raise TypeError(f"{name} must be a dict, got type {type(params)}") params = deepcopy(params) if not isinstance(params, dict): - raise TypeError('{0} must be a dict, got {1}' - .format(name, type(params))) + raise TypeError(f"{name} must be a dict, got {type(params)}") # Set sensible defaults for values that are not passed for k in keys: params[k] = params.get(k, get_config(k, defaults.get(k, None))) # Check keys for k in params.keys(): if k not in keys: - raise KeyError('Unrecognized key in {0}["{1}"], must be ' - 'one of {2}'.format(name, k, ', '.join(keys))) + raise KeyError( + 'Unrecognized key in {0}["{1}"], must be ' "one of {2}".format( + name, k, ", ".join(keys) + ) + ) return params def _get_display(): import pyglet + try: display = pyglet.canvas.get_display() except AttributeError: # < 1.4 @@ -952,9 +980,6 @@ def _get_display(): # Adapted from MNE-Python def _compare_version(version_a, operator, version_b): - try: - from pkg_resources import parse_version as parse # noqa - except ImportError: - from distutils.version import LooseVersion as parse # noqa + from packaging.version import parse # noqa return eval(f'parse("{version_a}") {operator} parse("{version_b}")') diff --git a/expyfun/_version.py b/expyfun/_version.py index 8933c937..9fc20a7e 100644 --- a/expyfun/_version.py +++ b/expyfun/_version.py @@ -1 +1 @@ -__version__ = '2.0.0.dev0' +__version__ = "2.0.0.dev0" diff --git a/expyfun/analyze/__init__.py b/expyfun/analyze/__init__.py index c5edef1d..87792884 100644 --- a/expyfun/analyze/__init__.py +++ b/expyfun/analyze/__init__.py @@ -6,7 +6,6 @@ """ # -*- coding: utf-8 -*- -from ._analyze import (dprime, logit, sigmoid, fit_sigmoid, - rt_chisq, press_times_to_hmfc) +from ._analyze import dprime, logit, sigmoid, fit_sigmoid, rt_chisq, press_times_to_hmfc from ._viz import barplot, box_off, plot_screen, format_pval from ._recon import restore_values diff --git a/expyfun/analyze/_analyze.py b/expyfun/analyze/_analyze.py index f8119d29..defac1dc 100644 --- a/expyfun/analyze/_analyze.py +++ b/expyfun/analyze/_analyze.py @@ -1,19 +1,14 @@ -# -*- coding: utf-8 -*- -"""Analysis functions (mostly for psychophysics data). -""" +"""Analysis functions (mostly for psychophysics data).""" -from collections import namedtuple import warnings +from collections import namedtuple import numpy as np import scipy.stats as ss from scipy.optimize import curve_fit -from .._utils import string_types - -def press_times_to_hmfc(presses, targets, foils, tmin, tmax, - return_type='counts'): +def press_times_to_hmfc(presses, targets, foils, tmin, tmax, return_type="counts"): """Convert press times to hits/misses/FA/CR and reaction times Parameters @@ -58,15 +53,15 @@ def press_times_to_hmfc(presses, targets, foils, tmin, tmax, press by this function. However, there is no such de-bouncing of responses to "other" times. """ - known_types = ['counts', 'rts'] - if isinstance(return_type, string_types): + known_types = ["counts", "rts"] + if isinstance(return_type, str): singleton = True return_type = [return_type] else: singleton = False for r in return_type: - if not isinstance(r, string_types) or r not in known_types: - raise ValueError('r must be one of %s, got %s' % (known_types, r)) + if not isinstance(r, str) or r not in known_types: + raise ValueError("r must be one of %s, got %s" % (known_types, r)) # Sanity check that targets and foils don't overlap (due to tmin/tmax) targets = np.atleast_1d(targets) foils = np.atleast_1d(foils) @@ -77,7 +72,7 @@ def press_times_to_hmfc(presses, targets, foils, tmin, tmax, order = np.argsort(stim_times) stim_times = stim_times[order] if not np.all(stim_times[:-1] + tmax <= stim_times[1:] + tmin): - raise ValueError('Analysis windows for targets and foils overlap') + raise ValueError("Analysis windows for targets and foils overlap") # figure out what targ/mask times our presses correspond to press_to_stim = np.searchsorted(stim_times, presses - tmin) - 1 if len(press_to_stim) > 0: @@ -88,8 +83,7 @@ def press_times_to_hmfc(presses, targets, foils, tmin, tmax, assert (stim_times <= presses).all() # figure out which presses were valid (to target or masker) - valid_mask = ((presses >= stim_times + tmin) & - (presses <= stim_times + tmax)) + valid_mask = (presses >= stim_times + tmin) & (presses <= stim_times + tmax) n_other = np.sum(~valid_mask) press_to_stim = press_to_stim[valid_mask] presses = presses[valid_mask] @@ -105,14 +99,16 @@ def press_times_to_hmfc(presses, targets, foils, tmin, tmax, del used_map_idx # figure out which valid presses were to target or masker - target_mask = (order <= len(targets)) + target_mask = order <= len(targets) n_hit = np.sum(target_mask) n_fa = len(target_mask) - n_hit n_miss = len(targets) - n_hit n_cr = len(foils) - n_fa - outs = dict(counts=(n_hit, n_miss, n_fa, n_cr, n_other), - rts=(diffs[target_mask], diffs[~target_mask])) - assert outs['counts'][:4:2] == tuple(map(len, outs['rts'])) + outs = dict( + counts=(n_hit, n_miss, n_fa, n_cr, n_other), + rts=(diffs[target_mask], diffs[~target_mask]), + ) + assert outs["counts"][:4:2] == tuple(map(len, outs["rts"])) outs = tuple(outs[r] for r in return_type) if singleton: outs = outs[0] @@ -141,7 +137,7 @@ def logit(prop, max_events=None): """ prop = np.atleast_1d(prop).astype(float) if np.any([prop > 1, prop < 0]): - raise ValueError('Proportions must be in the range [0, 1].') + raise ValueError("Proportions must be in the range [0, 1].") if max_events is not None: # add equivalent of half an event to 0s, and subtract same from 1s max_events = np.atleast_1d(max_events) * np.ones_like(prop) @@ -150,10 +146,10 @@ def logit(prop, max_events=None): prop[loc] = corr_factor[loc] for loc in zip(*np.where(prop == 1)): prop[loc] = 1 - corr_factor[loc] - return np.log(prop / (1. - prop)) + return np.log(prop / (1.0 - prop)) -def sigmoid(x, lower=0., upper=1., midpt=0., slope=1.): +def sigmoid(x, lower=0.0, upper=1.0, midpt=0.0, slope=1.0): """Calculate sigmoidal values along the x-axis Parameters @@ -213,23 +209,21 @@ def fit_sigmoid(x, y, p0=None, fixed=()): # Initial estimates x = np.asarray(x) y = np.asarray(y) - k = 2 * 4. / (np.max(x) - np.min(x)) + k = 2 * 4.0 / (np.max(x) - np.min(x)) if p0 is None: p0 = [None] * 4 p0 = list(p0) - for ii, p in enumerate([np.min(y), np.max(y), - np.mean([np.max(x), np.min(x)]), k]): + for ii, p in enumerate([np.min(y), np.max(y), np.mean([np.max(x), np.min(x)]), k]): p0[ii] = p if p0[ii] is None else p0[ii] p0 = np.array(p0, dtype=np.float64) if p0.size != 4 or p0.ndim != 1: - raise ValueError('p0 must have 4 elements, or be None') + raise ValueError("p0 must have 4 elements, or be None") # Fixing values - p_types = ('lower', 'upper', 'midpt', 'slope') + p_types = ("lower", "upper", "midpt", "slope") for f in fixed: if f not in p_types: - raise ValueError('fixed {0} not in parameter list {1}' - ''.format(f, p_types)) + raise ValueError(f"fixed {f} not in parameter list {p_types}" "") fixed = np.array([(True if f in fixed else False) for f in p_types], bool) kwargs = dict() @@ -243,7 +237,7 @@ def fit_sigmoid(x, y, p0=None, fixed=()): idx.append(ii) p0 = p0[idx] if len(idx) == 0: - raise RuntimeError('cannot fit with all fixed values') + raise RuntimeError("cannot fit with all fixed values") def wrapper(*args): assert len(args) == len(keys) + 1 @@ -255,7 +249,7 @@ def wrapper(*args): assert len(idx) == len(out) for ii, o in zip(idx, out): kwargs[p_types[ii]] = o - return namedtuple('params', p_types)(**kwargs) + return namedtuple("params", p_types)(**kwargs) def rt_chisq(x, axis=None, warn=True): @@ -294,30 +288,29 @@ def rt_chisq(x, axis=None, warn=True): """ x = np.asarray(x) if np.any(np.less(x, 0)): # save the user some pain - raise ValueError('x cannot have negative values') + raise ValueError("x cannot have negative values") if axis is None: df, _, scale = ss.chi2.fit(x, floc=0) else: + def fit(x): return np.array(ss.chi2.fit(x, floc=0)) + params = np.apply_along_axis(fit, axis=axis, arr=x) # df, loc, scale - pmut = np.concatenate((np.atleast_1d(axis), - np.delete(np.arange(x.ndim), axis))) + pmut = np.concatenate((np.atleast_1d(axis), np.delete(np.arange(x.ndim), axis))) df = np.transpose(params, pmut)[0] scale = np.transpose(params, pmut)[2] quartiles = np.percentile(x, (25, 75)) whiskers = quartiles + np.array((-1.5, 1.5)) * np.diff(quartiles) - n_bad = np.sum(np.logical_or(np.less(x, whiskers[0]), - np.greater(x, whiskers[1]))) + n_bad = np.sum(np.logical_or(np.less(x, whiskers[0]), np.greater(x, whiskers[1]))) if n_bad > 0 and warn: - warnings.warn('{0} likely bad values in x (of {1})' - ''.format(n_bad, x.size)) + warnings.warn(f"{n_bad} likely bad values in x (of {x.size})" "") peak = np.maximum(0, (df - 2)) * scale return peak def dprime(hmfc, zero_correction=True, return_bias=False, two_interval=False): - u"""Estimate d′ and bias. + """Estimate d′ and bias. Parameters ---------- @@ -366,13 +359,11 @@ def dprime(hmfc, zero_correction=True, return_bias=False, two_interval=False): """ hmfc = _check_dprime_inputs(hmfc) a = 0.5 if zero_correction else 0.0 - z_hr = ss.norm.ppf((hmfc[..., 0] + a) / - (hmfc[..., 0] + hmfc[..., 1] + 2 * a)) - z_fr = ss.norm.ppf((hmfc[..., 2] + a) / - (hmfc[..., 2] + hmfc[..., 3] + 2 * a)) - cf = 1. / np.sqrt(2) if two_interval else 1. + z_hr = ss.norm.ppf((hmfc[..., 0] + a) / (hmfc[..., 0] + hmfc[..., 1] + 2 * a)) + z_fr = ss.norm.ppf((hmfc[..., 2] + a) / (hmfc[..., 2] + hmfc[..., 3] + 2 * a)) + cf = 1.0 / np.sqrt(2) if two_interval else 1.0 dp = cf * (z_hr - z_fr) - bias = (z_hr + z_fr) / -2. + bias = (z_hr + z_fr) / -2.0 return (dp, bias) if return_bias else dp @@ -386,10 +377,13 @@ def _check_dprime_inputs(hmfc): """ hmfc = np.asarray(hmfc) if hmfc.shape[-1] != 4: - raise ValueError('Array must have last dimension 4') + raise ValueError("Array must have last dimension 4") if hmfc.dtype not in (np.int64, np.int32): - warnings.warn('Argument (%s) to dprime() cast to np.int64; floating ' - 'point values will have been truncated.' % hmfc.dtype, - RuntimeWarning, stacklevel=3) + warnings.warn( + "Argument (%s) to dprime() cast to np.int64; floating " + "point values will have been truncated." % hmfc.dtype, + RuntimeWarning, + stacklevel=3, + ) hmfc = hmfc.astype(np.int64) return hmfc diff --git a/expyfun/analyze/_recon.py b/expyfun/analyze/_recon.py index e49cc2de..12afc1c6 100644 --- a/expyfun/analyze/_recon.py +++ b/expyfun/analyze/_recon.py @@ -1,5 +1,4 @@ -"""Functions for fixing data. -""" +"""Functions for fixing data.""" import numpy as np from scipy import linalg @@ -36,8 +35,10 @@ def restore_values(correct, other, idx): correct = np.array(correct, np.float64) other = np.array(other, np.float64) if correct.ndim != 1 or other.ndim != 1 or other.size > correct.size: - raise RuntimeError('correct and other must be 1D, and correct must ' - 'be at least as long as other') + raise RuntimeError( + "correct and other must be 1D, and correct must " + "be at least as long as other" + ) keep = np.ones(len(correct), bool) for ii in idx: keep[ii] = False @@ -49,7 +50,7 @@ def restore_values(correct, other, idx): X = np.dot(X, other) test = np.dot(np.array((np.ones_like(use), use)).T, X) if not np.allclose(other, test): # validate fit - raise RuntimeError('data could not be fit') + raise RuntimeError("data could not be fit") miss = correct[replace] vals = np.dot(np.array((np.ones_like(miss), miss)).T, X) out = np.zeros(len(correct), np.float64) diff --git a/expyfun/analyze/_viz.py b/expyfun/analyze/_viz.py index c83ee779..5bac4dd1 100644 --- a/expyfun/analyze/_viz.py +++ b/expyfun/analyze/_viz.py @@ -1,13 +1,11 @@ -"""Analysis visualization functions -""" +"""Analysis visualization functions""" -import numpy as np from itertools import chain -from .._utils import string_types +import numpy as np -def format_pval(pval, latex=True, scheme='default'): +def format_pval(pval, latex=True, scheme="default"): """Format a p-value using one of several schemes. Parameters @@ -36,38 +34,42 @@ def format_pval(pval, latex=True, scheme='default'): expon = np.trunc(np.log10(pval)).astype(int) # exponents pv = np.zeros_like(pval, dtype=object) if latex: - wrap = '$' - brk_l = '{{' - brk_r = '}}' + wrap = "$" + brk_l = "{{" + brk_r = "}}" else: - wrap = '' - brk_l = '' - brk_r = '' - if scheme == 'ross': # (exact value up to 4 decimal places) - pv[pval >= 0.0001] = [wrap + 'p = {:.4f}'.format(x) + wrap - for x in pval[pval > 0.0001]] - pv[pval < 0.0001] = [wrap + 'p < 10^' + brk_l + '{}'.format(x) + - brk_r + wrap for x in expon[pval < 0.0001]] - elif scheme == 'stars': - star = '{*}' if latex else '*' - pv[pval >= 0.05] = wrap + '' + wrap + wrap = "" + brk_l = "" + brk_r = "" + if scheme == "ross": # (exact value up to 4 decimal places) + pv[pval >= 0.0001] = [wrap + f"p = {x:.4f}" + wrap for x in pval[pval > 0.0001]] + pv[pval < 0.0001] = [ + wrap + "p < 10^" + brk_l + f"{x}" + brk_r + wrap + for x in expon[pval < 0.0001] + ] + elif scheme == "stars": + star = "{*}" if latex else "*" + pv[pval >= 0.05] = wrap + "" + wrap pv[pval < 0.05] = wrap + star + wrap pv[pval < 0.01] = wrap + star * 2 + wrap pv[pval < 0.001] = wrap + star * 3 + wrap else: # scheme == 'default' - pv[pval >= 0.05] = wrap + 'n.s.' + wrap - pv[pval < 0.05] = wrap + 'p < 0.05' + wrap - pv[pval < 0.01] = wrap + 'p < 0.01' + wrap - pv[pval < 0.001] = wrap + 'p < 0.001' + wrap - pv[pval < 0.0001] = [wrap + 'p < 10^' + brk_l + '{}'.format(x) + - brk_r + wrap for x in expon[pval < 0.0001]] + pv[pval >= 0.05] = wrap + "n.s." + wrap + pv[pval < 0.05] = wrap + "p < 0.05" + wrap + pv[pval < 0.01] = wrap + "p < 0.01" + wrap + pv[pval < 0.001] = wrap + "p < 0.001" + wrap + pv[pval < 0.0001] = [ + wrap + "p < 10^" + brk_l + f"{x}" + brk_r + wrap + for x in expon[pval < 0.0001] + ] if single_value: pv = pv[0] - return(pv) + return pv def _instantiate(obj, typ): - """Returns obj if obj is not None, else returns new instance of typ + """Return obj if obj is not None, else returns new instance of typ. + obj : an object An object (most likely one that a user passed into a function) that, if ``None``, should be initiated as an empty object of some other type. @@ -77,13 +79,31 @@ def _instantiate(obj, typ): return typ() if obj is None else obj -def barplot(h, axis=-1, ylim=None, err_bars=None, lines=False, - groups=None, eq_group_widths=False, gap_size=0.2, - brackets=None, bracket_text=None, bracket_inline=False, - bracket_group_lines=False, bar_names=None, group_names=None, - bar_kwargs=None, err_kwargs=None, line_kwargs=None, - bracket_kwargs=None, pval_kwargs=None, figure_kwargs=None, - smart_defaults=True, fname=None, ax=None): +def barplot( + h, + axis=-1, + ylim=None, + err_bars=None, + lines=False, + groups=None, + eq_group_widths=False, + gap_size=0.2, + brackets=None, + bracket_text=None, + bracket_inline=False, + bracket_group_lines=False, + bar_names=None, + group_names=None, + bar_kwargs=None, + err_kwargs=None, + line_kwargs=None, + bracket_kwargs=None, + pval_kwargs=None, + figure_kwargs=None, + smart_defaults=True, + fname=None, + ax=None, +): """Makes barplots w/ optional line overlays, grouping, & signif. brackets. Parameters @@ -198,7 +218,9 @@ def barplot(h, axis=-1, ylim=None, err_bars=None, lines=False, bracket color: dark gray (30%) """ - from matplotlib import pyplot as plt, rcParams + from matplotlib import pyplot as plt + from matplotlib import rcParams + try: from pandas.core.frame import DataFrame except Exception: @@ -210,24 +232,26 @@ def barplot(h, axis=-1, ylim=None, err_bars=None, lines=False, bar_names = h.columns.tolist() if axis == 0 else h.index.tolist() # check arg errors if gap_size < 0 or gap_size >= 1: - raise ValueError('Barplot argument "gap_size" must be in the range ' - '[0, 1).') + raise ValueError('Barplot argument "gap_size" must be in the range ' "[0, 1).") if err_bars is not None: - if isinstance(err_bars, string_types) and \ - err_bars not in ['sd', 'se', 'ci']: - raise ValueError('err_bars must be "sd", "se", or "ci" (or an ' - 'array of error bar magnitudes).') + if isinstance(err_bars, str) and err_bars not in ["sd", "se", "ci"]: + raise ValueError( + 'err_bars must be "sd", "se", or "ci" (or an ' + "array of error bar magnitudes)." + ) if brackets is not None: if any([len(x) != 2 for x in brackets]): - raise ValueError('Each top-level element of brackets must have ' - 'length 2.') + raise ValueError( + "Each top-level element of brackets must have " "length 2." + ) if not len(brackets) == len(bracket_text): - raise ValueError('Mismatch between number of brackets and bracket ' - 'labels.') + raise ValueError( + "Mismatch between number of brackets and bracket " "labels." + ) # handle single-element args - if isinstance(bracket_text, string_types): + if isinstance(bracket_text, str): bracket_text = [bracket_text] - if isinstance(group_names, string_types): + if isinstance(group_names, str): group_names = [group_names] # arg defaults: if arg is None, instantiate as given type brackets = _instantiate(brackets, list) @@ -239,20 +263,20 @@ def barplot(h, axis=-1, ylim=None, err_bars=None, lines=False, bracket_kwargs = _instantiate(bracket_kwargs, dict) # user-supplied Axes if ax is not None: - bar_kwargs['axes'] = ax + bar_kwargs["axes"] = ax # smart defaults if smart_defaults: - if 'color' not in bar_kwargs.keys(): - bar_kwargs['color'] = '0.7' - if 'color' not in line_kwargs.keys(): - line_kwargs['color'] = 'k' - if 'ecolor' not in err_kwargs.keys(): - err_kwargs['ecolor'] = 'k' - if 'color' not in bracket_kwargs.keys(): - bracket_kwargs['color'] = '0.3' + if "color" not in bar_kwargs.keys(): + bar_kwargs["color"] = "0.7" + if "color" not in line_kwargs.keys(): + line_kwargs["color"] = "k" + if "ecolor" not in err_kwargs.keys(): + err_kwargs["ecolor"] = "k" + if "color" not in bracket_kwargs.keys(): + bracket_kwargs["color"] = "0.3" # fix bar alignment (defaults to 'center' in more recent versions of MPL) - if 'align' not in bar_kwargs.keys(): - bar_kwargs['align'] = 'edge' + if "align" not in bar_kwargs.keys(): + bar_kwargs["align"] = "edge" # parse heights h = np.array(h) if len(h.shape) > 2: @@ -265,54 +289,61 @@ def barplot(h, axis=-1, ylim=None, err_bars=None, lines=False, groups = [list(x) for x in groups] # forgive list/tuple mix-ups # calculate bar positions non_gap = 1 - gap_size - offset = gap_size / 2. + offset = gap_size / 2.0 if eq_group_widths: group_sizes = np.array([float(len(_grp)) for _grp in groups], int) group_widths = [non_gap for _ in groups] group_edges = [offset + _ix for _ix in range(len(groups))] group_ixs = list(chain.from_iterable([range(x) for x in group_sizes])) - bar_widths = np.repeat(np.array(group_widths) / group_sizes, - group_sizes).tolist() - bar_edges = (np.repeat(group_edges, group_sizes) + - bar_widths * np.array(group_ixs)).tolist() + bar_widths = np.repeat( + np.array(group_widths) / group_sizes, group_sizes + ).tolist() + bar_edges = ( + np.repeat(group_edges, group_sizes) + bar_widths * np.array(group_ixs) + ).tolist() else: bar_widths = [[non_gap for _ in _grp] for _grp in groups] # next line: offset + cumul. gap widths + cumul. bar widths - bar_edges = [[offset + _ix * gap_size + _bar * non_gap - for _bar in _grp] for _ix, _grp in enumerate(groups)] + bar_edges = [ + [offset + _ix * gap_size + _bar * non_gap for _bar in _grp] + for _ix, _grp in enumerate(groups) + ] group_widths = [np.sum(_width) for _width in bar_widths] group_edges = [_edge[0] for _edge in bar_edges] bar_edges = list(chain.from_iterable(bar_edges)) bar_widths = list(chain.from_iterable(bar_widths)) - bar_centers = np.array(bar_edges) + np.array(bar_widths) / 2. - group_centers = np.array(group_edges) + np.array(group_widths) / 2. + bar_centers = np.array(bar_edges) + np.array(bar_widths) / 2.0 + group_centers = np.array(group_edges) + np.array(group_widths) / 2.0 # calculate error bars err = np.zeros(num_bars) # default if no err_bars if err_bars is not None: if h.ndim == 2: - if err_bars == 'sd': # sample standard deviation + if err_bars == "sd": # sample standard deviation err = h.std(axis) - elif err_bars == 'se': # standard error + elif err_bars == "se": # standard error err = h.std(axis) / np.sqrt(h.shape[axis]) else: # 95% conf int err = 1.96 * h.std(axis) / np.sqrt(h.shape[axis]) else: # h.ndim == 1 - if isinstance(err_bars, string_types): - raise ValueError('string arguments to "err_bars" ignored when ' - '"h" has fewer than 2 dimensions.') + if isinstance(err_bars, str): + raise ValueError( + 'string arguments to "err_bars" ignored when ' + '"h" has fewer than 2 dimensions.' + ) elif not h.shape == np.array(err_bars).shape: - raise ValueError('When "err_bars" is array-like it must have ' - 'the same shape as "h".') + raise ValueError( + 'When "err_bars" is array-like it must have ' + 'the same shape as "h".' + ) err = np.atleast_1d(err_bars) - bar_kwargs['yerr'] = err + bar_kwargs["yerr"] = err # plot (bars and error bars) if ax is None: plt.figure(**figure_kwargs) p = plt.subplot(111) else: p = ax - b = p.bar(bar_edges, heights, bar_widths, error_kw=err_kwargs, - **bar_kwargs) + b = p.bar(bar_edges, heights, bar_widths, error_kw=err_kwargs, **bar_kwargs) # plot within-subject lines if lines: _h = h if axis == 0 else h.T @@ -326,41 +357,53 @@ def barplot(h, axis=-1, ylim=None, err_bars=None, lines=False, brk_min_h = np.diff(p.get_ylim()) * 0.05 # temporarily plot a textbox to get its height t = plt.annotate(bracket_text[0], (0, 0), **pval_kwargs) - t.set_bbox(dict(boxstyle='round, pad=0.25')) + t.set_bbox(dict(boxstyle="round, pad=0.25")) plt.draw() bb = t.get_bbox_patch().get_window_extent() - txth = np.diff(p.transData.inverted().transform(bb), - axis=0).ravel()[-1] + txth = np.diff(p.transData.inverted().transform(bb), axis=0).ravel()[-1] if bracket_inline: - txth = txth / 2. + txth = txth / 2.0 t.remove() # find highest points if lines and h.ndim == 2: # brackets must be above lines & error bars - apex = np.amax(np.r_[np.atleast_2d(heights + err), - np.atleast_2d(np.amax(h, axis))], axis=0) + apex = np.amax( + np.r_[np.atleast_2d(heights + err), np.atleast_2d(np.amax(h, axis))], + axis=0, + ) else: apex = np.atleast_1d(heights + err) apex = np.maximum(apex, 0) # for negative-going bars apex = apex + brk_offset gr_apex = np.array([np.amax(apex[_g]) for _g in groups]) # boolean for whether each half of a bracket is a group - is_group = [[hasattr(_b, 'append') for _b in _br] for _br in brackets] + is_group = [[hasattr(_b, "append") for _b in _br] for _br in brackets] # bracket left & right coords - brk_lr = [[group_centers[groups.index(_ix)] if _g - else bar_centers[_ix] for _ix, _g in zip(_brk, _isg)] - for _brk, _isg in zip(brackets, is_group)] + brk_lr = [ + [ + group_centers[groups.index(_ix)] if _g else bar_centers[_ix] + for _ix, _g in zip(_brk, _isg) + ] + for _brk, _isg in zip(brackets, is_group) + ] # bracket L/R midpoints (label position) brk_c = [np.mean(_lr) for _lr in brk_lr] # bracket bottom coords (first pass) - brk_b = [[gr_apex[groups.index(_ix)] if _g else apex[_ix] - for _ix, _g in zip(_brk, _isg)] - for _brk, _isg in zip(brackets, is_group)] + brk_b = [ + [ + gr_apex[groups.index(_ix)] if _g else apex[_ix] + for _ix, _g in zip(_brk, _isg) + ] + for _brk, _isg in zip(brackets, is_group) + ] # main bracket positioning loop brk_t = [] for _ix, (_brk, _isg) in enumerate(zip(brackets, is_group)): # which bars does this bracket span? - spanned_bars = list(chain.from_iterable( - [_b if hasattr(_b, 'append') else [_b] for _b in _brk])) + spanned_bars = list( + chain.from_iterable( + [_b if hasattr(_b, "append") else [_b] for _b in _brk] + ) + ) spanned_bars = range(min(spanned_bars), max(spanned_bars) + 1) # raise apex a bit extra if prev bracket label centered on bar prev_label_pos = brk_c[_ix - 1] if _ix else -1 @@ -375,41 +418,42 @@ def barplot(h, axis=-1, ylim=None, err_bars=None, lines=False, apex[label_bar_more] += txth gr_apex = np.array([np.amax(apex[_g]) for _g in groups]) # recalc lower tips of bracket: apex / gr_apex may have changed - brk_b[_ix] = [gr_apex[groups.index(_b)] if _g else apex[_b] - for _b, _g in zip(_brk, _isg)] + brk_b[_ix] = [ + gr_apex[groups.index(_b)] if _g else apex[_b] + for _b, _g in zip(_brk, _isg) + ] # calculate top span position _min_t = max(apex[spanned_bars]) + brk_min_h brk_t.append(_min_t) # raise apex on spanned bars to account for bracket - apex[spanned_bars] = np.maximum(apex[spanned_bars], - _min_t) + brk_offset + apex[spanned_bars] = np.maximum(apex[spanned_bars], _min_t) + brk_offset gr_apex = np.array([np.amax(apex[_g]) for _g in groups]) # draw horz line spanning groups if desired if bracket_group_lines: for _brk, _isg, _blr in zip(brackets, is_group, brk_b): for _bk, _g, _b in zip(_brk, _isg, _blr): if _g: - _lr = [bar_centers[_ix] - for _ix in groups[groups.index(_bk)]] + _lr = [bar_centers[_ix] for _ix in groups[groups.index(_bk)]] _lr = (min(_lr), max(_lr)) p.plot(_lr, (_b, _b), **bracket_kwargs) # draw (left, right, bottom-left, bottom-right, top, center, string) - for ((_l, _r), (_bl, _br), _t, _c, _s) in zip(brk_lr, brk_b, brk_t, - brk_c, bracket_text): + for (_l, _r), (_bl, _br), _t, _c, _s in zip( + brk_lr, brk_b, brk_t, brk_c, bracket_text + ): # bracket text - _t = float(_t) # on newer Pandas it can be shape (1,) - defaults = dict(ha='center', annotation_clip=False, - textcoords='offset points') + _t = np.array(_t).item() # on newer Pandas it can be shape (1,) + defaults = dict( + ha="center", annotation_clip=False, textcoords="offset points" + ) for k, v in defaults.items(): if k not in pval_kwargs.keys(): pval_kwargs[k] = v - if 'va' not in pval_kwargs.keys(): - pval_kwargs['va'] = 'center' if bracket_inline else 'baseline' - if 'xytext' not in pval_kwargs.keys(): - pval_kwargs['xytext'] = (0, 0) if bracket_inline else (0, 2) + if "va" not in pval_kwargs.keys(): + pval_kwargs["va"] = "center" if bracket_inline else "baseline" + if "xytext" not in pval_kwargs.keys(): + pval_kwargs["xytext"] = (0, 0) if bracket_inline else (0, 2) txt = p.annotate(_s, (_c, _t), **pval_kwargs) - txt.set_bbox(dict(facecolor='w', alpha=0, - boxstyle='round, pad=0.2')) + txt.set_bbox(dict(facecolor="w", alpha=0, boxstyle="round, pad=0.2")) plt.draw() # bracket lines lline = ((_l, _l), (_bl, _t)) @@ -417,10 +461,9 @@ def barplot(h, axis=-1, ylim=None, err_bars=None, lines=False, tline = ((_l, _r), (_t, _t)) if bracket_inline: bb = txt.get_bbox_patch().get_window_extent() - txtw = np.diff(p.transData.inverted().transform(bb), - axis=0).ravel()[0] - _m = _c - txtw / 2. - _n = _c + txtw / 2. + txtw = np.diff(p.transData.inverted().transform(bb), axis=0).ravel()[0] + _m = _c - txtw / 2.0 + _n = _c + txtw / 2.0 tline = [((_l, _m), (_t, _t)), ((_n, _r), (_t, _t))] else: tline = [((_l, _r), (_t, _t))] @@ -432,17 +475,23 @@ def barplot(h, axis=-1, ylim=None, err_bars=None, lines=False, p.set_ybound(ybnd[0], _t + txth) # annotation box_off(p) - p.tick_params(axis='x', length=0, pad=12) + p.tick_params(axis="x", length=0, pad=12) p.xaxis.set_ticks(bar_centers) if bar_names is not None: - p.xaxis.set_ticklabels(bar_names, va='baseline') + p.xaxis.set_ticklabels(bar_names, va="baseline") if group_names is not None: ymin = ylim[0] if ylim is not None else p.get_ylim()[0] - yoffset = -2.5 * rcParams['font.size'] + yoffset = -2.5 * rcParams["font.size"] for gn, gp in zip(group_names, group_centers): - p.annotate(gn, xy=(gp, ymin), xytext=(0, yoffset), - xycoords='data', textcoords='offset points', - ha='center', va='baseline') + p.annotate( + gn, + xy=(gp, ymin), + xytext=(0, yoffset), + xycoords="data", + textcoords="offset points", + ha="center", + va="baseline", + ) # axis limits p.set_xlim(0, bar_edges[-1] + bar_widths[-1] + gap_size / 2) if ylim is not None: @@ -450,6 +499,7 @@ def barplot(h, axis=-1, ylim=None, err_bars=None, lines=False, # output file if fname is not None: from os.path import splitext + fmt = splitext(fname)[-1][1:] plt.savefig(fname, format=fmt, transparent=True) # return handles for subplot and barplot instances @@ -467,10 +517,10 @@ def box_off(ax): """ ax.get_xaxis().tick_bottom() ax.get_yaxis().tick_left() - ax.tick_params(axis='x', direction='out') - ax.tick_params(axis='y', direction='out') - ax.spines['right'].set_color('none') - ax.spines['top'].set_color('none') + ax.tick_params(axis="x", direction="out") + ax.tick_params(axis="y", direction="out") + ax.spines["right"].set_color("none") + ax.spines["top"].set_color("none") def plot_screen(screen, ax=None): @@ -490,12 +540,13 @@ def plot_screen(screen, ax=None): The axes used to plot the image. """ import matplotlib.pyplot as plt + screen = np.array(screen) if screen.ndim != 3 or screen.shape[2] not in [3, 4]: - raise ValueError('screen must be a 3D array with 3 or 4 channels') + raise ValueError("screen must be a 3D array with 3 or 4 channels") if ax is None: plt.figure() ax = plt.axes([0, 0, 1, 1]) ax.imshow(screen) - ax.axis('off') + ax.axis("off") return ax diff --git a/expyfun/analyze/tests/test_analyze_functions.py b/expyfun/analyze/tests/test_analyze_functions.py index 81cefb7a..f7db6345 100644 --- a/expyfun/analyze/tests/test_analyze_functions.py +++ b/expyfun/analyze/tests/test_analyze_functions.py @@ -2,6 +2,7 @@ import pytest from numpy.testing import assert_allclose, assert_array_equal + try: from scipy.special import logit as splogit except ImportError: @@ -18,10 +19,9 @@ def assert_rts_equal(actual, desired): assert isinstance(desired, (list, tuple)) assert len(actual) == 2 assert len(desired) == 2 - kinds = ['hits', 'fas'] + kinds = ["hits", "fas"] for act, des, kind in zip(actual, desired, kinds): - assert_allclose(act, des, atol=1e-7, - err_msg='{0} mismatch'.format(kind)) + assert_allclose(act, des, atol=1e-7, err_msg=f"{kind} mismatch") def assert_hmfc(presses, targets, foils, hmfco, rts, tmin=0.1, tmax=0.6): @@ -30,14 +30,16 @@ def assert_hmfc(presses, targets, foils, hmfco, rts, tmin=0.1, tmax=0.6): assert_array_equal(out, hmfco) out = ea.press_times_to_hmfc(presses, targets, foils, tmin, tmax) assert_array_equal(out, hmfco) - out = ea.press_times_to_hmfc(presses, targets, foils, tmin, tmax, - return_type=['counts', 'rts']) + out = ea.press_times_to_hmfc( + presses, targets, foils, tmin, tmax, return_type=["counts", "rts"] + ) assert_array_equal(out[0][:4:2], list(map(len, out[1]))) assert_array_equal(out[0], hmfco) assert_rts_equal(out[1], rts) # reversing targets and foils - out = ea.press_times_to_hmfc(presses, foils, targets, tmin, tmax, - return_type=['counts', 'rts']) + out = ea.press_times_to_hmfc( + presses, foils, targets, tmin, tmax, return_type=["counts", "rts"] + ) assert_array_equal(out[0], np.array(hmfco)[[2, 3, 0, 1, 4]]) assert_rts_equal(out[1], rts[::-1]) @@ -45,7 +47,7 @@ def assert_hmfc(presses, targets, foils, hmfco, rts, tmin=0.1, tmax=0.6): def test_presses_to_hmfc(): """Test converting press times to HMFCO and RTs.""" # Simple example - targets = [0., 1.] + targets = [0.0, 1.0] foils = [0.5, 1.5] presses = [0.1, 1.6] # presses right at tmin/tmax @@ -76,8 +78,8 @@ def test_presses_to_hmfc(): # A complicated example: multiple preses to targ targets = [0, 2, 3] foils = [1, 4] - tmin, tmax = 0., 0.5 - presses = [0.111, 0.2, 1.101, 1.3, 2.222, 2.333, 2.7, 5.] + tmin, tmax = 0.0, 0.5 + presses = [0.111, 0.2, 1.101, 1.3, 2.222, 2.333, 2.7, 5.0] hmfco = [2, 1, 1, 1, 2] rts = [[0.111, 0.222], [0.101]] assert_hmfc(presses, targets, foils, hmfco, rts) @@ -95,35 +97,37 @@ def test_presses_to_hmfc(): # lots of presses targets = [1, 2, 5, 6, 7] foils = [0, 3, 4, 8] - presses = [0.201, 2.101, 4.202, 5.102, 6.103, 10.] + presses = [0.201, 2.101, 4.202, 5.102, 6.103, 10.0] hmfco = [3, 2, 2, 2, 1] rts = [[0.101, 0.102, 0.103], [0.201, 0.202]] assert_hmfc(presses, targets, foils, hmfco, rts) # Bad inputs - pytest.raises(ValueError, ea.press_times_to_hmfc, - presses, targets, foils, tmin, 1.1) - pytest.raises(ValueError, ea.press_times_to_hmfc, - presses, targets, foils, tmin, tmax, 'foo') + pytest.raises( + ValueError, ea.press_times_to_hmfc, presses, targets, foils, tmin, 1.1 + ) + pytest.raises( + ValueError, ea.press_times_to_hmfc, presses, targets, foils, tmin, tmax, "foo" + ) def test_dprime(): """Test dprime accuracy.""" - with pytest.warns(RuntimeWarning, match='cast to'): - pytest.raises(IndexError, ea.dprime, 'foo') - pytest.raises(ValueError, ea.dprime, ['foo', 0, 0, 0]) - with pytest.warns(RuntimeWarning, match='truncated'): + with pytest.warns(RuntimeWarning, match="cast to"): + pytest.raises(IndexError, ea.dprime, "foo") + pytest.raises(ValueError, ea.dprime, ["foo", 0, 0, 0]) + with pytest.warns(RuntimeWarning, match="truncated"): ea.dprime((1.1, 0, 0, 0)) for resp, want in ( - ((1, 1, 1, 1), [0, 0]), - ((1, 0, 0, 1), [1.34898, 0.]), - ((0, 1, 0, 1), [0, 0.67449]), - ((0, 0, 1, 1), [0, 0]), - ((1, 0, 1, 0), [0, -0.67449]), - ((0, 1, 1, 0), [-1.34898, 0.]), - ((0, 1, 1, 0), [-1.34898, 0.])): - assert_allclose(ea.dprime(resp, return_bias=True), - want, atol=1e-5) + ((1, 1, 1, 1), [0, 0]), + ((1, 0, 0, 1), [1.34898, 0.0]), + ((0, 1, 0, 1), [0, 0.67449]), + ((0, 0, 1, 1), [0, 0]), + ((1, 0, 1, 0), [0, -0.67449]), + ((0, 1, 1, 0), [-1.34898, 0.0]), + ((0, 1, 1, 0), [-1.34898, 0.0]), + ): + assert_allclose(ea.dprime(resp, return_bias=True), want, atol=1e-5) assert_allclose([np.inf, -np.inf], ea.dprime((1, 0, 2, 1), False, True)) pytest.raises(ValueError, ea.dprime, np.ones((5, 4, 3))) pytest.raises(ValueError, ea.dprime, (1, 2, 3)) @@ -135,7 +139,7 @@ def test_logit(): """Test logit calculations.""" pytest.raises(ValueError, ea.logit, 2) # On some versions, this throws warnings about divide-by-zero - with np.errstate(divide='ignore'): + with np.errstate(divide="ignore"): assert ea.logit(0) == -np.inf assert ea.logit(1) == np.inf assert ea.logit(1, max_events=1) < np.inf @@ -156,14 +160,14 @@ def test_sigmoid(): """Test sigmoidal fitting and generation.""" n_pts = 1000 x = np.random.RandomState(0).randn(n_pts) - p0 = (0., 1., 0., 1.) + p0 = (0.0, 1.0, 0.0, 1.0) y = ea.sigmoid(x, *p0) assert np.all(np.logical_and(y <= 1, y >= 0)) p = ea.fit_sigmoid(x, y) assert_allclose(p, p0, atol=1e-4, rtol=1e-4) with warnings.catch_warnings(record=True): # scipy convergence - warnings.simplefilter('ignore') - p = ea.fit_sigmoid(x, y, (0, 1, None, None), ('upper', 'lower')) + warnings.simplefilter("ignore") + p = ea.fit_sigmoid(x, y, (0, 1, None, None), ("upper", "lower")) assert_allclose(p, p0, atol=1e-4, rtol=1e-4) y += np.random.rand(n_pts) * 0.01 @@ -175,13 +179,13 @@ def test_rt_chisq(): """Test reaction time chi-square fitting.""" # 1D should return single float foo = np.random.RandomState(0).rand(30) - pytest.raises(ValueError, ea.rt_chisq, foo - 1.) + pytest.raises(ValueError, ea.rt_chisq, foo - 1.0) assert_equal(np.array(ea.rt_chisq(foo, warn=False)).shape, ()) # 2D should return array with shape of input but without ``axis`` dimension foo = np.random.rand(30).reshape((2, 3, 5)) for axis in range(-1, foo.ndim): bar = ea.rt_chisq(foo, axis=axis, warn=False) assert_array_equal(np.delete(foo.shape, axis), np.array(bar.shape)) - foo_bad = np.concatenate((np.random.rand(30), [100.])) - with pytest.warns(UserWarning, match='likely bad'): + foo_bad = np.concatenate((np.random.rand(30), [100.0])) + with pytest.warns(UserWarning, match="likely bad"): bar = ea.rt_chisq(foo_bad) diff --git a/expyfun/analyze/tests/test_recon.py b/expyfun/analyze/tests/test_recon.py index e42187c9..35ae0c53 100644 --- a/expyfun/analyze/tests/test_recon.py +++ b/expyfun/analyze/tests/test_recon.py @@ -5,8 +5,7 @@ def test_restore(): - """Test restoring missing values - """ + """Test restoring missing values""" n = 20 x = np.arange(n, dtype=float) y = x * 10 - 1.5 @@ -14,7 +13,7 @@ def test_restore(): keep[[0, 4, -1]] = False missing = np.where(~keep)[0] keep = np.where(keep)[0] - y = x[keep] * 10. - 1.5 + y = x[keep] * 10.0 - 1.5 y2 = restore_values(x, y, missing)[0] - x2 = (y2 + 1.5) / 10. + x2 = (y2 + 1.5) / 10.0 assert_allclose(x, x2, atol=1e-7) diff --git a/expyfun/analyze/tests/test_viz.py b/expyfun/analyze/tests/test_viz.py index 739a51e6..2bbbcaf5 100644 --- a/expyfun/analyze/tests/test_viz.py +++ b/expyfun/analyze/tests/test_viz.py @@ -1,7 +1,7 @@ -import numpy as np -from os import path as op import warnings +from os import path as op +import numpy as np import pytest from numpy.testing import assert_equal @@ -13,16 +13,19 @@ def _check_warnings(w): """Silly helper to deal with MPL deprecation warnings.""" - assert all(['expyfun' not in ww.filename for ww in w]) + assert all(["expyfun" not in ww.filename for ww in w]) -@requires_lib('pandas') +@requires_lib("pandas") def test_barplot_with_pandas(): """Test bar plot function pandas support.""" import pandas as pd - tmp = pd.DataFrame(np.arange(20).reshape((4, 5)), - columns=['a', 'b', 'c', 'd', 'e'], - index=['one', 'two', 'three', 'four']) + + tmp = pd.DataFrame( + np.arange(20).reshape((4, 5)), + columns=["a", "b", "c", "d", "e"], + index=["one", "two", "three", "four"], + ) ea.barplot(tmp) ea.barplot(tmp, axis=0, lines=True) @@ -31,13 +34,14 @@ def test_barplot_with_pandas(): def tmp_err(): # noqa rng = np.random.RandomState(0) tmp = np.ones(4) + rng.rand(4) - err = 0.1 + tmp / 5. + err = 0.1 + tmp / 5.0 return tmp, err def test_barplot_degenerate(tmp_err): """Test bar plot degenerate cases.""" import matplotlib.pyplot as plt + tmp, err = tmp_err # too many data dimensions: pytest.raises(ValueError, ea.barplot, np.arange(8).reshape((2, 2, 2))) @@ -46,73 +50,111 @@ def test_barplot_degenerate(tmp_err): # shape mismatch between data & error bars: pytest.raises(ValueError, ea.barplot, tmp, err_bars=np.arange(3)) # bad err_bar string: - pytest.raises(ValueError, ea.barplot, tmp, err_bars='foo') + pytest.raises(ValueError, ea.barplot, tmp, err_bars="foo") # cannot calculate 'sd' error bars with only 1 value per bar: - pytest.raises(ValueError, ea.barplot, tmp, err_bars='sd') + pytest.raises(ValueError, ea.barplot, tmp, err_bars="sd") # mismatched lengths of brackets & bracket_text: - pytest.raises(ValueError, ea.barplot, tmp, brackets=[(0, 1)], - bracket_text=['foo', 'bar']) + pytest.raises( + ValueError, ea.barplot, tmp, brackets=[(0, 1)], bracket_text=["foo", "bar"] + ) # bad bracket spec: - pytest.raises(ValueError, ea.barplot, tmp, brackets=[(1,)], - bracket_text=['foo']) - plt.close('all') + pytest.raises(ValueError, ea.barplot, tmp, brackets=[(1,)], bracket_text=["foo"]) + plt.close("all") def test_barplot_single(tmp_err): """Test with single data point and single error bar spec.""" import matplotlib.pyplot as plt + tmp, err = tmp_err ea.barplot(2, err_bars=0.2) - plt.close('all') + plt.close("all") @pytest.mark.timeout(15) def test_barplot_single_spec(tmp_err): """Test with one data point per bar and user-specified err ranges.""" import matplotlib.pyplot as plt + tmp, err = tmp_err _, axs = plt.subplots(1, 5, sharey=False) - ea.barplot(tmp, err_bars=err, brackets=([2, 3], [0, 1]), ax=axs[0], - bracket_text=['foo', 'bar'], bracket_inline=True) - ea.barplot(tmp, err_bars=err, brackets=((0, 2), (1, 3)), ax=axs[1], - bracket_text=['foo', 'bar']) - ea.barplot(tmp, err_bars=err, brackets=[[2, 1], [0, 3]], ax=axs[2], - bracket_text=['foo', 'bar']) - ea.barplot(tmp, err_bars=err, brackets=[(0, 1), (0, 2), (0, 3)], - bracket_text=['foo', 'bar', 'baz'], ax=axs[3]) - ea.barplot(tmp, err_bars=err, brackets=[(0, 1), (2, 3), (0, 2), (1, 3)], - bracket_text=['foo', 'bar', 'baz', 'snafu'], ax=axs[4]) - ea.barplot(tmp, groups=[[0, 1, 2], [3]], eq_group_widths=True, - brackets=[(0, 1), (1, 2), ([0, 1, 2], 3)], - bracket_text=['foo', 'bar', 'baz'], - bracket_group_lines=True) - plt.close('all') + ea.barplot( + tmp, + err_bars=err, + brackets=([2, 3], [0, 1]), + ax=axs[0], + bracket_text=["foo", "bar"], + bracket_inline=True, + ) + ea.barplot( + tmp, + err_bars=err, + brackets=((0, 2), (1, 3)), + ax=axs[1], + bracket_text=["foo", "bar"], + ) + ea.barplot( + tmp, + err_bars=err, + brackets=[[2, 1], [0, 3]], + ax=axs[2], + bracket_text=["foo", "bar"], + ) + ea.barplot( + tmp, + err_bars=err, + brackets=[(0, 1), (0, 2), (0, 3)], + bracket_text=["foo", "bar", "baz"], + ax=axs[3], + ) + ea.barplot( + tmp, + err_bars=err, + brackets=[(0, 1), (2, 3), (0, 2), (1, 3)], + bracket_text=["foo", "bar", "baz", "snafu"], + ax=axs[4], + ) + ea.barplot( + tmp, + groups=[[0, 1, 2], [3]], + eq_group_widths=True, + brackets=[(0, 1), (1, 2), ([0, 1, 2], 3)], + bracket_text=["foo", "bar", "baz"], + bracket_group_lines=True, + ) + plt.close("all") @pytest.mark.timeout(10) def test_barplot_multiple(): """Test with multiple data points per bar and calculated ranges.""" import matplotlib.pyplot as plt + rng = np.random.RandomState(0) tmp = (rng.randn(20) + np.arange(20)).reshape((5, 4)) # 2-dim _, axs = plt.subplots(1, 4, sharey=False) - ea.barplot(tmp, lines=True, err_bars='sd', ax=axs[0], smart_defaults=False) - ea.barplot(tmp, lines=True, err_bars='ci', ax=axs[1], axis=0) - ea.barplot(tmp, lines=True, err_bars='se', ax=axs[2], ylim=(0, 30)) - ea.barplot(tmp, lines=True, err_bars='se', ax=axs[3], - groups=[[0, 1, 2], [3, 4]], bracket_group_lines=True, - brackets=[(0, 1), (1, 2), (3, 4), ([0, 1, 2], [3, 4])], - bracket_text=['foo', 'bar', 'baz', 'snafu']) - extns = ['pdf'] # jpg, tif not supported; 'png', 'raw', 'svg' not tested + ea.barplot(tmp, lines=True, err_bars="sd", ax=axs[0], smart_defaults=False) + ea.barplot(tmp, lines=True, err_bars="ci", ax=axs[1], axis=0) + ea.barplot(tmp, lines=True, err_bars="se", ax=axs[2], ylim=(0, 30)) + ea.barplot( + tmp, + lines=True, + err_bars="se", + ax=axs[3], + groups=[[0, 1, 2], [3, 4]], + bracket_group_lines=True, + brackets=[(0, 1), (1, 2), (3, 4), ([0, 1, 2], [3, 4])], + bracket_text=["foo", "bar", "baz", "snafu"], + ) + extns = ["pdf"] # jpg, tif not supported; 'png', 'raw', 'svg' not tested for ext in extns: - fname = op.join(temp_dir, 'temp.' + ext) + fname = op.join(temp_dir, "temp." + ext) with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') - ea.barplot(tmp, groups=[[0, 1, 2], [3]], err_bars='sd', axis=0, - fname=fname) + warnings.simplefilter("always") + ea.barplot(tmp, groups=[[0, 1, 2], [3]], err_bars="sd", axis=0, fname=fname) plt.close() _check_warnings(w) - plt.close('all') + plt.close("all") def test_plot_screen(): @@ -126,10 +168,10 @@ def test_plot_screen(): def test_format_pval(): """Test p-value formatting.""" foo = ea.format_pval(1e-10, latex=False) - bar = ea.format_pval(1e-10, scheme='ross') + bar = ea.format_pval(1e-10, scheme="ross") baz = ea.format_pval([0.2, 0.02]) - qux = ea.format_pval(0.002, scheme='stars') - assert_equal(foo, 'p < 10^-9') - assert_equal(bar, '$p < 10^{{-9}}$') - assert_equal(baz[0], '$n.s.$') - assert_equal(qux, '${*}{*}$') + qux = ea.format_pval(0.002, scheme="stars") + assert_equal(foo, "p < 10^-9") + assert_equal(bar, "$p < 10^{{-9}}$") + assert_equal(baz[0], "$n.s.$") + assert_equal(qux, "${*}{*}$") diff --git a/expyfun/codeblocks/__init__.py b/expyfun/codeblocks/__init__.py index 6f3798e6..342c3d60 100644 --- a/expyfun/codeblocks/__init__.py +++ b/expyfun/codeblocks/__init__.py @@ -8,5 +8,4 @@ # Copyright (c) 2014, LABSN. # Distributed under the (new) BSD License. See LICENSE.txt for more info. -from ._pupillometry import (find_pupil_dynamic_range, - find_pupil_tone_impulse_response) +from ._pupillometry import find_pupil_dynamic_range, find_pupil_tone_impulse_response diff --git a/expyfun/codeblocks/_pupillometry.py b/expyfun/codeblocks/_pupillometry.py index 021094ec..fe2946a6 100644 --- a/expyfun/codeblocks/_pupillometry.py +++ b/expyfun/codeblocks/_pupillometry.py @@ -1,12 +1,11 @@ -"""Analysis functions (mostly for psychophysics data). -""" +"""Analysis functions (mostly for psychophysics data).""" import numpy as np -from ..visual import FixationDot -from ..analyze import sigmoid from .._utils import logger, verbose_dec +from ..analyze import sigmoid from ..stimuli import window_edges +from ..visual import FixationDot def _check_pyeparse(): @@ -20,12 +19,13 @@ def _check_pyeparse(): def _load_raw(el, fname): """Helper to load some pupil data""" import pyeparse + fname = el.transfer_remote_file(fname) # Load and parse data - logger.info('Pupillometry: Parsing local file "{0}"'.format(fname)) + logger.info(f'Pupillometry: Parsing local file "{fname}"') raw = pyeparse.RawEDF(fname) raw.remove_blink_artifacts() - events = raw.find_events('SYNCTIME', 1) + events = raw.find_events("SYNCTIME", 1) return raw, events @@ -61,14 +61,17 @@ def find_pupil_dynamic_range(ec, el, prompt=True, verbose=None): """ _check_pyeparse() import pyeparse + if el.recording: el.stop() el.calibrate() if prompt: - ec.screen_prompt('We will now determine the dynamic ' - 'range of your pupil.\n\n' - 'Press a button to continue.') - levels = np.concatenate(([0.], 2 ** np.arange(8) / 255.)) + ec.screen_prompt( + "We will now determine the dynamic " + "range of your pupil.\n\n" + "Press a button to continue." + ) + levels = np.concatenate(([0.0], 2 ** np.arange(8) / 255.0)) fixs = levels + 0.2 n_rep = 2 # inter-rep interval (allow system to reset) @@ -76,15 +79,14 @@ def find_pupil_dynamic_range(ec, el, prompt=True, verbose=None): # amount of time between levels settle_time = 3.0 if not el.dummy_mode else 0.3 fix = FixationDot(ec) - fix.set_colors([fixs[0] * np.ones(3), 'k']) - ec.set_background_color('k') + fix.set_colors([fixs[0] * np.ones(3), "k"]) + ec.set_background_color("k") fix.draw() ec.flip() for ri in range(n_rep): ec.wait_secs(iri) for ii, (lev, fc) in enumerate(zip(levels, fixs)): - ec.identify_trial(ec_id='FPDR_%02i' % (ii + 1), - el_id=[ii + 1], ttl_id=()) + ec.identify_trial(ec_id="FPDR_%02i" % (ii + 1), el_id=[ii + 1], ttl_id=()) bgcolor = np.ones(3) * lev fcolor = np.ones(3) * fc ec.set_background_color(bgcolor) @@ -95,13 +97,12 @@ def find_pupil_dynamic_range(ec, el, prompt=True, verbose=None): ec.check_force_quit() ec.stop() ec.trial_ok() - ec.set_background_color('k') - fix.set_colors([fixs[0] * np.ones(3), 'k']) + ec.set_background_color("k") + fix.set_colors([fixs[0] * np.ones(3), "k"]) fix.draw() ec.flip() el.stop() # stop the recording - ec.screen_prompt('Processing data, please wait...', max_wait=0, - clear_after=False) + ec.screen_prompt("Processing data, please wait...", max_wait=0, clear_after=False) # now we need to parse the data if el.dummy_mode: @@ -115,17 +116,18 @@ def find_pupil_dynamic_range(ec, el, prompt=True, verbose=None): epochs = pyeparse.Epochs(raw, events, 1, -0.5, settle_time) assert len(epochs) == len(levels) * n_rep idx = epochs.n_times // 2 - resp = np.median(epochs.get_data('ps')[:, idx:], 1) + resp = np.median(epochs.get_data("ps")[:, idx:], 1) bgcolor = np.mean(resp.reshape((n_rep, len(levels))), 0) idx = np.argmin(np.diff(bgcolor)) + 1 bgcolor = levels[idx] * np.ones(3) fcolor = fixs[idx] * np.ones(3) - logger.info('Pupillometry: optimal background color {0}'.format(bgcolor)) + logger.info(f"Pupillometry: optimal background color {bgcolor}") return bgcolor, fcolor, np.tile(levels, n_rep), resp -def find_pupil_tone_impulse_response(ec, el, bgcolor, fcolor, prompt=True, - verbose=None, targ_is_fm=True): +def find_pupil_tone_impulse_response( + ec, el, bgcolor, fcolor, prompt=True, verbose=None, targ_is_fm=True +): """Find pupil impulse response using responses to tones Parameters @@ -162,6 +164,7 @@ def find_pupil_tone_impulse_response(ec, el, bgcolor, fcolor, prompt=True, """ _check_pyeparse() import pyeparse + if el.recording: el.stop() @@ -175,14 +178,14 @@ def find_pupil_tone_impulse_response(ec, el, bgcolor, fcolor, prompt=True, delay_range = np.array(delay_range) targ_prop = 0.25 stim_dur = 100e-3 - f0 = 1000. # Hz + f0 = 1000.0 # Hz rng = np.random.RandomState(0) isis = np.linspace(*delay_range, num=n_stimuli) n_targs = int(targ_prop * n_stimuli) targs = np.zeros(n_stimuli, bool) targs[np.linspace(0, n_stimuli - 1, n_targs + 2)[1:-1].astype(int)] = True - while(True): # ensure we randomize but don't start with a target + while True: # ensure we randomize but don't start with a target idx = rng.permutation(np.arange(n_stimuli)) isis = isis[idx] targs = targs[idx] @@ -196,8 +199,9 @@ def find_pupil_tone_impulse_response(ec, el, bgcolor, fcolor, prompt=True, n_samp = int(fs * stim_dur) t = np.arange(n_samp).astype(float) / fs steady = np.sin(2 * np.pi * f0 * t) - wobble = np.sin(np.cumsum(f0 + 100 * np.sin(2 * np.pi * (1 / stim_dur) * t) - ) / fs * 2 * np.pi) + wobble = np.sin( + np.cumsum(f0 + 100 * np.sin(2 * np.pi * (1 / stim_dur) * t)) / fs * 2 * np.pi + ) std_stim, dev_stim = (steady, wobble) if targ_is_fm else (wobble, steady) std_stim = window_edges(std_stim * ec._stim_rms * np.sqrt(2), fs) dev_stim = window_edges(dev_stim * ec._stim_rms * np.sqrt(2), fs) @@ -207,17 +211,23 @@ def find_pupil_tone_impulse_response(ec, el, bgcolor, fcolor, prompt=True, # ec.stop() ec.set_background_color(bgcolor) - targstr, tonestr = ('wobble', 'beep') if targ_is_fm else ('beep', 'wobble') - instr = ('Remember to press the button as quickly as possible following ' - 'each "{}" sound.\n\nPress the response button to ' - 'continue.'.format(targstr)) + targstr, tonestr = ("wobble", "beep") if targ_is_fm else ("beep", "wobble") + instr = ( + "Remember to press the button as quickly as possible following " + f'each "{targstr}" sound.\n\nPress the response button to ' + "continue." + ) if prompt: - notes = [('We will now determine the response of your pupil to sound ' - 'changes.\n\nYour job is to press the response button ' - 'as quickly as possible when you hear a "{1}" instead ' - 'of a "{0}".\n\nPress a button to hear the "{0}".' - ''.format(tonestr, targstr)), - ('Now press a button to hear the "{}".'.format(targstr))] + notes = [ + ( + "We will now determine the response of your pupil to sound " + "changes.\n\nYour job is to press the response button " + f'as quickly as possible when you hear a "{targstr}" instead ' + f'of a "{tonestr}".\n\nPress a button to hear the "{tonestr}".' + "" + ), + (f'Now press a button to hear the "{targstr}".'), + ] for text, stim in zip(notes, (std_stim, dev_stim)): ec.screen_prompt(text) ec.load_buffer(stim) @@ -235,10 +245,12 @@ def find_pupil_tone_impulse_response(ec, el, bgcolor, fcolor, prompt=True, if ii in cal_stim: if ii != 0: el.stop() - perc = round((100. * ii) / n_stimuli) - ec.screen_prompt('Great work! You are {0}% done.\n\nFeel ' - 'free to take a break, then press the ' - 'button to continue.'.format(perc)) + perc = round((100.0 * ii) / n_stimuli) + ec.screen_prompt( + f"Great work! You are {perc}% done.\n\nFeel " + "free to take a break, then press the " + "button to continue." + ) el.calibrate() ec.screen_prompt(instr) # let's put the initial color up to allow the system to settle @@ -247,15 +259,15 @@ def find_pupil_tone_impulse_response(ec, el, bgcolor, fcolor, prompt=True, ec.wait_secs(10.0) # let the pupil settle fix.draw() ec.load_buffer(dev_stim if targ else std_stim) - ec.identify_trial(ec_id='TONE_{0}'.format(int(targ)), - el_id=[int(targ)], ttl_id=[int(targ)]) + ec.identify_trial( + ec_id=f"TONE_{int(targ)}", el_id=[int(targ)], ttl_id=[int(targ)] + ) flip_times.append(ec.start_stimulus()) presses.append(ec.wait_for_presses(isi)) ec.stop() ec.trial_ok() el.stop() # stop the recording - ec.screen_prompt('Processing data, please wait...', max_wait=0, - clear_after=False) + ec.screen_prompt("Processing data, please wait...", max_wait=0, clear_after=False) flip_times = np.array(flip_times) tmin = -0.5 @@ -275,8 +287,7 @@ def find_pupil_tone_impulse_response(ec, el, bgcolor, fcolor, prompt=True, raws.append(raw) events.append(event) assert sum(len(event) for event in events) == n_stimuli - epochs = pyeparse.Epochs(raws, events, 1, - tmin=tmin, tmax=delay_range[0]) + epochs = pyeparse.Epochs(raws, events, 1, tmin=tmin, tmax=delay_range[0]) response = epochs.pupil_zscores() assert response.shape[0] == n_stimuli std_err = np.std(response[~targs], axis=0) diff --git a/expyfun/conftest.py b/expyfun/conftest.py index 6fda1829..f2a38734 100644 --- a/expyfun/conftest.py +++ b/expyfun/conftest.py @@ -1,12 +1,13 @@ -# -*- coding: utf-8 -*- # Author: Eric Larson # # License: BSD (3-clause) import os + import pytest -from expyfun._utils import _get_display + from expyfun._sound_controllers import _AUTO_BACKENDS +from expyfun._utils import _get_display # Unknown pytest problem with readline<->deprecated decorator try: @@ -15,40 +16,43 @@ pass -@pytest.mark.timeout(0) # importing plt will build font cache, slow on Azure -@pytest.fixture(scope='session') +@pytest.fixture(scope="session") def matplotlib_config(): """Configure matplotlib for viz tests.""" import matplotlib - matplotlib.use('agg') # don't pop up windows + + matplotlib.use("agg") # don't pop up windows import matplotlib.pyplot as plt - assert plt.get_backend() == 'agg' + + assert plt.get_backend() == "agg" # overwrite some params that can horribly slow down tests that # users might have changed locally (but should not otherwise affect # functionality) plt.ioff() - plt.rcParams['figure.dpi'] = 100 - os.environ['_EXPYFUN_WIN_INVISIBLE'] = 'true' + plt.rcParams["figure.dpi"] = 100 + os.environ["_EXPYFUN_WIN_INVISIBLE"] = "true" -@pytest.fixture(scope='function') +@pytest.fixture(scope="function") def hide_window(): """Hide the expyfun window.""" try: _get_display() except Exception as exp: - pytest.skip('Windowing unavailable (%s)' % exp) + pytest.skip("Windowing unavailable (%s)" % exp) -_SOUND_CARD_ACS = tuple({'TYPE': 'sound_card', 'SOUND_CARD_BACKEND': backend} - for backend in _AUTO_BACKENDS) +_SOUND_CARD_ACS = tuple( + {"TYPE": "sound_card", "SOUND_CARD_BACKEND": backend} for backend in _AUTO_BACKENDS +) for val in _SOUND_CARD_ACS: - if val['SOUND_CARD_BACKEND'] == 'pyglet': - val.update(SOUND_CARD_API=None, SOUND_CARD_NAME=None, - SOUND_CARD_FIXED_DELAY=None) + if val["SOUND_CARD_BACKEND"] == "pyglet": + val.update( + SOUND_CARD_API=None, SOUND_CARD_NAME=None, SOUND_CARD_FIXED_DELAY=None + ) -@pytest.fixture(scope="module", params=('tdt',) + _SOUND_CARD_ACS) +@pytest.fixture(scope="module", params=("tdt",) + _SOUND_CARD_ACS) def ac(request): """Get the backend name.""" yield request.param diff --git a/expyfun/io/__init__.py b/expyfun/io/__init__.py index c655b909..ae3fe8e5 100644 --- a/expyfun/io/__init__.py +++ b/expyfun/io/__init__.py @@ -4,12 +4,10 @@ File reading and writing routines. """ -# -*- coding: utf-8 -*- +from h5io import read_hdf5 as _read_hdf5, write_hdf5 as _write_hdf5 + from ._wav import read_wav, write_wav -from .._externals._h5io import (read_hdf5 as _read_hdf5, - write_hdf5 as _write_hdf5) -from ._parse import (read_tab, reconstruct_tracker, - reconstruct_dealer, read_tab_raw) +from ._parse import read_tab, reconstruct_tracker, reconstruct_dealer, read_tab_raw def read_hdf5(fname): @@ -29,7 +27,7 @@ def read_hdf5(fname): -------- write_hdf5 """ - return _read_hdf5(fname, title='expyfun') + return _read_hdf5(fname, title="expyfun") def write_hdf5(fname, data, overwrite=False, compression=4): @@ -54,4 +52,4 @@ def write_hdf5(fname, data, overwrite=False, compression=4): -------- read_hdf5 """ - return _write_hdf5(fname, data, overwrite, compression, title='expyfun') + return _write_hdf5(fname, data, overwrite, compression, title="expyfun") diff --git a/expyfun/io/_parse.py b/expyfun/io/_parse.py index 2729be71..ea8a0a8a 100644 --- a/expyfun/io/_parse.py +++ b/expyfun/io/_parse.py @@ -1,11 +1,9 @@ -# -*- coding: utf-8 -*- -"""File parsing functions -""" +"""File parsing functions""" import ast -from collections import OrderedDict import csv import json +from collections import OrderedDict import numpy as np @@ -33,23 +31,21 @@ def read_tab_raw(fname, return_params=False): -------- read_tab """ - with open(fname, 'r') as f: - csvr = csv.reader(f, delimiter='\t') + with open(fname) as f: + csvr = csv.reader(f, delimiter="\t") lines = [c for c in csvr] # first two lines are headers - assert len(lines[0]) == 1 and lines[0][0].startswith('# ') + assert len(lines[0]) == 1 and lines[0][0].startswith("# ") if return_params: line = lines[0][0][2:] try: - params = json.loads( - line, object_pairs_hook=OrderedDict) + params = json.loads(line, object_pairs_hook=OrderedDict) except json.decoder.JSONDecodeError: # old format - params = json.loads( - line.replace("'", '"'), object_pairs_hook=OrderedDict) + params = json.loads(line.replace("'", '"'), object_pairs_hook=OrderedDict) else: params = None - assert lines[1] == ['timestamp', 'event', 'value'] + assert lines[1] == ["timestamp", "event", "value"] lines = lines[2:] times = [float(line[0]) for line in lines] @@ -59,8 +55,13 @@ def read_tab_raw(fname, return_params=False): return (data, params) if return_params else data -def read_tab(fname, group_start='trial_id', group_end='trial_ok', - return_params=False, allow_last_missing=False): +def read_tab( + fname, + group_start="trial_id", + group_end="trial_ok", + return_params=False, + allow_last_missing=False, +): """Read .tab file from expyfun output and segment into trials. Parameters @@ -100,28 +101,24 @@ def read_tab(fname, group_start='trial_id', group_end='trial_ok', header = list(set([line[1] for line in lines])) header.sort() if group_start not in header: - raise ValueError('group_start "{0}" not in header: {1}' - ''.format(group_start, header)) + raise ValueError(f'group_start "{group_start}" not in header: {header}' "") if group_end == group_start: - raise ValueError('group_start cannot equal group_end, use ' - 'group_end=None') + raise ValueError("group_start cannot equal group_end, use " "group_end=None") header = [header.pop(header.index(group_start))] + header b1s = np.where([line[1] == group_start for line in lines])[0] if group_end is None: b2s = np.concatenate((b1s[1:], [len(lines)])) else: # group_end is not None if group_end not in header: - raise ValueError('group_end "{0}" not in header ({1})' - ''.format(group_end, header)) + raise ValueError(f'group_end "{group_end}" not in header ({header})' "") header.append(header.pop(header.index(group_end))) b2s = np.where([line[1] == group_end for line in lines])[0] if len(b1s) == len(b2s) + 1 and allow_last_missing: # old expyfun would sometimes not write the last trial_ok :( b2s = np.concatenate([b2s, [len(lines)]]) - lines.append((lines[-1][0] + 0.1, group_end, 'None')) + lines.append((lines[-1][0] + 0.1, group_end, "None")) if len(b1s) != len(b2s) or not np.all(b1s < b2s): - raise RuntimeError('bad bounds in {0}:\n{1}\n{2}' - .format(fname, b1s, b2s)) + raise RuntimeError(f"bad bounds in {fname}:\n{b1s}\n{b2s}") data = [] for b1, b2 in zip(b1s, b2s): assert lines[b1][1] == group_start # prevent stupidity @@ -155,40 +152,42 @@ def reconstruct_tracker(fname): the generation of the file.) If only one tracker is found in the file, it will still be stored in a list and will be accessible as ``tr[0]``. """ - from ..stimuli import TrackerUD, TrackerBinom, TrackerMHW + from ..stimuli import TrackerBinom, TrackerMHW, TrackerUD + # read in raw data raw = read_tab_raw(fname) # find tracker_identify and make list of IDs - tracker_idx = np.where([r[1] == 'tracker_identify' for r in raw])[0] + tracker_idx = np.where([r[1] == "tracker_identify" for r in raw])[0] if len(tracker_idx) == 0: - raise ValueError('There are no Trackers in this file.') + raise ValueError("There are no Trackers in this file.") tr = [] used_dict_idx = [] # they can have repeat names! used_stop_idx = [] for ii in tracker_idx: - tracker_id = ast.literal_eval(raw[ii][2])['tracker_id'] - tracker_type = ast.literal_eval(raw[ii][2])['tracker_type'] + tracker_id = ast.literal_eval(raw[ii][2])["tracker_id"] + tracker_type = ast.literal_eval(raw[ii][2])["tracker_type"] # find tracker_ID_init lines and get dict - init_str = 'tracker_' + str(tracker_id) + '_init' + init_str = "tracker_" + str(tracker_id) + "_init" tracker_dict_idx = np.where([r[1] == init_str for r in raw])[0] tracker_dict_idx = np.setdiff1d(tracker_dict_idx, used_dict_idx) tracker_dict_idx = tracker_dict_idx[0] used_dict_idx.append(tracker_dict_idx) tracker_dict = json.loads(raw[tracker_dict_idx][2]) - td = dict(TrackerUD=TrackerUD, TrackerBinom=TrackerBinom, - TrackerMHW=TrackerMHW) + td = dict(TrackerUD=TrackerUD, TrackerBinom=TrackerBinom, TrackerMHW=TrackerMHW) tr.append(td[tracker_type](**tracker_dict)) tr[-1]._tracker_id = tracker_id # make sure tracker has original ID - stop_str = 'tracker_' + str(tracker_id) + '_stop' + stop_str = "tracker_" + str(tracker_id) + "_stop" tracker_stop_idx = np.where([r[1] == stop_str for r in raw])[0] tracker_stop_idx = np.setdiff1d(tracker_stop_idx, used_stop_idx) if len(tracker_stop_idx) == 0: - raise ValueError('Tracker {} has not stopped. All Trackers ' - 'must be stopped.'.format(tracker_id)) + raise ValueError( + f"Tracker {tracker_id} has not stopped. All Trackers " + "must be stopped." + ) tracker_stop_idx = tracker_stop_idx[0] used_stop_idx.append(tracker_stop_idx) - responses = json.loads(raw[tracker_stop_idx][2])['responses'] + responses = json.loads(raw[tracker_stop_idx][2])["responses"] # feed in responses from tracker_ID_stop for r in responses: tr[-1].respond(r) @@ -214,20 +213,20 @@ def reconstruct_dealer(fname): still be stored in a list and will be assessible as ``td[0]``. """ from ..stimuli import TrackerDealer + raw = read_tab_raw(fname) # find info on dealer - dealer_idx = np.where([r[1] == 'dealer_identify' for r in raw])[0] + dealer_idx = np.where([r[1] == "dealer_identify" for r in raw])[0] if len(dealer_idx) == 0: - raise ValueError('There are no TrackerDealers in this file.') + raise ValueError("There are no TrackerDealers in this file.") dealer = [] for ii in dealer_idx: - dealer_id = ast.literal_eval(raw[ii][2])['dealer_id'] - dealer_init_str = 'dealer_' + str(dealer_id) + '_init' - dealer_dict_idx = np.where([r[1] == dealer_init_str - for r in raw])[0][0] + dealer_id = ast.literal_eval(raw[ii][2])["dealer_id"] + dealer_init_str = "dealer_" + str(dealer_id) + "_init" + dealer_dict_idx = np.where([r[1] == dealer_init_str for r in raw])[0][0] dealer_dict = ast.literal_eval(raw[dealer_dict_idx][2]) - dealer_trackers = dealer_dict['trackers'] + dealer_trackers = dealer_dict["trackers"] # match up tracker objects to id trackers = reconstruct_tracker(fname) @@ -237,22 +236,24 @@ def reconstruct_dealer(fname): tr_objects.append(trackers[idx]) # make the dealer object - max_lag = dealer_dict['max_lag'] - pace_rule = dealer_dict['pace_rule'] + max_lag = dealer_dict["max_lag"] + pace_rule = dealer_dict["pace_rule"] dealer.append(TrackerDealer(None, tr_objects, max_lag, pace_rule)) # force input responses/log data - dealer_stop_str = 'dealer_' + str(dealer_id) + '_stop' + dealer_stop_str = "dealer_" + str(dealer_id) + "_stop" dealer_stop_idx = np.where([r[1] == dealer_stop_str for r in raw])[0] if len(dealer_stop_idx) == 0: - raise ValueError('TrackerDealer {} has not stopped. All dealers ' - 'must be stopped.'.format(dealer_id)) + raise ValueError( + f"TrackerDealer {dealer_id} has not stopped. All dealers " + "must be stopped." + ) dealer_stop_log = json.loads(raw[dealer_stop_idx[0]][2]) - shape = tuple(dealer_dict['shape']) - log_response_history = dealer_stop_log['response_history'] - log_x_history = dealer_stop_log['x_history'] - log_tracker_history = dealer_stop_log['tracker_history'] + shape = tuple(dealer_dict["shape"]) + log_response_history = dealer_stop_log["response_history"] + log_x_history = dealer_stop_log["x_history"] + log_tracker_history = dealer_stop_log["tracker_history"] dealer[-1]._shape = shape dealer[-1]._trackers.shape = shape diff --git a/expyfun/io/_wav.py b/expyfun/io/_wav.py index e888a1e1..3daf2ef1 100644 --- a/expyfun/io/_wav.py +++ b/expyfun/io/_wav.py @@ -1,13 +1,12 @@ -# -*- coding: utf-8 -*- -"""WAV file IO functions -""" +"""WAV file IO functions""" + +import warnings +from os import path as op import numpy as np from scipy.io import wavfile -from os import path as op -import warnings -from .._utils import verbose_dec, logger, _has_scipy_version +from .._utils import _has_scipy_version, logger, verbose_dec @verbose_dec @@ -36,7 +35,7 @@ def read_wav(fname, verbose=None): orig_dtype = data.dtype max_val = _get_dtype_norm(orig_dtype) data = np.ascontiguousarray(data.astype(np.float64) / max_val) - _print_wav_info('Read', data, orig_dtype) + _print_wav_info("Read", data, orig_dtype) return data, fs @@ -61,26 +60,30 @@ def write_wav(fname, data, fs, dtype=np.int16, overwrite=False, verbose=None): If not None, override default verbose level. """ if not overwrite and op.isfile(fname): - raise IOError('File {} exists, overwrite=True must be ' - 'used'.format(op.basename(fname))) - if not np.dtype(type(fs)).kind == 'i': + raise OSError( + f"File {op.basename(fname)} exists, overwrite=True must be " "used" + ) + if not np.dtype(type(fs)).kind == "i": fs = int(fs) - warnings.warn('Warning: sampling rate is being cast to integer and ' - 'may be truncated.') + warnings.warn( + "Warning: sampling rate is being cast to integer and " "may be truncated." + ) data = np.atleast_2d(data) - if np.dtype(dtype).kind not in ['i', 'f']: - raise TypeError('dtype must be integer or float') - if np.dtype(dtype).kind == 'f': - if not _has_scipy_version('0.13'): - raise RuntimeError('cannot write float datatype unless ' - 'scipy >= 0.13 is installed') + if np.dtype(dtype).kind not in ["i", "f"]: + raise TypeError("dtype must be integer or float") + if np.dtype(dtype).kind == "f": + if not _has_scipy_version("0.13"): + raise RuntimeError( + "cannot write float datatype unless " "scipy >= 0.13 is installed" + ) elif np.dtype(dtype).itemsize == 8: - raise RuntimeError('Writing 64-bit integers is not supported') - if np.dtype(data.dtype).kind == 'f': - if np.dtype(dtype).kind == 'i' and np.max(np.abs(data)) > 1.: - raise ValueError('Data must be between -1 and +1 when saving ' - 'with an integer dtype') - _print_wav_info('Writing', data, dtype) + raise RuntimeError("Writing 64-bit integers is not supported") + if np.dtype(data.dtype).kind == "f": + if np.dtype(dtype).kind == "i" and np.max(np.abs(data)) > 1.0: + raise ValueError( + "Data must be between -1 and +1 when saving " "with an integer dtype" + ) + _print_wav_info("Writing", data, dtype) max_val = _get_dtype_norm(dtype) data = (data * max_val).astype(dtype) wavfile.write(fname, fs, data.T) @@ -88,15 +91,16 @@ def write_wav(fname, data, fs, dtype=np.int16, overwrite=False, verbose=None): def _print_wav_info(pre, data, dtype): """Helper to print WAV info""" - logger.info('{0} WAV file with {1} channel{3} and {2} samples ' - '(format {4})'.format(pre, data.shape[0], data.shape[1], - 's' if data.shape[0] != 1 else '', - dtype)) + logger.info( + "{0} WAV file with {1} channel{3} and {2} samples " "(format {4})".format( + pre, data.shape[0], data.shape[1], "s" if data.shape[0] != 1 else "", dtype + ) + ) def _get_dtype_norm(dtype): """Helper to get normalization factor for a given datatype""" - if np.dtype(dtype).kind == 'i': + if np.dtype(dtype).kind == "i": info = np.iinfo(dtype) maxval = min(-info.min, info.max) else: # == 'f' diff --git a/expyfun/io/tests/test_parse.py b/expyfun/io/tests/test_parse.py index 6e9d519f..17f33ecd 100644 --- a/expyfun/io/tests/test_parse.py +++ b/expyfun/io/tests/test_parse.py @@ -3,71 +3,78 @@ from numpy.testing import assert_equal from expyfun import ExperimentController, __version__ -from expyfun.io import read_tab, reconstruct_tracker, reconstruct_dealer from expyfun._utils import _TempDir -from expyfun.stimuli import TrackerUD, TrackerBinom, TrackerDealer +from expyfun.io import read_tab, reconstruct_dealer, reconstruct_tracker +from expyfun.stimuli import TrackerBinom, TrackerDealer, TrackerUD temp_dir = _TempDir() -std_args = ['test'] # experiment name -std_kwargs = dict(output_dir=temp_dir, full_screen=False, window_size=(1, 1), - participant='foo', session='01', stim_db=0.0, noise_db=0.0, - verbose=True, version='dev') - - -@pytest.mark.timeout(10) +std_args = ["test"] # experiment name +std_kwargs = dict( + output_dir=temp_dir, + full_screen=False, + window_size=(1, 1), + participant="foo", + session="01", + stim_db=0.0, + noise_db=0.0, + verbose=True, + version="dev", +) + + +@pytest.mark.timeout(20) def test_parse_basic(hide_window, tmpdir): """Test .tab parsing.""" with ExperimentController(*std_args, **std_kwargs) as ec: - ec.identify_trial(ec_id='one', ttl_id=[0]) + ec.identify_trial(ec_id="one", ttl_id=[0]) ec.start_stimulus() - ec.write_data_line('misc', 'trial one') + ec.write_data_line("misc", "trial one") ec.stop() ec.trial_ok() - ec.write_data_line('misc', 'between trials') - ec.identify_trial(ec_id='two', ttl_id=[1]) + ec.write_data_line("misc", "between trials") + ec.identify_trial(ec_id="two", ttl_id=[1]) ec.start_stimulus() - ec.write_data_line('misc', 'trial two') + ec.write_data_line("misc", "trial two") ec.stop() ec.trial_ok() - ec.write_data_line('misc', 'end of experiment') + ec.write_data_line("misc", "end of experiment") - pytest.raises(ValueError, read_tab, ec.data_fname, group_start='foo') - pytest.raises(ValueError, read_tab, ec.data_fname, group_end='foo') - pytest.raises(ValueError, read_tab, ec.data_fname, group_end='trial_id') - pytest.raises(RuntimeError, read_tab, ec.data_fname, group_end='misc') + pytest.raises(ValueError, read_tab, ec.data_fname, group_start="foo") + pytest.raises(ValueError, read_tab, ec.data_fname, group_end="foo") + pytest.raises(ValueError, read_tab, ec.data_fname, group_end="trial_id") + pytest.raises(RuntimeError, read_tab, ec.data_fname, group_end="misc") data = read_tab(ec.data_fname) keys = list(data[0].keys()) assert_equal(len(keys), 6) - for key in ['trial_id', 'flip', 'play', 'stop', 'misc', 'trial_ok']: + for key in ["trial_id", "flip", "play", "stop", "misc", "trial_ok"]: assert key in keys - assert_equal(len(data[0]['misc']), 1) - assert_equal(len(data[1]['misc']), 1) + assert_equal(len(data[0]["misc"]), 1) + assert_equal(len(data[1]["misc"]), 1) data, params = read_tab(ec.data_fname, group_end=None, return_params=True) - assert_equal(len(data[0]['misc']), 2) # includes between-trials stuff - assert_equal(len(data[1]['misc']), 2) - assert_equal(params['version'], 'dev') - assert_equal(params['version_used'], __version__) - assert (params['file'].endswith('test_parse.py')) + assert_equal(len(data[0]["misc"]), 2) # includes between-trials stuff + assert_equal(len(data[1]["misc"]), 2) + assert_equal(params["version"], "dev") + assert_equal(params["version_used"], __version__) + assert params["file"].endswith("test_parse.py") # handle old files where the last trial_ok was missing - bad_fname = str(tmpdir.join('bad.tab')) - with open(ec.data_fname, 'r') as fid: + bad_fname = str(tmpdir.join("bad.tab")) + with open(ec.data_fname) as fid: lines = fid.readlines() - assert 'trial_ok' in lines[-3] - with open(bad_fname, 'w') as fid: + assert "trial_ok" in lines[-3] + with open(bad_fname, "w") as fid: # we used to write JSON badly fid.write(lines[0].replace('"', "'")) # and then sometimes missed the last trial_ok for line in lines[1:-3]: fid.write(line) - with pytest.raises(RuntimeError, match='bad bounds'): + with pytest.raises(RuntimeError, match="bad bounds"): read_tab(bad_fname) data, params = read_tab(ec.data_fname, return_params=True) - data_2, params_2 = read_tab( - bad_fname, return_params=True, allow_last_missing=True) + data_2, params_2 = read_tab(bad_fname, return_params=True, allow_last_missing=True) assert params == params_2 - t = data[-1].pop('trial_ok') - t_2 = data_2[-1].pop('trial_ok') + t = data[-1].pop("trial_ok") + t_2 = data_2[-1].pop("trial_ok") assert t != t_2 assert data_2 == data @@ -81,24 +88,24 @@ def test_reconstruct(hide_window): tr.respond(np.random.rand() < tr.x_current) tracker = reconstruct_tracker(ec.data_fname)[0] - assert (tracker.stopped) + assert tracker.stopped tracker.x_current # test with one TrackerBinom with ExperimentController(*std_args, **std_kwargs) as ec: - tr = TrackerBinom(ec, .05, .5, 10) + tr = TrackerBinom(ec, 0.05, 0.5, 10) while not tr.stopped: tr.respond(True) tracker = reconstruct_tracker(ec.data_fname)[0] - assert (tracker.stopped) + assert tracker.stopped tracker.x_current # tracker not stopped with ExperimentController(*std_args, **std_kwargs) as ec: tr = TrackerUD(ec, 1, 1, 3, 1, 5, np.inf, 3) tr.respond(np.random.rand() < tr.x_current) - assert (not tr.stopped) + assert not tr.stopped pytest.raises(ValueError, reconstruct_tracker, ec.data_fname) # test with dealer @@ -110,20 +117,20 @@ def test_reconstruct(hide_window): td.respond(np.random.rand() < x_current) dealer = reconstruct_dealer(ec.data_fname)[0] - assert (all(td._x_history == dealer._x_history)) - assert (all(td._tracker_history == dealer._tracker_history)) - assert (all(td._response_history == dealer._response_history)) - assert (td.shape == dealer.shape) - assert (td.trackers.shape == dealer.trackers.shape) + assert all(td._x_history == dealer._x_history) + assert all(td._tracker_history == dealer._tracker_history) + assert all(td._response_history == dealer._response_history) + assert td.shape == dealer.shape + assert td.trackers.shape == dealer.trackers.shape # no tracker/dealer in file with ExperimentController(*std_args, **std_kwargs) as ec: - ec.identify_trial(ec_id='one', ttl_id=[0]) + ec.identify_trial(ec_id="one", ttl_id=[0]) ec.start_stimulus() - ec.write_data_line('misc', 'trial one') + ec.write_data_line("misc", "trial one") ec.stop() ec.trial_ok() - ec.write_data_line('misc', 'end') + ec.write_data_line("misc", "end") pytest.raises(ValueError, reconstruct_tracker, ec.data_fname) pytest.raises(ValueError, reconstruct_dealer, ec.data_fname) diff --git a/expyfun/io/tests/test_wav.py b/expyfun/io/tests/test_wav.py index ad77f48b..46a901de 100644 --- a/expyfun/io/tests/test_wav.py +++ b/expyfun/io/tests/test_wav.py @@ -1,9 +1,8 @@ -# -*- coding: utf-8 -*- +from os import path as op + import numpy as np import pytest -from numpy.testing import (assert_array_almost_equal, assert_array_equal, - assert_equal) -from os import path as op +from numpy.testing import assert_array_almost_equal, assert_array_equal, assert_equal from expyfun._utils import _has_scipy_version from expyfun.io import read_wav, write_wav @@ -11,7 +10,7 @@ def test_read_write_wav(tmpdir): """Test reading and writing WAV files.""" - fname = op.join(str(tmpdir), 'temp.wav') + fname = op.join(str(tmpdir), "temp.wav") data = np.r_[np.random.rand(1000), 1, -1] fs = 44100 @@ -25,12 +24,13 @@ def test_read_write_wav(tmpdir): pytest.raises(IOError, write_wav, fname, data, fs) # test forcing fs dtype to int - with pytest.warns(UserWarning, match='rate is being cast'): + with pytest.warns(UserWarning, match="rate is being cast"): write_wav(fname, data, float(fs), overwrite=True) # Use 64-bit int: not supported - pytest.raises(RuntimeError, write_wav, fname, data, fs, dtype=np.int64, - overwrite=True) + pytest.raises( + RuntimeError, write_wav, fname, data, fs, dtype=np.int64, overwrite=True + ) # Use 32-bit int: better write_wav(fname, data, fs, dtype=np.int32, overwrite=True) @@ -38,7 +38,7 @@ def test_read_write_wav(tmpdir): assert_equal(fs_read, fs) assert_array_almost_equal(data[np.newaxis, :], data_read, 7) - if _has_scipy_version('0.13'): + if _has_scipy_version("0.13"): # Use 32-bit float: better write_wav(fname, data, fs, dtype=np.float32, overwrite=True) data_read, fs_read = read_wav(fname) @@ -51,8 +51,9 @@ def test_read_write_wav(tmpdir): assert_equal(fs_read, fs) assert_array_equal(data[np.newaxis, :], data_read) else: - pytest.raises(RuntimeError, write_wav, fname, data, fs, - dtype=np.float32, overwrite=True) + pytest.raises( + RuntimeError, write_wav, fname, data, fs, dtype=np.float32, overwrite=True + ) # Now try multi-dimensional data data = np.tile(data[np.newaxis, :], (2, 1)) diff --git a/expyfun/stimuli/__init__.py b/expyfun/stimuli/__init__.py index 9531df4b..d71b6b7c 100644 --- a/expyfun/stimuli/__init__.py +++ b/expyfun/stimuli/__init__.py @@ -16,8 +16,13 @@ from ._tracker import TrackerUD, TrackerBinom, TrackerDealer, TrackerMHW from .._tdt_controller import get_tdt_rates from ._texture import texture_ERB -from ._crm import (crm_sentence, crm_response_menu, crm_prepare_corpus, - crm_info, CRMPreload) +from ._crm import ( + crm_sentence, + crm_response_menu, + crm_prepare_corpus, + crm_info, + CRMPreload, +) # for backward compat (not great to do this...) from ..io import read_wav, write_wav diff --git a/expyfun/stimuli/_crm.py b/expyfun/stimuli/_crm.py index 7d4ef047..33b7be21 100644 --- a/expyfun/stimuli/_crm.py +++ b/expyfun/stimuli/_crm.py @@ -1,60 +1,45 @@ -"""Functions for using the Coordinate Response Measure (CRM) corpus. -""" +"""Functions for using the Coordinate Response Measure (CRM) corpus.""" # Author: Ross Maddox # # License: BSD (3-clause) -from multiprocessing import cpu_count import os +from multiprocessing import cpu_count from os.path import join from zipfile import ZipFile import numpy as np -from ..io import read_wav, write_wav +from .. import visual as vis from .._parallel import parallel_func +from .._utils import _get_user_home_path, fetch_data_file +from ..io import read_wav, write_wav from ._stimuli import window_edges -from .. import visual as vis -from .._utils import fetch_data_file, _get_user_home_path _fs_binary = 40e3 # the sampling rate of the original corpus binaries _rms_binary = 0.099977227591239365 # the RMS of the original corpus binaries _rms_prepped = 0.01 # the RMS for preparation of the whole corpus at an fs -_sexes = { - 'male': 0, - 'female': 1, - 'm': 0, - 'f': 1, - 0: 0, - 1: 1} -_talker_nums = { - '0': 0, - '1': 1, - '2': 2, - '3': 3, - 0: 0, - 1: 1, - 2: 2, - 3: 3} +_sexes = {"male": 0, "female": 1, "m": 0, "f": 1, 0: 0, 1: 1} +_talker_nums = {"0": 0, "1": 1, "2": 2, "3": 3, 0: 0, 1: 1, 2: 2, 3: 3} _callsigns = { - 'charlie': 0, - 'ringo': 1, - 'laker': 2, - 'hopper': 3, - 'arrow': 4, - 'tiger': 5, - 'eagle': 6, - 'baron': 7, - 'c': 0, - 'r': 1, - 'l': 2, - 'h': 3, - 'a': 4, - 't': 5, - 'e': 6, - 'b': 7, + "charlie": 0, + "ringo": 1, + "laker": 2, + "hopper": 3, + "arrow": 4, + "tiger": 5, + "eagle": 6, + "baron": 7, + "c": 0, + "r": 1, + "l": 2, + "h": 3, + "a": 4, + "t": 5, + "e": 6, + "b": 7, 0: 0, 1: 1, 2: 2, @@ -62,37 +47,39 @@ 4: 4, 5: 5, 6: 6, - 7: 7} + 7: 7, +} _colors = { - 'blue': 0, - 'red': 1, - 'white': 2, - 'green': 3, - 'b': 0, - 'r': 1, - 'w': 2, - 'g': 3, + "blue": 0, + "red": 1, + "white": 2, + "green": 3, + "b": 0, + "r": 1, + "w": 2, + "g": 3, 0: 0, 1: 1, 2: 2, - 3: 3} + 3: 3, +} _numbers = { - 'one': 0, - 'two': 1, - 'three': 2, - 'four': 3, - 'five': 4, - 'six': 5, - 'seven': 6, - 'eight': 7, - '1': 0, - '2': 1, - '3': 2, - '4': 3, - '5': 4, - '6': 5, - '7': 6, - '8': 7, + "one": 0, + "two": 1, + "three": 2, + "four": 3, + "five": 4, + "six": 5, + "seven": 6, + "eight": 7, + "1": 0, + "2": 1, + "3": 2, + "4": 3, + "5": 4, + "6": 5, + "7": 6, + "8": 7, 0: 0, 1: 1, 2: 2, @@ -100,7 +87,8 @@ 4: 4, 5: 5, 6: 6, - 7: 7} + 7: 7, +} _n_sexes = 2 _n_talkers = 4 @@ -110,64 +98,72 @@ def _check(name, value): - if name.lower() == 'sex': + if name.lower() == "sex": param_dict = _sexes - elif name.lower() == 'talker_num': + elif name.lower() == "talker_num": param_dict = _talker_nums - elif name.lower() == 'callsign': + elif name.lower() == "callsign": param_dict = _callsigns - elif name.lower() == 'color': + elif name.lower() == "color": param_dict = _colors - elif name.lower() == 'number': + elif name.lower() == "number": param_dict = _numbers if isinstance(value, str): value = value.lower() - if value in param_dict.keys(): + if value in param_dict: return param_dict[value] else: - raise ValueError('{} is not a valid {}. Legal values are: {}' - .format(value, name, - sorted(k for k in param_dict.keys() - if isinstance(k, int)))) + raise ValueError( + f"{value} is not a valid {name}. Legal values are: " + f"{sorted(k for k in param_dict if isinstance(k, int))}" + ) def _get_talker_zip_file(sex, talker_num): talker_num_raw = _n_talkers * _sexes[sex] + _talker_nums[talker_num] - fn = fetch_data_file('crm/Talker%i.zip' % talker_num_raw) + fn = fetch_data_file("crm/Talker%i.zip" % talker_num_raw) return fn # Read a raw binary CRM file -def _read_binary(zip_file, callsign, color, number, - ramp_dur=0.01): +def _read_binary(zip_file, callsign, color, number, ramp_dur=0.01): talk_path = zip_file.filelist[0].orig_filename[:8] - raw = zip_file.read(talk_path + '/%02i%02i%02i.BIN' % ( - _callsigns[callsign], _colors[color], _numbers[number])) - x = np.frombuffer(raw, ' max_wait: - raise ValueError('min_wait must be <= max_wait') + raise ValueError("min_wait must be <= max_wait") start_time = ec.current_time mouse_cursor = ec.window._mouse_cursor cursor = ec.window.get_system_mouse_cursor(ec.window.CURSOR_HAND) colors = [c.lower() for c in colors] - units = 'norm' + units = "norm" vert = float(ec.window_size_pix[0]) / ec.window_size_pix[1] h_spacing = 0.1 v_spacing = h_spacing * vert @@ -395,57 +444,55 @@ def crm_response_menu(ec, colors=['blue', 'red', 'white', 'green'], colors_rgb = [[0, 0, 1], [1, 0, 0], [1, 1, 1], [0, 0.85, 0]] n_numbers = len(numbers) n_colors = len(colors) - h_start = -(n_numbers - 1) * h_spacing / 2. - v_start = (n_colors - 1) * v_spacing / 2. + h_start = -(n_numbers - 1) * h_spacing / 2.0 + v_start = (n_colors - 1) * v_spacing / 2.0 font_size = (72 / ec.dpi) * height * ec.window_size_pix[1] / 2 - h_nudge = h_spacing / 8. - v_nudge = v_spacing / 20. + h_nudge = h_spacing / 8.0 + v_nudge = v_spacing / 20.0 - colors = [_check('color', color) for color in colors] - numbers = [str(_check('number', number) + 1) for number in numbers] + colors = [_check("color", color) for color in colors] + numbers = [str(_check("number", number) + 1) for number in numbers] - if (len(colors) != len(np.unique(colors)) or - len(numbers) != len(np.unique(numbers))): - raise ValueError('There can be no repeated colors or numbers in the ' - 'menu.') + if len(colors) != len(np.unique(colors)) or len(numbers) != len(np.unique(numbers)): + raise ValueError("There can be no repeated colors or numbers in the " "menu.") # Draw the buttons rects = [] for ni, number in enumerate(numbers): for ci, color in enumerate(colors): - pos = [ni * h_spacing + h_start, - -ci * v_spacing + v_start, - width, height] - rects += [vis.Rectangle( - ec, pos, units=units, - fill_color=colors_rgb[color])] + pos = [ni * h_spacing + h_start, -ci * v_spacing + v_start, width, height] + rects += [vis.Rectangle(ec, pos, units=units, fill_color=colors_rgb[color])] rects[-1].draw() - ec.screen_text(number, [pos[0] + h_nudge, pos[1] + v_nudge], - color='black', - wrap=False, units=units, font_size=font_size) + ec.screen_text( + number, + [pos[0] + h_nudge, pos[1] + v_nudge], + color="black", + wrap=False, + units=units, + font_size=font_size, + ) ec.flip() - ec.write_data_line('crm_menu') + ec.write_data_line("crm_menu") # Wait for min_wait and get the click while ec.current_time - start_time < min_wait: ec.check_force_quit() ec.window.set_mouse_cursor(cursor) max_wait = np.maximum(0, max_wait - (ec.current_time - start_time)) - but = ec.wait_for_click_on(rects, max_wait=max_wait, - live_buttons='left')[1] + but = ec.wait_for_click_on(rects, max_wait=max_wait, live_buttons="left")[1] ec.flip() ec.window.set_mouse_cursor(mouse_cursor) if but is not None: sub = np.unravel_index(but, (n_numbers, n_colors)) - resp = ('brwg'[colors[sub[1]]], numbers[sub[0]]) - ec.write_data_line('crm_response', resp[0] + ',' + resp[1]) + resp = ("brwg"[colors[sub[1]]], numbers[sub[0]]) + ec.write_data_line("crm_response", resp[0] + "," + resp[1]) return resp else: - ec.write_data_line('crm_timeout') + ec.write_data_line("crm_timeout") return (None, None) -class CRMPreload(object): +class CRMPreload: """Store the CRM corpus in memory for fast access. Parameters @@ -466,13 +513,13 @@ class CRMPreload(object): where the raw CRM originals are stored. """ - def __init__(self, fs, ref_rms=0.01, ramp_dur=0.01, stereo=False, - path=None): + def __init__(self, fs, ref_rms=0.01, ramp_dur=0.01, stereo=False, path=None): if path is None: - path = join(_get_user_home_path(), '.expyfun', 'data', 'crm') + path = join(_get_user_home_path(), ".expyfun", "data", "crm") if not os.path.isdir(join(path, str(fs))): - raise RuntimeError('prepare_corpus has not yet been run ' - 'for sampling rate of %i' % fs) + raise RuntimeError( + "prepare_corpus has not yet been run " "for sampling rate of %i" % fs + ) self._excluded = [] self._all_stim = {} for sex in range(_n_sexes): @@ -480,12 +527,20 @@ def __init__(self, fs, ref_rms=0.01, ramp_dur=0.01, stereo=False, for cal in range(_n_callsigns): for col in range(_n_colors): for num in range(_n_numbers): - stim_id = '%i%i%i%i%i' % (sex, tal, cal, col, num) + stim_id = "%i%i%i%i%i" % (sex, tal, cal, col, num) try: - self._all_stim[stim_id] = \ - crm_sentence(fs, sex, tal, cal, col, num, - ref_rms, ramp_dur, stereo, - path) + self._all_stim[stim_id] = crm_sentence( + fs, + sex, + tal, + cal, + col, + num, + ref_rms, + ramp_dur, + stereo, + path, + ) except Exception: self._excluded += [stim_id] @@ -528,11 +583,15 @@ def sentence(self, sex, talker_num, callsign, color, number): index of ``'1'`` is 0, so care must be taken if using indices for the number argument. """ - stim_id = '%i%i%i%i%i' % ( - _check('sex', sex), _check('talker_num', talker_num), - _check('callsign', callsign), _check('color', color), - _check('number', number)) + stim_id = "%i%i%i%i%i" % ( + _check("sex", sex), + _check("talker_num", talker_num), + _check("callsign", callsign), + _check("color", color), + _check("number", number), + ) if stim_id in self._excluded: - raise RuntimeError('prepare_corpus has not yet been run for the ' - 'requested talker') + raise RuntimeError( + "prepare_corpus has not yet been run for the " "requested talker" + ) return self._all_stim[stim_id].copy() diff --git a/expyfun/stimuli/_hrtf.py b/expyfun/stimuli/_hrtf.py index 55049432..85477da5 100644 --- a/expyfun/stimuli/_hrtf.py +++ b/expyfun/stimuli/_hrtf.py @@ -1,11 +1,10 @@ -"""Stimulus generation functions -""" +"""Stimulus generation functions""" import numpy as np -from ..io import read_hdf5 from .._fixes import irfft -from .._utils import fetch_data_file, _fix_audio_dims +from .._utils import _fix_audio_dims, fetch_data_file +from ..io import read_hdf5 # This was used to generate "barb_anech.gz": # @@ -54,42 +53,42 @@ def _get_hrtf(angle, source, fs, interp=False): Functions", Australian Government Department of Defence: Defence Science and Technology Organization, Melbourne, Victoria, Australia, 2007. """ - fname = fetch_data_file('hrtf/{0}_{1}.hdf5'.format(source, fs)) + fname = fetch_data_file(f"hrtf/{source}_{fs}.hdf5") data = read_hdf5(fname) - angles = data['angles'] + angles = data["angles"] leftward = False read_angle = float(angle) if angle < 0: leftward = True read_angle = float(-angle) if read_angle not in angles and not interp: - raise ValueError('angle "{0}" must be one of +/-{1}' - ''.format(angle, list(angles))) - brir = data['brir'] + raise ValueError(f'angle "{angle}" must be one of +/-{list(angles)}' "") + brir = data["brir"] if read_angle in angles: interp = False if not interp: idx = np.where(angles == read_angle)[0] if len(idx) != 1: - raise ValueError('angle "{0}" not uniquely found in angles' - ''.format(angle)) + raise ValueError(f'angle "{angle}" not uniquely found in angles' "") brir = brir[idx[0]] else: # interpolation - if source != 'cipic': - raise ValueError('source must be ''cipic'' when interp=True') + if source != "cipic": + raise ValueError("source must be " "cipic" " when interp=True") # pull in files containing known hrtfs and extract magnitude and phase - fname = fetch_data_file('hrtf/pair_cipic_{0}.hdf5'.format(fs)) + fname = fetch_data_file(f"hrtf/pair_cipic_{fs}.hdf5") data = read_hdf5(fname) - hrtf_amp = data['hrtf_amp'] - hrtf_phase = data['hrtf_phase'] - pairs = data['pairs'] + hrtf_amp = data["hrtf_amp"] + hrtf_phase = data["hrtf_phase"] + pairs = data["pairs"] # isolate appropriate pair of amplitude and phase idx = np.searchsorted(angles, read_angle) if idx > len(pairs): - raise ValueError('angle magnitude "{0}" must be smaller than "{1}"' - ''.format(read_angle, pairs[-1][-1])) + raise ValueError( + f'angle magnitude "{read_angle}" must be smaller than "{pairs[-1][-1]}"' + "" + ) knowns = np.array([angles[idx - 1], angles[idx]]) index = np.where(pairs == knowns)[0][0] hrtf_amp = hrtf_amp[index] @@ -98,20 +97,18 @@ def _get_hrtf(angle, source, fs, interp=False): # weighted averages of log magnitude and unwrapped phase step = float(knowns[1] - knowns[0]) weights = (step - np.abs(read_angle - knowns)) / step - hrtf_amp = np.prod(hrtf_amp ** weights[:, np.newaxis, np.newaxis], - axis=0) - hrtf_phase = np.sum(hrtf_phase * weights[:, np.newaxis, np.newaxis], - axis=0) + hrtf_amp = np.prod(hrtf_amp ** weights[:, np.newaxis, np.newaxis], axis=0) + hrtf_phase = np.sum(hrtf_phase * weights[:, np.newaxis, np.newaxis], axis=0) # reconstruct hrtf and convert to time domain hrtf = hrtf_amp * np.exp(1j * hrtf_phase) brir = irfft(hrtf, int(hrtf.shape[-1])) - return brir, data['fs'], leftward + return brir, data["fs"], leftward -def convolve_hrtf(data, fs, angle, source='cipic', interp=False): - """Convolve a signal with a head-related transfer function +def convolve_hrtf(data, fs, angle, source="cipic", interp=False): + """Convolve a signal with a head-related transfer function. Technically we will be convolving with binaural room impulse responses (BRIRs), but HRTFs (freq-domain equiv. representations) @@ -147,7 +144,7 @@ def convolve_hrtf(data, fs, angle, source='cipic', interp=False): Additional documentation: - http://earlab.bu.edu/databases/collections/cipic/documentation/hrir_data_documentation.pdf # noqa + http://earlab.bu.edu/databases/collections/cipic/documentation/hrir_data_documentation.pdf The data were modified to suit our experimental needs. Below is the licensing information for the CIPIC data: @@ -183,16 +180,17 @@ def convolve_hrtf(data, fs, angle, source='cipic', interp=False): CIPIC- Center for Image Processing and Integrated Computing University of California 1 Shields Avenue Davis, CA 95616-8553 - """ + """ # noqa: E501 fs = float(fs) angle = float(angle) - known_sources = ['barb', 'cipic'] + known_sources = ["barb", "cipic"] known_fs = [24414, 44100] # must be sorted if source not in known_sources: - raise ValueError('Source "{0}" unknown, must be one of {1}' - ''.format(source, known_sources)) + raise ValueError( + f'Source "{source}" unknown, must be one of {known_sources}' "" + ) if not isinstance(interp, bool): - raise ValueError('interp must be bool') + raise ValueError("interp must be bool") data = np.array(data, np.float64) data = _fix_audio_dims(data, n_channels=1).ravel() @@ -205,6 +203,7 @@ def convolve_hrtf(data, fs, angle, source='cipic', interp=False): order = [1, 0] if leftward else [0, 1] if not np.allclose(brir_fs, fs, rtol=0, atol=0.5): from mne.filter import resample + brir = [resample(b, fs, brir_fs) for b in brir] out = np.array([np.convolve(data, brir[o]) for o in order]) return out diff --git a/expyfun/stimuli/_mls.py b/expyfun/stimuli/_mls.py index cac8d35b..f5e453cb 100644 --- a/expyfun/stimuli/_mls.py +++ b/expyfun/stimuli/_mls.py @@ -1,26 +1,25 @@ -# -*- coding: utf-8 -*- # Copyright (c) 2014, LABSN. # Distributed under the (new) BSD License. See LICENSE.txt for more info. -"""Maximum-length sequence (MLS) impulse-response finding functions -""" +"""Maximum-length sequence (MLS) impulse-response finding functions""" from os import path as op + import numpy as np from .._fixes import irfft, rfft -from .._utils import verbose_dec, logger +from .._utils import logger, verbose_dec -_mls_file = op.join(op.dirname(__file__), '..', 'data', 'mls.bin') +_mls_file = op.join(op.dirname(__file__), "..", "data", "mls.bin") _max_bits = 14 # determined by how the file was made, see _max_len_wrapper def _check_n_bits(n_bits): """Helper to make sure we have a usable number of bits""" if not isinstance(n_bits, int): - raise TypeError('n_bits must be an integer') + raise TypeError("n_bits must be an integer") if n_bits < 2 or n_bits > _max_bits: - raise ValueError('n_bits must be between 2 and %s' % _max_bits) + raise ValueError("n_bits must be between 2 and %s" % _max_bits) def _max_len_wrapper(n_bits): @@ -40,21 +39,21 @@ def _max_len_wrapper(n_bits): n_bits = int(n_bits) _check_n_bits(n_bits) # This was used to generate the sequences: - #from scipy.signal import max_len_seq - #_mlss = np.concatenate([max_len_seq(n) > 0 + # from scipy.signal import max_len_seq + # _mlss = np.concatenate([max_len_seq(n) > 0 # for n in range(2, _max_bits + 1)]) - #with open(_mls_file, 'wb') as fid: + # with open(_mls_file, 'wb') as fid: # fid.write(_mlss.tostring()) - _lims = np.cumsum([0] + [2 ** n - 1 for n in range(2, 15)]) + _lims = np.cumsum([0] + [2**n - 1 for n in range(2, 15)]) _mlss = np.fromfile(_mls_file, dtype=bool) _mlss = [_mlss[l1:l2].copy() for l1, l2 in zip(_lims[:-1], _lims[1:])] - return _mlss[n_bits - 2] * 2. - 1 + return _mlss[n_bits - 2] * 2.0 - 1 # Once this is in upstream scipy, we can add this: -#try: +# try: # from scipy.signal import max_len_seq as _max_len_seq -#except: +# except: _max_len_seq = _max_len_wrapper @@ -69,11 +68,10 @@ def repeated_mls(n_samp, n_repeats): The number of repeats to use. """ if not isinstance(n_samp, int) or not isinstance(n_repeats, int): - raise TypeError('n_samp and n_repeats must both be integers') + raise TypeError("n_samp and n_repeats must both be integers") n_bits = max(int(np.ceil(np.log2(n_samp + 1))), 2) if n_bits > _max_bits: - raise ValueError('Only lengths up to %s supported' - % (2 ** _max_bits - 1)) + raise ValueError("Only lengths up to %s supported" % (2**_max_bits - 1)) mls = 0.5 * _max_len_seq(n_bits) + 0.5 n_resp = len(mls) * (n_repeats + 1) - 1 mls = np.tile(mls, n_repeats) @@ -96,37 +94,43 @@ def compute_mls_impulse_response(response, mls, n_repeats, verbose=None): If not ``None``, override default verbose level. """ if mls.ndim != 1 or response.ndim != 1: - raise ValueError('response and mls must both be one-dimensional') + raise ValueError("response and mls must both be one-dimensional") if not isinstance(n_repeats, int): - raise TypeError('n_repeats must be an integer') + raise TypeError("n_repeats must be an integer") if not np.array_equal(np.sort(np.unique(mls)), [0, 1]): - raise ValueError('MLS must be sequence of 0s and 1s') + raise ValueError("MLS must be sequence of 0s and 1s") if mls.size % n_repeats != 0: - raise ValueError('MLS length (%s) is not a multiple of the number ' - 'of repeats (%s)' % (mls.size, n_repeats)) + raise ValueError( + "MLS length (%s) is not a multiple of the number " + "of repeats (%s)" % (mls.size, n_repeats) + ) mls_len = mls.size // n_repeats n_bits = int(np.round(np.log2(mls_len + 1))) - n_check = 2 ** n_bits + n_check = 2**n_bits if n_check != mls_len + 1: - raise RuntimeError('length of MLS must be one shorter than a power ' - 'of 2, got %s (close to %s)' % (mls_len, n_check)) - logger.info('MLS using %s bits detected' % n_bits) + raise RuntimeError( + "length of MLS must be one shorter than a power " + "of 2, got %s (close to %s)" % (mls_len, n_check) + ) + logger.info("MLS using %s bits detected" % n_bits) n_len = response.size + 1 if n_len % mls_len != 0: n_rep = int(np.round(n_len / float(mls_len))) n_len = mls_len * n_rep - 1 - raise ValueError('length of data must be one shorter than a ' - 'multiple of the MLS length (%s), found a length ' - 'of %s which is close to %s (%s repeats)' - % (mls_len, response.size, n_len, n_rep)) + raise ValueError( + "length of data must be one shorter than a " + "multiple of the MLS length (%s), found a length " + "of %s which is close to %s (%s repeats)" + % (mls_len, response.size, n_len, n_rep) + ) # Now that we know our signal, we can actually deconvolve. # First, wrap the end back to the beginning - resp_wrap = response[:n_repeats * mls_len].copy() - resp_wrap[:mls_len - 1] += response[n_repeats * mls_len:] + resp_wrap = response[: n_repeats * mls_len].copy() + resp_wrap[: mls_len - 1] += response[n_repeats * mls_len :] # Compute the circular crosscorrelation, w/correction for MLS scaling correction = np.empty(len(mls) // 2 + 1) - correction.fill(1. / (2 ** (n_bits - 2) * n_repeats)) - correction[0] = 1. / ((4 ** (n_bits - 1)) * n_repeats) + correction.fill(1.0 / (2 ** (n_bits - 2) * n_repeats)) + correction[0] = 1.0 / ((4 ** (n_bits - 1)) * n_repeats) y = irfft(correction * rfft(resp_wrap) * rfft(mls).conj()) # Average out repeats h_est = np.mean(np.reshape(y, (n_repeats, mls_len)), axis=0) diff --git a/expyfun/stimuli/_stimuli.py b/expyfun/stimuli/_stimuli.py index 627e8e77..0adbbd45 100644 --- a/expyfun/stimuli/_stimuli.py +++ b/expyfun/stimuli/_stimuli.py @@ -1,17 +1,17 @@ -# -*- coding: utf-8 -*- """Generic stimulus generation functions.""" import warnings +from threading import Timer + import numpy as np from scipy import signal -from threading import Timer -from ..io import read_wav from .._sound_controllers import SoundPlayer -from .._utils import _wait_secs, string_types +from .._utils import _wait_secs +from ..io import read_wav -def window_edges(sig, fs, dur=0.01, axis=-1, window='hann', edges='both'): +def window_edges(sig, fs, dur=0.01, axis=-1, window="hann", edges="both"): """Window the edges of a signal (e.g., to prevent "pops") Parameters @@ -40,25 +40,27 @@ def window_edges(sig, fs, dur=0.01, axis=-1, window='hann', edges='both'): sig_len = sig.shape[axis] win_len = int(dur * fs) if win_len > sig_len: - raise RuntimeError('cannot create window of size {0} samples (dur={1})' - 'for signal with length {2}' - ''.format(win_len, dur, sig_len)) - if window == 'dpss': + raise RuntimeError( + f"cannot create window of size {win_len} samples (dur={dur})" + f"for signal with length {sig_len}" + "" + ) + if window == "dpss": from mne.time_frequency.multitaper import dpss_windows + win = dpss_windows(2 * win_len + 1, 1, 1)[0][0][:win_len] win -= win[0] win /= win.max() else: win = signal.windows.get_window(window, 2 * win_len)[:win_len] - valid_edges = ('leading', 'trailing', 'both') + valid_edges = ("leading", "trailing", "both") if edges not in valid_edges: - raise ValueError('edges must be one of {0}, not "{1}"' - ''.format(valid_edges, edges)) + raise ValueError(f'edges must be one of {valid_edges}, not "{edges}"' "") # now we can actually do the calculation flattop = np.ones(sig_len, dtype=np.float64) - if edges in ('trailing', 'both'): # eliminate trailing + if edges in ("trailing", "both"): # eliminate trailing flattop[-win_len:] *= win[::-1] - if edges in ('leading', 'both'): # eliminate leading + if edges in ("leading", "both"): # eliminate leading flattop[:win_len] *= win shape = np.ones_like(sig.shape) shape[axis] = sig.shape[axis] @@ -82,7 +84,7 @@ def rms(data, axis=-1, keepdims=False): return np.sqrt(np.mean(data * data, axis=axis, keepdims=keepdims)) -def play_sound(sound, fs=None, norm=True, wait=False, backend='auto'): +def play_sound(sound, fs=None, norm=True, wait=False, backend="auto"): """Play a sound Parameters @@ -108,20 +110,20 @@ def play_sound(sound, fs=None, norm=True, wait=False, backend='auto'): """ sound = np.array(sound) fs_default = 44100 - if isinstance(sound, string_types): + if isinstance(sound, str): sound, fs_default = read_wav(sound) if fs is None: fs = fs_default if sound.ndim == 1: # make it stereo sound = np.array((sound, sound)) if sound.ndim != 2: - raise ValueError('sound must be 1- or 2-dimensional') + raise ValueError("sound must be 1- or 2-dimensional") if norm: m = np.abs(sound).max() * 1.000001 m = m if m != 0 else 1 sound /= m - if np.abs(sound).max() > 1.: - warnings.warn('Sound exceeds +/-1, will clip') + if np.abs(sound).max() > 1.0: + warnings.warn("Sound exceeds +/-1, will clip") # For rtmixer it's possible this will fail on some configurations if # resampling isn't built in to the backend; when we hit this we can # try/except here and do the resampling ourselves. @@ -133,12 +135,12 @@ def play_sound(sound, fs=None, norm=True, wait=False, backend='auto'): _wait_secs(dur) else: del_wait += dur - if hasattr(snd, 'delete'): # for backward compatibility + if hasattr(snd, "delete"): # for backward compatibility Timer(del_wait, snd.delete).start() return snd -def add_pad(sounds, alignment='start'): +def add_pad(sounds, alignment="start"): """Add sounds of different lengths and channel counts together Parameters @@ -162,28 +164,27 @@ def add_pad(sounds, alignment='start'): Even if the original sounds were all 0- or 1-dimensional, the output will be 2-dimensional (channels, samples). """ - if alignment not in ['start', 'center', 'end']: - raise(ValueError("alignment must be either 'start', 'center', " - "or 'end'")) + if alignment not in ["start", "center", "end"]: + raise ValueError("alignment must be either 'start', 'center', " "or 'end'") x = [np.atleast_2d(y) for y in sounds] if not np.all(y.ndim == 2 for y in x): - raise ValueError('Sound data must have no more than 2 dimensions.') + raise ValueError("Sound data must have no more than 2 dimensions.") shapes = [y.shape for y in x] ch_max, len_max = np.max(shapes, axis=0) if ch_max > 2: - raise ValueError('Only 1- and 2-channel sounds are supported.') + raise ValueError("Only 1- and 2-channel sounds are supported.") for xi, (ch, length) in enumerate(shapes): if length < len_max: - if alignment == 'start': + if alignment == "start": n_pre = 0 n_post = len_max - length - elif alignment == 'center': + elif alignment == "center": n_pre = (len_max - length) // 2 n_post = len_max - length - n_pre - elif alignment == 'end': + elif alignment == "end": n_pre = len_max - length n_post = 0 - x[xi] = np.pad(x[xi], ((0, 0), (n_pre, n_post)), 'constant') + x[xi] = np.pad(x[xi], ((0, 0), (n_pre, n_post)), "constant") if ch < ch_max: x[xi] = np.tile(x[xi], [ch_max, 1]) return np.sum(x, 0) diff --git a/expyfun/stimuli/_texture.py b/expyfun/stimuli/_texture.py index bff7c1cd..ba754754 100644 --- a/expyfun/stimuli/_texture.py +++ b/expyfun/stimuli/_texture.py @@ -1,14 +1,14 @@ #!/usr/bin/env python2 -# -*- coding: utf-8 -*- """Texture (ERB-spaced) stimulus generation functions.""" # adapted (with permission) from code by Hari Bharadwaj -import numpy as np import warnings -from ._stimuli import rms, window_edges +import numpy as np + from .._fixes import irfft +from ._stimuli import rms, window_edges def _cams(f): @@ -18,7 +18,7 @@ def _cams(f): def _inv_cams(E): """Compute cams inverse.""" - return (10 ** (E / 21.4) - 1.) / 0.00437 + return (10 ** (E / 21.4) - 1.0) / 0.00437 def _scale_sound(x): @@ -28,21 +28,30 @@ def _scale_sound(x): def _make_narrow_noise(bw, f_c, dur, fs, ramp_dur, rng): """Make narrow-band noise using FFT.""" - f_min, f_max = f_c - bw / 2., f_c + bw / 2. + f_min, f_max = f_c - bw / 2.0, f_c + bw / 2.0 t = np.arange(int(round(dur * fs))) / fs # Make Noise - f_step = 1. / dur # Frequency bin size + f_step = 1.0 / dur # Frequency bin size h_min = int(np.ceil(f_min / f_step)) h_max = int(np.floor(f_max / f_step)) + 1 phase = rng.rand(h_max - h_min) * 2 * np.pi noise = np.zeros(len(t) // 2 + 1, np.complex128) noise[h_min:h_max] = np.exp(1j * phase) - return window_edges(irfft(noise)[:len(t)], fs, ramp_dur, window='dpss') - - -def texture_ERB(n_freqs=20, n_coh=None, rho=1., seq=('inc', 'nb', 'inc', 'nb'), - fs=24414.0625, dur=1., SAM_freq=7., random_state=None, - freq_lims=(200, 8000), verbose=True): + return window_edges(irfft(noise)[: len(t)], fs, ramp_dur, window="dpss") + + +def texture_ERB( + n_freqs=20, + n_coh=None, + rho=1.0, + seq=("inc", "nb", "inc", "nb"), + fs=24414.0625, + dur=1.0, + SAM_freq=7.0, + random_state=None, + freq_lims=(200, 8000), + verbose=True, +): """Create ERB texture stimulus Parameters @@ -83,14 +92,16 @@ def texture_ERB(n_freqs=20, n_coh=None, rho=1., seq=('inc', 'nb', 'inc', 'nb'), """ from mne.time_frequency.multitaper import dpss_windows from mne.utils import check_random_state + if not isinstance(seq, (list, tuple, np.ndarray)): - raise TypeError('seq must be list, tuple, or ndarray, got %s' - % type(seq)) - known_seqs = ('inc', 'nb', 'sam') + raise TypeError("seq must be list, tuple, or ndarray, got %s" % type(seq)) + known_seqs = ("inc", "nb", "sam") for si, s in enumerate(seq): if s not in known_seqs: - raise ValueError('all entries in seq must be one of %s, got ' - 'seq[%s]=%s' % (known_seqs, si, s)) + raise ValueError( + "all entries in seq must be one of %s, got " + "seq[%s]=%s" % (known_seqs, si, s) + ) fs = float(fs) rng = check_random_state(random_state) n_coh = int(np.round(n_freqs * 0.8)) if n_coh is None else n_coh @@ -102,10 +113,12 @@ def texture_ERB(n_freqs=20, n_coh=None, rho=1., seq=('inc', 'nb', 'inc', 'nb'), del f_max spacing_ERBs = n_ERBs / float(n_freqs - 1) if verbose: - print('This stim will have successive tones separated by %2.2f ERBs' - % spacing_ERBs) + print( + "This stim will have successive tones separated by %2.2f ERBs" + % spacing_ERBs + ) if spacing_ERBs < 1.0: - warnings.warn('The spacing between tones is LESS THAN 1 ERB!') + warnings.warn("The spacing between tones is LESS THAN 1 ERB!") # Make a filter whose impulse response is purely positive (to avoid phase # jumps) so that the filtered envelope is purely positive. Use a DPSS @@ -113,49 +126,51 @@ def texture_ERB(n_freqs=20, n_coh=None, rho=1., seq=('inc', 'nb', 'inc', 'nb'), # filterlength, we need to restrict time-bandwidth product to a minimum. # Thus we need a length*bw = 2 => length = 2/bw (second). Hence filter # coefficients are calculated as follows: - b = dpss_windows(int(np.floor(2 * fs / 100.)), 1., 1)[0][0] + b = dpss_windows(int(np.floor(2 * fs / 100.0)), 1.0, 1)[0][0] b -= b[0] b /= b.sum() # Incoherent envrate = 14 bw = 20 - incoh = 0. + incoh = 0.0 for k in range(n_freqs): f = _inv_cams(_cams(f_min) + spacing_ERBs * k) env = _make_narrow_noise(bw, envrate, dur, fs, rise, rng) env[env < 0] = 0 - env = np.convolve(b, env)[:len(t)] - incoh += _scale_sound(window_edges( - env * np.sin(2 * np.pi * f * t), fs, rise, window='dpss')) + env = np.convolve(b, env)[: len(t)] + incoh += _scale_sound( + window_edges(env * np.sin(2 * np.pi * f * t), fs, rise, window="dpss") + ) incoh /= rms(incoh) # Coherent (noise band) - stims = dict(inc=0., nb=0., sam=0.) + stims = dict(inc=0.0, nb=0.0, sam=0.0) group = np.sort(rng.permutation(np.arange(n_freqs))[:n_coh]) for kind in known_seqs: - if kind == 'nb': # noise band + if kind == "nb": # noise band env_coh = _make_narrow_noise(bw, envrate, dur, fs, rise, rng) else: # 'nb' or 'inc' - env_coh = 0.5 + np.sin(2 * np.pi * SAM_freq * t) / 2. - env_coh = window_edges(env_coh, fs, rise, window='dpss') + env_coh = 0.5 + np.sin(2 * np.pi * SAM_freq * t) / 2.0 + env_coh = window_edges(env_coh, fs, rise, window="dpss") env_coh[env_coh < 0] = 0 - env_coh = np.convolve(b, env_coh)[:len(t)] - if kind == 'inc': + env_coh = np.convolve(b, env_coh)[: len(t)] + if kind == "inc": use_group = [] # no coherent ones else: # 'nb' or 'sam' use_group = group for k in range(n_freqs): f = _inv_cams(_cams(f_min) + spacing_ERBs * k) env_inc = _make_narrow_noise(bw, envrate, dur, fs, rise, rng) - env_inc[env_inc < 0] = 0. - env_inc = np.convolve(b, env_inc)[:len(t)] + env_inc[env_inc < 0] = 0.0 + env_inc = np.convolve(b, env_inc)[: len(t)] if k in use_group: - env = np.sqrt(rho) * env_coh + np.sqrt(1 - rho ** 2) * env_inc + env = np.sqrt(rho) * env_coh + np.sqrt(1 - rho**2) * env_inc else: env = env_inc - stims[kind] += _scale_sound(window_edges( - env * np.sin(2 * np.pi * f * t), fs, rise, window='dpss')) + stims[kind] += _scale_sound( + window_edges(env * np.sin(2 * np.pi * f * t), fs, rise, window="dpss") + ) stims[kind] /= rms(stims[kind]) stim = np.concatenate([stims[s] for s in seq]) stim = 0.01 * stim / rms(stim) diff --git a/expyfun/stimuli/_tracker.py b/expyfun/stimuli/_tracker.py index f14832ed..bd599ed6 100644 --- a/expyfun/stimuli/_tracker.py +++ b/expyfun/stimuli/_tracker.py @@ -1,15 +1,15 @@ -"""Adaptive tracks for psychophysics (individual, or multiple randomly dealt) -""" +"""Adaptive tracks for psychophysics (individual, or multiple randomly dealt)""" # Author: Ross Maddox # # License: BSD (3-clause) -import numpy as np -import time -from scipy.stats import binom import json +import time import warnings +import numpy as np +from scipy.stats import binom + from .. import ExperimentController @@ -17,29 +17,29 @@ # Set up the logging callback (use write_data_line or do nothing) # ============================================================================= def _callback_dummy(event_type, value=None, timestamp=None): - """Take the arguments of write_data_line, but do nothing. - """ + """Take the arguments of write_data_line, but do nothing.""" pass def _check_callback(callback): - """Check to see if the callback is of an allowable type. - """ + """Check to see if the callback is of an allowable type.""" if callback is None: callback = _callback_dummy elif isinstance(callback, ExperimentController): callback = callback.write_data_line if not callable(callback): - raise TypeError('callback must be a callable, None, or an instance of ' - 'ExperimentController.') + raise TypeError( + "callback must be a callable, None, or an instance of " + "ExperimentController." + ) return callback # ============================================================================= # Define the TrackerUD Class # ============================================================================= -class TrackerUD(object): +class TrackerUD: r"""Up-down adaptive tracker This class implements a standard up-down adaptive tracker object. Based on @@ -126,22 +126,34 @@ class TrackerUD(object): None. """ - def __init__(self, callback, up, down, step_size_up, step_size_down, - stop_reversals, stop_trials, start_value, change_indices=None, - change_rule='reversals', x_min=None, x_max=None, - repeat_limit='reversals'): + def __init__( + self, + callback, + up, + down, + step_size_up, + step_size_down, + stop_reversals, + stop_trials, + start_value, + change_indices=None, + change_rule="reversals", + x_min=None, + x_max=None, + repeat_limit="reversals", + ): self._callback = _check_callback(callback) if not isinstance(up, int): - raise ValueError('up must be an integer') + raise ValueError("up must be an integer") self._up = up if not isinstance(down, int): - raise ValueError('down must be an integer') + raise ValueError("down must be an integer") self._down = down - if stop_reversals != np.inf and type(stop_reversals) != int: - raise ValueError('stop_reversals must be an integer or np.inf') + if stop_reversals != np.inf and not isinstance(stop_reversals, int): + raise ValueError("stop_reversals must be an integer or np.inf") self._stop_reversals = stop_reversals - if stop_trials != np.inf and type(stop_trials) != int: - raise ValueError('stop_trials must be an integer or np.inf') + if stop_trials != np.inf and not isinstance(stop_trials, int): + raise ValueError("stop_trials must be an integer or np.inf") self._stop_trials = stop_trials self._start_value = start_value self._x_min = -np.inf if x_min is None else float(x_min) @@ -150,34 +162,41 @@ def __init__(self, callback, up, down, step_size_up, step_size_down, if change_indices is None: change_indices = [0] if not np.isscalar(step_size_up): - raise ValueError('If step_size_up is longer than 1, you must ' - 'specify change indices.') + raise ValueError( + "If step_size_up is longer than 1, you must " + "specify change indices." + ) if not np.isscalar(step_size_down): - raise ValueError('If step_size_down is longer than 1, you must' - ' specify change indices.') + raise ValueError( + "If step_size_down is longer than 1, you must" + " specify change indices." + ) self._change_indices = np.asarray(change_indices) - if change_rule not in ['trials', 'reversals']: - raise ValueError("change_rule must be either 'trials' or " - "'reversals'") + if change_rule not in ["trials", "reversals"]: + raise ValueError("change_rule must be either 'trials' or " "'reversals'") self._change_rule = change_rule step_size_up = np.atleast_1d(step_size_up) if change_indices != [0]: if len(step_size_up) != len(change_indices) + 1: - raise ValueError('If step_size_up is not scalar it must be one' - ' element longer than change_indices.') + raise ValueError( + "If step_size_up is not scalar it must be one" + " element longer than change_indices." + ) self._step_size_up = np.asarray(step_size_up, dtype=float) step_size_down = np.atleast_1d(step_size_down) if change_indices != [0]: if len(step_size_down) != len(change_indices) + 1: - raise ValueError('If step_size_down is not scalar it must be ' - 'one element longer than change_indices.') + raise ValueError( + "If step_size_down is not scalar it must be " + "one element longer than change_indices." + ) self._step_size_down = np.asarray(step_size_down, dtype=float) self._x = np.asarray([start_value], dtype=float) if not np.isscalar(start_value): - raise TypeError('start_value must be a scalar') + raise TypeError("start_value must be a scalar") self._x_current = float(start_value) self._responses = np.asarray([], dtype=bool) self._reversals = np.asarray([], dtype=int) @@ -193,25 +212,32 @@ def __init__(self, callback, up, down, step_size_up, step_size_down, self._limit_count = 0 # Now write the initialization data out - self._tracker_id = '%s-%s' % (id(self), int(round(time.time() * 1e6))) - self._callback('tracker_identify', json.dumps(dict( - tracker_id=self._tracker_id, - tracker_type='TrackerUD'))) - - self._callback('tracker_%s_init' % self._tracker_id, json.dumps(dict( - callback=None, - up=self._up, - down=self._down, - step_size_up=[float(s) for s in self._step_size_up], - step_size_down=[float(s) for s in self._step_size_down], - stop_reversals=self._stop_reversals, - stop_trials=self._stop_trials, - start_value=self._start_value, - change_indices=[int(s) for s in self._change_indices], - change_rule=self._change_rule, - x_min=self._x_min, - x_max=self._x_max, - repeat_limit=self._repeat_limit))) + self._tracker_id = "%s-%s" % (id(self), int(round(time.time() * 1e6))) + self._callback( + "tracker_identify", + json.dumps(dict(tracker_id=self._tracker_id, tracker_type="TrackerUD")), + ) + + self._callback( + "tracker_%s_init" % self._tracker_id, + json.dumps( + dict( + callback=None, + up=self._up, + down=self._down, + step_size_up=[float(s) for s in self._step_size_up], + step_size_down=[float(s) for s in self._step_size_down], + stop_reversals=self._stop_reversals, + stop_trials=self._stop_trials, + start_value=self._start_value, + change_indices=[int(s) for s in self._change_indices], + change_rule=self._change_rule, + x_min=self._x_min, + x_max=self._x_max, + repeat_limit=self._repeat_limit, + ) + ), + ) def respond(self, correct): """Update the tracker based on the last response. @@ -222,7 +248,7 @@ def respond(self, correct): Was the most recent subject response correct? """ if self._stopped: - raise RuntimeError('Tracker is stopped.') + raise RuntimeError("Tracker is stopped.") bound = False bad = False @@ -262,11 +288,9 @@ def respond(self, correct): if step_dir == 0: self._x = np.append(self._x, self._x[-1]) elif step_dir < 0: - self._x = np.append(self._x, self._x[-1] - - self._current_step_size_down) + self._x = np.append(self._x, self._x[-1] - self._current_step_size_down) elif step_dir > 0: - self._x = np.append(self._x, self._x[-1] + - self._current_step_size_up) + self._x = np.append(self._x, self._x[-1] + self._current_step_size_up) if self._x_min is not -np.inf: if self._x[-1] < self._x_min: @@ -274,7 +298,7 @@ def respond(self, correct): self._limit_count += 1 if bound: bad = True - if self._repeat_limit == 'reversals': + if self._repeat_limit == "reversals": reversal = True self._n_reversals += 1 if self._x_max is not np.inf: @@ -283,7 +307,7 @@ def respond(self, correct): self._limit_count += 1 if bound: bad = True - if self._repeat_limit == 'reversals': + if self._repeat_limit == "reversals": reversal = True self._n_reversals += 1 @@ -299,15 +323,19 @@ def respond(self, correct): if not self._stopped: self._x_current = self._x[-1] - self._callback('tracker_%s_respond' % self._tracker_id, - correct) + self._callback("tracker_%s_respond" % self._tracker_id, correct) else: self._x = self._x[:-1] self._callback( - 'tracker_%s_stop' % self._tracker_id, json.dumps(dict( - responses=[int(s) for s in self._responses], - reversals=[int(s) for s in self._reversals], - x=[float(s) for s in self._x]))) + "tracker_%s_stop" % self._tracker_id, + json.dumps( + dict( + responses=[int(s) for s in self._responses], + reversals=[int(s) for s in self._reversals], + x=[float(s) for s in self._x], + ) + ), + ) def check_valid(self, n_reversals): """If last reversals contain reversals exceeding x_min or x_max. @@ -323,8 +351,7 @@ def check_valid(self, n_reversals): True if none of the reversals are at x_min or x_max and False otherwise. """ - self._valid = (not self._bad_reversals[self._reversals != 0] - [-n_reversals:].any()) + self._valid = not self._bad_reversals[self._reversals != 0][-n_reversals:].any() return self._valid def _stop_here(self): @@ -335,14 +362,16 @@ def _stop_here(self): else: self._n_stop = False if self._n_stop and self._limit_count > 0: - warnings.warn('Tracker {} exceeded x_min or x_max bounds {} times.' - ''.format(self._tracker_id, self._limit_count)) + warnings.warn( + f"Tracker {self._tracker_id} exceeded x_min or x_max bounds " + f"{self._limit_count} times." + ) return self._n_stop def _step_index(self): - if self._change_rule.lower() == 'reversals': + if self._change_rule.lower() == "reversals": self._n_change = self._n_reversals - elif self._change_rule.lower() == 'trials': + elif self._change_rule.lower() == "trials": self._n_change = self._n_trials step_index = np.where(self._n_change >= self._change_indices)[0] if len(step_index) == 0 or np.array_equal(self._change_indices, [0]): @@ -404,44 +433,37 @@ def repeat_limit(self): @property def stopped(self): - """Has the tracker stopped - """ + """Has the tracker stopped""" return self._stopped @property def x(self): - """The staircase - """ + """The staircase""" return self._x @property def x_current(self): - """The current level - """ + """The current level""" return self._x_current @property def responses(self): - """The response history - """ + """The response history""" return self._responses @property def n_trials(self): - """The number of trials so far - """ + """The number of trials so far""" return self._n_trials @property def n_reversals(self): - """The number of reversals so far - """ + """The number of reversals so far""" return self._n_reversals @property def reversals(self): - """The reversal history (0 where there was no reversal) - """ + """The reversal history (0 where there was no reversal)""" return self._reversals @property @@ -475,20 +497,22 @@ def plot(self, ax=None, threshold=True, n_skip=2): The handles to the staircase line and the reversal dots. """ import matplotlib.pyplot as plt + if ax is None: fig, ax = plt.subplots(1) else: fig = ax.figure - line = ax.plot(1 + np.arange(self._n_trials), self._x, 'k.-') - line[0].set_label('Trials') - dots = ax.plot(1 + np.where(self._reversals > 0)[0], - self._x[self._reversals > 0], 'ro') - dots[0].set_label('Reversals') - ax.set(xlabel='Trial number', ylabel='Level') + line = ax.plot(1 + np.arange(self._n_trials), self._x, "k.-") + line[0].set_label("Trials") + dots = ax.plot( + 1 + np.where(self._reversals > 0)[0], self._x[self._reversals > 0], "ro" + ) + dots[0].set_label("Reversals") + ax.set(xlabel="Trial number", ylabel="Level") if threshold: thresh = self.plot_thresh(n_skip, ax) - thresh[0].set_label('Estimated Threshold') + thresh[0].set_label("Estimated Threshold") ax.legend() return fig, ax, line + dots @@ -509,10 +533,12 @@ def plot_thresh(self, n_skip=2, ax=None): The handle to the threshold line, as returned from ``plt.plot``. """ import matplotlib.pyplot as plt + if ax is None: ax = plt.gca() - h = ax.plot([1, self._n_trials], [self.threshold(n_skip)] * 2, - '--', color='gray') + h = ax.plot( + [1, self._n_trials], [self.threshold(n_skip)] * 2, "--", color="gray" + ) return h def threshold(self, n_skip=2): @@ -543,17 +569,20 @@ def threshold(self, n_skip=2): return np.nan else: if self._bad_reversals[rev_inds].any(): - raise ValueError('Cannot calculate thresholds with reversals ' - 'attempting to exceed x_min or x_max. Try ' - 'increasing n_skip.') - return (np.mean(self._x[rev_inds[0::2]]) + - np.mean(self._x[rev_inds[1::2]])) / 2 + raise ValueError( + "Cannot calculate thresholds with reversals " + "attempting to exceed x_min or x_max. Try " + "increasing n_skip." + ) + return ( + np.mean(self._x[rev_inds[0::2]]) + np.mean(self._x[rev_inds[1::2]]) + ) / 2 # ============================================================================= # Define the TrackerBinom Class # ============================================================================= -class TrackerBinom(object): +class TrackerBinom: """Binomial hypothesis testing tracker This class implements a tracker that runs a test at each trial with the @@ -608,8 +637,16 @@ class TrackerBinom(object): of following them. """ - def __init__(self, callback, alpha, chance, max_trials, min_trials=0, - stop_early=True, x_current=np.nan): + def __init__( + self, + callback, + alpha, + chance, + max_trials, + min_trials=0, + stop_early=True, + x_current=np.nan, + ): self._callback = _check_callback(callback) self._alpha = alpha self._chance = chance @@ -629,18 +666,25 @@ def __init__(self, callback, alpha, chance, max_trials, min_trials=0, # Now write the initialization data out self._tracker_id = id(self) - self._callback('tracker_identify', json.dumps(dict( - tracker_id=self._tracker_id, - tracker_type='TrackerBinom'))) - - self._callback('tracker_%s_init' % self._tracker_id, json.dumps(dict( - callback=None, - alpha=self._alpha, - chance=self._chance, - max_trials=self._max_trials, - min_trials=self._min_trials, - stop_early=self._stop_early, - x_current=self._x_current))) + self._callback( + "tracker_identify", + json.dumps(dict(tracker_id=self._tracker_id, tracker_type="TrackerBinom")), + ) + + self._callback( + "tracker_%s_init" % self._tracker_id, + json.dumps( + dict( + callback=None, + alpha=self._alpha, + chance=self._chance, + max_trials=self._max_trials, + min_trials=self._min_trials, + stop_early=self._stop_early, + x_current=self._x_current, + ) + ), + ) def respond(self, correct): """Update the tracker based on the last response. @@ -657,15 +701,16 @@ def respond(self, correct): else: self._n_correct += 1 self._pc = float(self._n_correct) / self._n_trials - self._p_val = binom.cdf(self._n_wrong, self._n_trials, - 1 - self._chance) - self._min_p_val = binom.cdf(self._n_wrong, self._max_trials, - 1 - self._chance) - self._max_p_val = binom.cdf(self._n_wrong + (self._max_trials - - self._n_trials), - self._max_trials, 1 - self._chance) - if ((self._p_val <= self._alpha) or - (self._min_p_val >= self._alpha and self._stop_early)): + self._p_val = binom.cdf(self._n_wrong, self._n_trials, 1 - self._chance) + self._min_p_val = binom.cdf(self._n_wrong, self._max_trials, 1 - self._chance) + self._max_p_val = binom.cdf( + self._n_wrong + (self._max_trials - self._n_trials), + self._max_trials, + 1 - self._chance, + ) + if (self._p_val <= self._alpha) or ( + self._min_p_val >= self._alpha and self._stop_early + ): if self._n_trials >= self._min_trials: self._stopped = True if self._n_trials == self._max_trials: @@ -673,12 +718,17 @@ def respond(self, correct): if self._stopped: self._callback( - 'tracker_%s_stop' % self._tracker_id, json.dumps(dict( - responses=[int(s) for s in self._responses], - p_val=self._p_val, - success=int(self.success)))) + "tracker_%s_stop" % self._tracker_id, + json.dumps( + dict( + responses=[int(s) for s in self._responses], + p_val=self._p_val, + success=int(self.success), + ) + ), + ) else: - self._callback('tracker_%s_respond' % self._tracker_id, correct) + self._callback("tracker_%s_respond" % self._tracker_id, correct) # ========================================================================= # Define all the public properties @@ -717,55 +767,47 @@ def n_trials(self): @property def n_wrong(self): - """The number of incorrect trials so far - """ + """The number of incorrect trials so far""" return self._n_wrong @property def n_correct(self): - """The number of correct trials so far - """ + """The number of correct trials so far""" return self._n_correct @property def pc(self): - """Proportion correct (0-1, NaN before any responses made) - """ + """Proportion correct (0-1, NaN before any responses made)""" return self._pc @property def responses(self): - """The response history - """ + """The response history""" return self._responses @property def stopped(self): - """Is the tracker stopped - """ + """Is the tracker stopped""" return self._stopped @property def success(self): - """Has the p-value reached significance - """ + """Has the p-value reached significance""" return self._p_val <= self._alpha @property def x_current(self): - """Included only for compatibility with TrackerDealer - """ + """Included only for compatibility with TrackerDealer""" return self._x_current @property def x(self): - """Included only for compatibility with TrackerDealer - """ + """Included only for compatibility with TrackerDealer""" return np.array([self._x_current for _ in range(self._n_trials)]) @property def stop_rule(self): - return 'trials' + return "trials" # ============================================================================= @@ -774,9 +816,10 @@ def stop_rule(self): # TODO: Make it so you can add a list of values for each dimension (such as the # phase in a BMLD task) and have it return that + # TODO: eventually, make a BaseTracker class so that TrackerDealer can make # sure it has the methods / properties it needs -class TrackerDealer(object): +class TrackerDealer: """Class for selecting and pacing independent simultaneous trackers Parameters @@ -817,36 +860,44 @@ class TrackerDealer(object): pace. """ - def __init__(self, callback, trackers, max_lag=1, pace_rule='reversals', - rand=None): + def __init__(self, callback, trackers, max_lag=1, pace_rule="reversals", rand=None): # dim will only be used for user output. Will be stored as 0-d self._callback = _check_callback(callback) self._trackers = np.asarray(trackers) for ti, t in enumerate(self._trackers.flat): if not isinstance(t, (TrackerUD, TrackerBinom)): - raise TypeError('trackers.ravel()[%d] is type %s, must be ' - 'TrackerUD or TrackerBinom' % (ti, type(t))) + raise TypeError( + "trackers.ravel()[%d] is type %s, must be " + "TrackerUD or TrackerBinom" % (ti, type(t)) + ) if isinstance(t, TrackerBinom) and t.stop_early: - raise ValueError('stop_early for trackers.flat[%d] must be ' - 'False to deal trials from a TrackerBinom ' - 'object' % (ti,)) + raise ValueError( + "stop_early for trackers.flat[%d] must be " + "False to deal trials from a TrackerBinom " + "object" % (ti,) + ) self._shape = self._trackers.shape self._n = np.prod(self._shape) self._max_lag = max_lag self._pace_rule = pace_rule - if any([isinstance(t, TrackerBinom) for t in - self._trackers]) and pace_rule == 'reversals': - raise ValueError('pace_rule must be ''trials'' to deal trials from' - ' a TrackerBinom object') + if ( + any([isinstance(t, TrackerBinom) for t in self._trackers]) + and pace_rule == "reversals" + ): + raise ValueError( + "pace_rule must be " + "trials" + " to deal trials from" + " a TrackerBinom object" + ) if rand is None: self._seed = int(time.time()) rand = np.random.RandomState(self._seed) else: self._seed = None if not isinstance(rand, np.random.RandomState): - raise TypeError('rand must be of type ' - 'numpy.random.RandomState') + raise TypeError("rand must be of type " "numpy.random.RandomState") self._rand = rand self._trial_complete = True self._tracker_history = np.array([], dtype=int) @@ -854,14 +905,19 @@ def __init__(self, callback, trackers, max_lag=1, pace_rule='reversals', self._x_history = np.array([], dtype=float) self._dealer_id = id(self) - self._callback('dealer_identify', json.dumps(dict( - dealer_id=self._dealer_id))) - - self._callback('dealer_%s_init' % self._dealer_id, json.dumps(dict( - trackers=[s._tracker_id for s in self._trackers.ravel()], - shape=self._shape, - max_lag=self._max_lag, - pace_rule=self._pace_rule))) + self._callback("dealer_identify", json.dumps(dict(dealer_id=self._dealer_id))) + + self._callback( + "dealer_%s_init" % self._dealer_id, + json.dumps( + dict( + trackers=[s._tracker_id for s in self._trackers.ravel()], + shape=self._shape, + max_lag=self._max_lag, + pace_rule=self._pace_rule, + ) + ), + ) def __iter__(self): return self @@ -877,15 +933,13 @@ def next(self): The level of the selected tracker. """ if self.stopped: - raise(StopIteration) + raise StopIteration if not self._trial_complete: # Chose a new tracker before responding, so record non-response - self._response_history = np.append(self._response_history, - np.nan) + self._response_history = np.append(self._response_history, np.nan) self._trial_complete = False self._current_tracker = self._pick() - self._tracker_history = np.append(self._tracker_history, - self._current_tracker) + self._tracker_history = np.append(self._tracker_history, self._current_tracker) ss = np.unravel_index(self._current_tracker, self.shape) level = self._trackers.flat[self._current_tracker].x_current self._x_history = np.append(self._x_history, level) @@ -895,15 +949,14 @@ def __next__(self): # for py3k compatibility return self.next() def _pick(self): - """Decide which tracker from which to draw a trial - """ + """Decide which tracker from which to draw a trial""" if self.stopped: - raise RuntimeError('All trackers have stopped.') + raise RuntimeError("All trackers have stopped.") active = np.where([not t.stopped for t in self._trackers.flat])[0] - if self._pace_rule == 'reversals': + if self._pace_rule == "reversals": pace = np.asarray([t.n_reversals for t in self._trackers.flat]) - elif self._pace_rule == 'trials': + elif self._pace_rule == "trials": pace = np.asarray([t.n_trials for t in self._trackers.flat]) pace = pace[active] lag = pace.max() - pace @@ -927,17 +980,21 @@ def respond(self, correct): Was the most recent subject response correct? """ if self._trial_complete: - raise RuntimeError('You must get a trial before you can respond.') + raise RuntimeError("You must get a trial before you can respond.") self._trackers.flat[self._current_tracker].respond(correct) self._trial_complete = True self._response_history = np.append(self._response_history, correct) if self.stopped: self._callback( - 'dealer_%s_stop' % self._dealer_id, json.dumps(dict( - tracker_history=[int(s) for s in self._tracker_history], - response_history=[float(s) for s in - self._response_history], - x_history=[float(s) for s in self._x_history]))) + "dealer_%s_stop" % self._dealer_id, + json.dumps( + dict( + tracker_history=[int(s) for s in self._tracker_history], + response_history=[float(s) for s in self._response_history], + x_history=[float(s) for s in self._x_history], + ) + ), + ) def history(self, include_skips=False): """The history of the dealt trials and the responses @@ -959,12 +1016,14 @@ def history(self, include_skips=False): The response history (i.e., correct or incorrect) """ if include_skips: - return (self._tracker_history, self._x_history, - self._response_history) + return (self._tracker_history, self._x_history, self._response_history) else: inds = np.invert(np.isnan(self._response_history)) - return (self._tracker_history[inds], self._x_history[inds], - self._response_history[inds].astype(bool)) + return ( + self._tracker_history[inds], + self._x_history[inds], + self._response_history[inds].astype(bool), + ) @property def shape(self): @@ -972,21 +1031,19 @@ def shape(self): @property def stopped(self): - """Are all the trackers stopped - """ + """Are all the trackers stopped""" return all(t.stopped for t in self._trackers.flat) @property def trackers(self): - """All of the tracker objects in the container - """ + """All of the tracker objects in the container""" return self._trackers # ============================================================================= # Define the TrackerMHW Class # ============================================================================= -class TrackerMHW(object): +class TrackerMHW: """Up-down adaptive tracker for the modified Hughson-Westlake Procedure This class implements a standard up-down adaptive tracker object. It is @@ -1039,9 +1096,18 @@ class TrackerMHW(object): and finding threshold. """ - def __init__(self, callback, x_min, x_max, base_step=5, factor_down=2, - factor_up_nr=4, start_value=40, n_up_stop=2, - repeat_limit='reversals'): + def __init__( + self, + callback, + x_min, + x_max, + base_step=5, + factor_down=2, + factor_up_nr=4, + start_value=40, + n_up_stop=2, + repeat_limit="reversals", + ): self._callback = _check_callback(callback) self._x_min = x_min self._x_max = x_max @@ -1052,25 +1118,25 @@ def __init__(self, callback, x_min, x_max, base_step=5, factor_down=2, self._n_up_stop = n_up_stop self._repeat_limit = repeat_limit - if type(x_min) != int and type(x_min) != float: - raise TypeError('x_min must be a float or integer') - if type(x_max) != int and type(x_max) != float: - raise TypeError('x_max must be a float or integer') + if not isinstance(x_min, (int, float)): + raise TypeError("x_min must be a float or integer") + if not isinstance(x_max, (int, float)): + raise TypeError("x_max must be a float or integer") self._x = np.asarray([start_value], dtype=float) if not np.isscalar(start_value): - raise TypeError('start_value must be a scalar') + raise TypeError("start_value must be a scalar") else: if start_value % base_step != 0: - raise ValueError('start_value must be a multiple of base_step') + raise ValueError("start_value must be a multiple of base_step") else: if (x_min - start_value) % base_step != 0: - raise ValueError('x_min must be a multiple of base_step') + raise ValueError("x_min must be a multiple of base_step") if (x_max - start_value) % base_step != 0: - raise ValueError('x_max must be a multiple of base_step') + raise ValueError("x_max must be a multiple of base_step") - if type(n_up_stop) != int: - raise TypeError('n_up_stop must be an integer') + if not isinstance(n_up_stop, int): + raise TypeError("n_up_stop must be an integer") self._x_current = float(start_value) self._responses = np.asarray([], dtype=bool) @@ -1090,21 +1156,28 @@ def __init__(self, callback, x_min, x_max, base_step=5, factor_down=2, self._threshold = np.nan # Now write the initialization data out - self._tracker_id = '%s-%s' % (id(self), int(round(time.time() * 1e6))) - self._callback('tracker_identify', json.dumps(dict( - tracker_id=self._tracker_id, - tracker_type='TrackerMHW'))) - - self._callback('tracker_%s_init' % self._tracker_id, json.dumps(dict( - callback=None, - base_step=self._base_step, - factor_down=self._factor_down, - factor_up_nr=self._factor_up_nr, - start_value=self._start_value, - x_min=self._x_min, - x_max=self._x_max, - n_up_stop=self._n_up_stop, - repeat_limit=self._repeat_limit))) + self._tracker_id = "%s-%s" % (id(self), int(round(time.time() * 1e6))) + self._callback( + "tracker_identify", + json.dumps(dict(tracker_id=self._tracker_id, tracker_type="TrackerMHW")), + ) + + self._callback( + "tracker_%s_init" % self._tracker_id, + json.dumps( + dict( + callback=None, + base_step=self._base_step, + factor_down=self._factor_down, + factor_up_nr=self._factor_up_nr, + start_value=self._start_value, + x_min=self._x_min, + x_max=self._x_max, + n_up_stop=self._n_up_stop, + repeat_limit=self._repeat_limit, + ) + ), + ) def respond(self, correct): """Update the tracker based on the last response. @@ -1115,7 +1188,7 @@ def respond(self, correct): Was the most recent subject response correct? """ if self._stopped: - raise RuntimeError('Tracker is stopped.') + raise RuntimeError("Tracker is stopped.") bound = False bad = False @@ -1153,12 +1226,14 @@ def respond(self, correct): if step_dir == 0: self._x = np.append(self._x, self._x[-1]) elif step_dir < 0: - self._x = np.append(self._x, self._x[-1] - - self._factor_down * self._base_step) + self._x = np.append( + self._x, self._x[-1] - self._factor_down * self._base_step + ) elif step_dir > 0: if self._n_correct == 0: - self._x = np.append(self._x, self._x[-1] + - self._factor_up_nr * self._base_step) + self._x = np.append( + self._x, self._x[-1] + self._factor_up_nr * self._base_step + ) else: self._x = np.append(self._x, self._x[-1] + self._base_step) @@ -1167,10 +1242,10 @@ def respond(self, correct): self._limit_count += 1 if bound: bad = True - if self._repeat_limit == 'reversals': + if self._repeat_limit == "reversals": reversal = True self._n_reversals += 1 - if self._repeat_limit == 'ignore': + if self._repeat_limit == "ignore": reversal = False self._direction = 0 if self._x[-1] >= self._x_max: @@ -1178,10 +1253,10 @@ def respond(self, correct): self._limit_count += 1 if bound: bad = True - if self._repeat_limit == 'reversals': + if self._repeat_limit == "reversals": reversal = True self._n_reversals += 1 - if self._repeat_limit == 'ignore': + if self._repeat_limit == "ignore": reversal = False self._direction = 0 @@ -1197,18 +1272,23 @@ def respond(self, correct): if not self._stopped: self._x_current = self._x[-1] - self._callback('tracker_%s_respond' % self._tracker_id, - correct) + self._callback("tracker_%s_respond" % self._tracker_id, correct) else: self._x = self._x[:-1] self._callback( - 'tracker_%s_stop' % self._tracker_id, json.dumps(dict( - responses=[int(s) for s in self._responses], - reversals=[int(s) for s in self._reversals], - x=[float(s) for s in self._x], - threshold=self._threshold, - n_correct_levels={int(k): v for k, v in - self._n_correct_levels.items()}))) + "tracker_%s_stop" % self._tracker_id, + json.dumps( + dict( + responses=[int(s) for s in self._responses], + reversals=[int(s) for s in self._reversals], + x=[float(s) for s in self._x], + threshold=self._threshold, + n_correct_levels={ + int(k): v for k, v in self._n_correct_levels.items() + }, + ) + ), + ) def check_valid(self, n_reversals): """If last reversals contain reversals exceeding x_min or x_max. @@ -1224,15 +1304,18 @@ def check_valid(self, n_reversals): True if none of the reversals are at x_min or x_max and False otherwise. """ - self._valid = (not self._bad_reversals[self._reversals != 0] - [-n_reversals:].any()) + self._valid = not self._bad_reversals[self._reversals != 0][-n_reversals:].any() return self._valid def _stop_here(self): - self._threshold_reached = [self._n_correct_levels[level] == - self._n_up_stop for level in self._levels] - if self._n_correct == 0 and self._x[ - -2] == self._x_max and self._x[-1] == self._x_max: + self._threshold_reached = [ + self._n_correct_levels[level] == self._n_up_stop for level in self._levels + ] + if ( + self._n_correct == 0 + and self._x[-2] == self._x_max + and self._x[-1] == self._x_max + ): self._n_stop = True self._threshold = np.nan elif len(self._x) > 3 and (self._x == self._x_max).sum() >= 4: @@ -1242,13 +1325,14 @@ def _stop_here(self): self._threshold = self._x_min elif self._threshold_reached.count(True) == 1: self._n_stop = True - self._threshold = int(self._levels[ - [i for i, tr in enumerate(self._threshold_reached) if tr]]) + self._threshold = self._levels[self._threshold_reached].item() else: self._n_stop = False if self._n_stop and self._limit_count > 0: - warnings.warn('Tracker {} exceeded x_min or x_max bounds {} times.' - ''.format(self._tracker_id, self._limit_count)) + warnings.warn( + f"Tracker {self._tracker_id} exceeded x_min or x_max bounds " + f"{self._limit_count} times." + ) return self._n_stop # ========================================================================= @@ -1296,44 +1380,37 @@ def threshold(self): @property def stopped(self): - """Has the tracker stopped - """ + """Has the tracker stopped""" return self._stopped @property def x(self): - """The staircase - """ + """The staircase""" return self._x @property def x_current(self): - """The current level - """ + """The current level""" return self._x_current @property def responses(self): - """The response history - """ + """The response history""" return self._responses @property def n_trials(self): - """The number of trials so far - """ + """The number of trials so far""" return self._n_trials @property def n_reversals(self): - """The number of reversals so far - """ + """The number of reversals so far""" return self._n_reversals @property def reversals(self): - """The reversal history (0 where there was no reversal) - """ + """The reversal history (0 where there was no reversal)""" return self._reversals @property @@ -1370,20 +1447,22 @@ def plot(self, ax=None, threshold=True): The handles to the staircase line and the reversal dots. """ import matplotlib.pyplot as plt + if ax is None: fig, ax = plt.subplots(1) else: fig = ax.figure - line = ax.plot(1 + np.arange(self._n_trials), self._x, 'k.-') - line[0].set_label('Trials') - dots = ax.plot(1 + np.where(self._reversals > 0)[0], - self._x[self._reversals > 0], 'ro') - dots[0].set_label('Reversals') - ax.set(xlabel='Trial number', ylabel='Level (dB)') + line = ax.plot(1 + np.arange(self._n_trials), self._x, "k.-") + line[0].set_label("Trials") + dots = ax.plot( + 1 + np.where(self._reversals > 0)[0], self._x[self._reversals > 0], "ro" + ) + dots[0].set_label("Reversals") + ax.set(xlabel="Trial number", ylabel="Level (dB)") if threshold: thresh = self.plot_thresh(ax) - thresh[0].set_label('Threshold') + thresh[0].set_label("Threshold") ax.legend() return fig, ax, line + dots @@ -1402,8 +1481,8 @@ def plot_thresh(self, ax=None): The handle to the threshold line, as returned from ``plt.plot``. """ import matplotlib.pyplot as plt + if ax is None: ax = plt.gca() - h = ax.plot([1, self._n_trials], [self._threshold] * 2, '--', - color='gray') + h = ax.plot([1, self._n_trials], [self._threshold] * 2, "--", color="gray") return h diff --git a/expyfun/stimuli/_vocoder.py b/expyfun/stimuli/_vocoder.py index a3d6271a..d68a4027 100644 --- a/expyfun/stimuli/_vocoder.py +++ b/expyfun/stimuli/_vocoder.py @@ -1,11 +1,10 @@ -# -*- coding: utf-8 -*- -"""Vocoder functions -""" +"""Vocoder functions""" -import numpy as np -from scipy.signal import butter, lfilter, filtfilt import warnings +import numpy as np +from scipy.signal import butter, filtfilt, lfilter + from .._utils import verbose_dec @@ -20,8 +19,9 @@ def _erbn_to_freq(e): @verbose_dec -def get_band_freqs(fs, n_bands=16, freq_lims=(200., 8000.), scale='erb', - verbose=None): +def get_band_freqs( + fs, n_bands=16, freq_lims=(200.0, 8000.0), scale="erb", verbose=None +): """Calculate frequency band edges. Parameters @@ -45,27 +45,26 @@ def get_band_freqs(fs, n_bands=16, freq_lims=(200., 8000.), scale='erb', """ freq_lims = np.array(freq_lims, float) fs = float(fs) - if np.any(freq_lims >= fs / 2.): - raise ValueError('frequency limits must not exceed Nyquist') + if np.any(freq_lims >= fs / 2.0): + raise ValueError("frequency limits must not exceed Nyquist") assert freq_lims.ndim == 1 and freq_lims.size == 2 - if scale not in ('erb', 'log', 'hz'): + if scale not in ("erb", "log", "hz"): raise ValueError('Frequency scale must be "erb", "hz", or "log".') - if scale == 'erb': + if scale == "erb": freq_lims_erbn = _freq_to_erbn(freq_lims) delta_erb = np.diff(freq_lims_erbn) / n_bands - cutoffs = _erbn_to_freq(freq_lims_erbn[0] + - delta_erb * np.arange(n_bands + 1)) + cutoffs = _erbn_to_freq(freq_lims_erbn[0] + delta_erb * np.arange(n_bands + 1)) assert np.allclose(cutoffs[[0, -1]], freq_lims) # should be - elif scale == 'log': + elif scale == "log": freq_lims_log = np.log2(freq_lims) delta = np.diff(freq_lims_log) / n_bands - cutoffs = 2. ** (freq_lims_log[0] + delta * np.arange(n_bands + 1)) + cutoffs = 2.0 ** (freq_lims_log[0] + delta * np.arange(n_bands + 1)) assert np.allclose(cutoffs[[0, -1]], freq_lims) # should be else: # scale == 'hz' delta = np.diff(freq_lims) / n_bands cutoffs = freq_lims[0] + delta * np.arange(n_bands + 1) edges = zip(cutoffs[:-1], cutoffs[1:]) - return(edges) + return edges def get_bands(data, fs, edges, order=2, zero_phase=False, axis=-1): @@ -100,15 +99,15 @@ def get_bands(data, fs, edges, order=2, zero_phase=False, axis=-1): filts = [] for lf, hf in edges: # band-pass - b, a = butter(order, [2 * lf / fs, 2 * hf / fs], 'bandpass') + b, a = butter(order, [2 * lf / fs, 2 * hf / fs], "bandpass") filt = filtfilt if zero_phase else lfilter band = filt(b, a, data, axis=axis) bands.append(band) filts.append((b, a)) - return(bands, filts) + return bands, filts -def get_env(data, fs, lp_order=4, lp_cutoff=160., zero_phase=False, axis=-1): +def get_env(data, fs, lp_order=4, lp_cutoff=160.0, zero_phase=False, axis=-1): """Calculate a low-pass envelope of a signal Parameters @@ -133,18 +132,17 @@ def get_env(data, fs, lp_order=4, lp_cutoff=160., zero_phase=False, axis=-1): filt : tuple The filter coefficients (numerator, denominator). """ - if lp_cutoff >= fs / 2.: - raise ValueError('frequency limits must not exceed Nyquist') + if lp_cutoff >= fs / 2.0: + raise ValueError("frequency limits must not exceed Nyquist") cutoff = 2 * lp_cutoff / float(fs) - data[data < 0] = 0. # half-wave rectify - b, a = butter(lp_order, cutoff, 'lowpass') + data[data < 0] = 0.0 # half-wave rectify + b, a = butter(lp_order, cutoff, "lowpass") filt = filtfilt if zero_phase else lfilter env = filt(b, a, data, axis=axis) - return(env, (b, a)) + return env, (b, a) -def get_carriers(data, fs, edges, order=2, axis=-1, mode='tone', rate=None, - seed=None): +def get_carriers(data, fs, edges, order=2, axis=-1, mode="tone", rate=None, seed=None): """Generate carriers for frequency bands of a signal Parameters @@ -178,9 +176,8 @@ def get_carriers(data, fs, edges, order=2, axis=-1, mode='tone', rate=None, List of numpy ndarrays of the carrier signals. """ # check args - if mode not in ('noise', 'tone', 'poisson'): - raise ValueError('mode must be "noise", "tone", or "poisson", not {0}' - ''.format(mode)) + if mode not in ("noise", "tone", "poisson"): + raise ValueError(f'mode must be "noise", "tone", or "poisson", not {mode}' "") if isinstance(seed, np.random.RandomState): rng = seed elif seed is None: @@ -188,38 +185,53 @@ def get_carriers(data, fs, edges, order=2, axis=-1, mode='tone', rate=None, elif isinstance(seed, int): rng = np.random.RandomState(seed) else: - raise TypeError('"seed" must be an int, an instance of ' - 'numpy.random.RandomState, or None.') + raise TypeError( + '"seed" must be an int, an instance of ' + "numpy.random.RandomState, or None." + ) carrs = [] fs = float(fs) n_samp = data.shape[axis] for lf, hf in edges: - if mode == 'tone': - cf = (lf + hf) / 2. + if mode == "tone": + cf = (lf + hf) / 2.0 carrier = np.sin(2 * np.pi * cf * np.arange(n_samp) / fs) carrier *= np.sqrt(2) # rms of 1 shape = np.ones_like(data.shape) shape[axis] = n_samp carrier.shape = shape else: - if mode == 'noise': + if mode == "noise": carrier = rng.rand(*data.shape) else: # mode == 'poisson' prob = rate / fs with warnings.catch_warnings(record=True): # numpy silliness - carrier = rng.choice([0., 1.], n_samp, p=[1 - prob, prob]) - b, a = butter(order, [2 * lf / fs, 2 * hf / fs], 'bandpass') + carrier = rng.choice([0.0, 1.0], n_samp, p=[1 - prob, prob]) + b, a = butter(order, [2 * lf / fs, 2 * hf / fs], "bandpass") carrier = lfilter(b, a, carrier, axis=axis) - carrier /= np.sqrt(np.mean(carrier * carrier, axis=axis, - keepdims=True)) # rms of 1 + carrier /= np.sqrt( + np.mean(carrier * carrier, axis=axis, keepdims=True) + ) # rms of 1 carrs.append(carrier) - return(carrs) + return carrs @verbose_dec -def vocode(data, fs, n_bands=16, freq_lims=(200., 8000.), scale='erb', - order=2, lp_cutoff=160., lp_order=4, mode='noise', - rate=200, seed=None, axis=-1, verbose=None): +def vocode( + data, + fs, + n_bands=16, + freq_lims=(200.0, 8000.0), + scale="erb", + order=2, + lp_cutoff=160.0, + lp_order=4, + mode="noise", + rate=200, + seed=None, + axis=-1, + verbose=None, +): """Vocode stimuli using a variety of methods Parameters @@ -268,14 +280,17 @@ def vocode(data, fs, n_bands=16, freq_lims=(200., 8000.), scale='erb', The default settings are adapted from a cochlear implant simulation algorithm described by Zachary Smith (Cochlear Corp.). """ - edges = get_band_freqs(fs, n_bands=n_bands, freq_lims=freq_lims, - scale=scale) + edges = get_band_freqs(fs, n_bands=n_bands, freq_lims=freq_lims, scale=scale) bands, filts = get_bands(data, fs, edges, order=order, axis=axis) - envs, env_filts = zip(*[get_env(x, fs, lp_order=lp_order, - lp_cutoff=lp_cutoff, axis=axis) - for x in bands]) - carrs = get_carriers(data, fs, edges, order=order, axis=axis, mode=mode, - rate=rate, seed=seed) + envs, env_filts = zip( + *[ + get_env(x, fs, lp_order=lp_order, lp_cutoff=lp_cutoff, axis=axis) + for x in bands + ] + ) + carrs = get_carriers( + data, fs, edges, order=order, axis=axis, mode=mode, rate=rate, seed=seed + ) # reconstruct voc = np.zeros_like(data) for carr, env in zip(carrs, envs): diff --git a/expyfun/stimuli/tests/test_mls.py b/expyfun/stimuli/tests/test_mls.py index e223c04b..a2644946 100644 --- a/expyfun/stimuli/tests/test_mls.py +++ b/expyfun/stimuli/tests/test_mls.py @@ -2,12 +2,11 @@ import pytest from numpy.testing import assert_allclose -from expyfun.stimuli import repeated_mls, compute_mls_impulse_response +from expyfun.stimuli import compute_mls_impulse_response, repeated_mls def test_mls_ir(): - """Test computing impulse response with MLS - """ + """Test computing impulse response with MLS""" # test simple stuff for _ in range(5): # make sure our signals have some DC @@ -17,20 +16,20 @@ def test_mls_ir(): mls, n_resp = repeated_mls(len(kernel), n_repeats) resp = np.zeros(n_resp) - resp[:len(mls) + len(kernel) - 1] = np.convolve(mls, kernel) + resp[: len(mls) + len(kernel) - 1] = np.convolve(mls, kernel) est_kernel = compute_mls_impulse_response(resp, mls, n_repeats) kernel_pad = np.zeros(len(est_kernel)) - kernel_pad[:len(kernel)] = kernel + kernel_pad[: len(kernel)] = kernel assert_allclose(kernel_pad, est_kernel, atol=1e-5, rtol=1e-5) # failure modes - pytest.raises(TypeError, repeated_mls, 'foo', n_repeats) - pytest.raises(ValueError, compute_mls_impulse_response, resp[:-1], mls, - n_repeats) - pytest.raises(ValueError, compute_mls_impulse_response, resp, mls[:-1], - n_repeats) - pytest.raises(ValueError, compute_mls_impulse_response, resp, - mls * 2. - 1., n_repeats) - pytest.raises(ValueError, compute_mls_impulse_response, resp, - mls[np.newaxis, :], n_repeats) + pytest.raises(TypeError, repeated_mls, "foo", n_repeats) + pytest.raises(ValueError, compute_mls_impulse_response, resp[:-1], mls, n_repeats) + pytest.raises(ValueError, compute_mls_impulse_response, resp, mls[:-1], n_repeats) + pytest.raises( + ValueError, compute_mls_impulse_response, resp, mls * 2.0 - 1.0, n_repeats + ) + pytest.raises( + ValueError, compute_mls_impulse_response, resp, mls[np.newaxis, :], n_repeats + ) diff --git a/expyfun/stimuli/tests/test_stimuli.py b/expyfun/stimuli/tests/test_stimuli.py index c1432c9d..3b73dfeb 100644 --- a/expyfun/stimuli/tests/test_stimuli.py +++ b/expyfun/stimuli/tests/test_stimuli.py @@ -1,92 +1,124 @@ -# -*- coding: utf-8 -*- +import os import numpy as np import pytest -from numpy.testing import (assert_array_equal, assert_array_almost_equal, - assert_allclose, assert_equal) +from numpy.testing import ( + assert_allclose, + assert_array_almost_equal, + assert_array_equal, + assert_equal, +) from scipy.signal import butter, lfilter -from expyfun._sound_controllers import _BACKENDS -from expyfun._utils import requires_lib, requires_opengl21, _check_skip_backend -from expyfun.stimuli import (rms, play_sound, convolve_hrtf, window_edges, - vocode, texture_ERB, crm_info, crm_prepare_corpus, - crm_sentence, crm_response_menu, CRMPreload, - add_pad) from expyfun import ExperimentController - - -std_kwargs = dict(output_dir=None, full_screen=False, window_size=(340, 480), - participant='foo', session='01', stim_db=0.0, noise_db=0.0, - verbose=True, version='dev') - - -@requires_lib('mne') +from expyfun._sound_controllers import _BACKENDS +from expyfun._utils import _check_skip_backend, requires_lib, requires_opengl21 +from expyfun.stimuli import ( + CRMPreload, + add_pad, + convolve_hrtf, + crm_info, + crm_prepare_corpus, + crm_response_menu, + crm_sentence, + play_sound, + rms, + texture_ERB, + vocode, + window_edges, +) + +std_kwargs = dict( + output_dir=None, + full_screen=False, + window_size=(340, 480), + participant="foo", + session="01", + stim_db=0.0, + noise_db=0.0, + verbose=True, + version="dev", +) + + +@requires_lib("mne") def test_textures(): """Test stimulus textures.""" texture_ERB() # smoke test - pytest.raises(TypeError, texture_ERB, seq='foo') - pytest.raises(ValueError, texture_ERB, seq=('foo',)) - with pytest.warns(UserWarning, match='LESS THAN 1 ERB'): + pytest.raises(TypeError, texture_ERB, seq="foo") + pytest.raises(ValueError, texture_ERB, seq=("foo",)) + with pytest.warns(UserWarning, match="LESS THAN 1 ERB"): x = texture_ERB(freq_lims=(200, 500)) - assert_allclose(len(x) / 24414., 4., rtol=1e-5) + assert_allclose(len(x) / 24414.0, 4.0, rtol=1e-5) @pytest.mark.timeout(15) -@requires_lib('h5py') +@requires_lib("h5py") def test_hrtf_convolution(): """Test HRTF convolution.""" data = np.random.randn(2, 10000) pytest.raises(ValueError, convolve_hrtf, data, 44100, 0, interp=False) data = data[0] pytest.raises(ValueError, convolve_hrtf, data, 44100, 0.5, interp=False) - pytest.raises(ValueError, convolve_hrtf, data, 44100, 0, - source='foo', interp=False) + pytest.raises(ValueError, convolve_hrtf, data, 44100, 0, source="foo", interp=False) pytest.raises(ValueError, convolve_hrtf, data, 44100, 90.5, interp=True) - pytest.raises(ValueError, convolve_hrtf, data, 44100, 0, interp='foo') + pytest.raises(ValueError, convolve_hrtf, data, 44100, 0, interp="foo") # invalid angle when interp=False for interp in [True, False]: - for source in ['barb', 'cipic']: - if interp and source == 'barb': + for source in ["barb", "cipic"]: + if interp and source == "barb": # raise an error when trying to interp with 'barb' - pytest.raises(ValueError, convolve_hrtf, data, 44100, 2.5, - source=source, interp=interp) + pytest.raises( + ValueError, + convolve_hrtf, + data, + 44100, + 2.5, + source=source, + interp=interp, + ) else: - out = convolve_hrtf(data, 44100, 0, source=source, - interp=interp) - out_2 = convolve_hrtf(data, 24414, 0, source=source, - interp=interp) + out = convolve_hrtf(data, 44100, 0, source=source, interp=interp) + out_2 = convolve_hrtf(data, 24414, 0, source=source, interp=interp) assert_equal(out.ndim, 2) assert_equal(out.shape[0], 2) - assert (out.shape[1] > data.size) - assert (out_2.shape[1] < out.shape[1]) + assert out.shape[1] > data.size + assert out_2.shape[1] < out.shape[1] if interp: - out_3 = convolve_hrtf(data, 44100, 2.5, source=source, - interp=interp) - out_4 = convolve_hrtf(data, 44100, -2.5, source=source, - interp=interp) + out_3 = convolve_hrtf( + data, 44100, 2.5, source=source, interp=interp + ) + out_4 = convolve_hrtf( + data, 44100, -2.5, source=source, interp=interp + ) assert_equal(out_3.ndim, 2) assert_equal(out_4.ndim, 2) # ensure that, at least for zero degrees, it's close - out = convolve_hrtf(data, 44100, 0, source=source, - interp=interp)[:, 1024:-1024] + out = convolve_hrtf(data, 44100, 0, source=source, interp=interp)[ + :, 1024:-1024 + ] assert_allclose(np.mean(rms(out)), rms(data), rtol=1e-1) - out = convolve_hrtf(data, 44100, -90, source=source, - interp=interp) + out = convolve_hrtf(data, 44100, -90, source=source, interp=interp) rmss = rms(out) - assert (rmss[0] > 4 * rmss[1]) + assert rmss[0] > 4 * rmss[1] -@pytest.mark.parametrize('backend', ('auto',) + _BACKENDS) +@pytest.mark.skipif( + os.getenv("AZURE_CI_WINDOWS", "") == "true", reason="Azure CI Windows has problems" +) +@pytest.mark.parametrize("backend", ("auto",) + _BACKENDS) def test_play_sound(backend, hide_window): # only works if windowing works """Test playing a sound.""" _check_skip_backend(backend) + fs = 48000 data = np.zeros((2, 100)) - play_sound(data).stop() - play_sound(data[0], norm=False, wait=True) - pytest.raises(ValueError, play_sound, data[:, :, np.newaxis]) + play_sound(data, fs=fs).stop() + play_sound(data[0], norm=False, wait=True, fs=fs) + with pytest.raises(ValueError, match="sound must be"): + play_sound(data[:, :, np.newaxis], fs=fs) # Make sure each backend can handle a lot of sounds for _ in range(10): - snd = play_sound(data) + snd = play_sound(data, fs=fs) # we manually stop and delete here, because we don't want to # have to wait for our Timer instances to get around to doing # it... this also checks to make sure calling `delete()` more @@ -99,40 +131,40 @@ def test_window_edges(): """Test windowing signal edges.""" sig = np.ones((2, 1000)) fs = 44100 - pytest.raises(ValueError, window_edges, sig, fs, window='foo') # bad win + pytest.raises(ValueError, window_edges, sig, fs, window="foo") # bad win pytest.raises(RuntimeError, window_edges, sig, fs, dur=1.0) # too long - pytest.raises(ValueError, window_edges, sig, fs, edges='foo') # bad type - x = window_edges(sig, fs, edges='leading') - y = window_edges(sig, fs, edges='trailing') + pytest.raises(ValueError, window_edges, sig, fs, edges="foo") # bad type + x = window_edges(sig, fs, edges="leading") + y = window_edges(sig, fs, edges="trailing") z = window_edges(sig, fs) - assert (np.all(x[:, 0] < 1)) # make sure we actually reduced amp - assert (np.all(x[:, -1] == 1)) - assert (np.all(y[:, 0] == 1)) - assert (np.all(y[:, -1] < 1)) + assert np.all(x[:, 0] < 1) # make sure we actually reduced amp + assert np.all(x[:, -1] == 1) + assert np.all(y[:, 0] == 1) + assert np.all(y[:, -1] < 1) assert_allclose(x + y, z + 1) def _voc_similarity(orig, voc): """Quantify envelope similarity after vocoding.""" - return np.correlate(orig, voc, mode='full').max() + return np.correlate(orig, voc, mode="full").max() def test_vocoder(): """Test noise, tone, and click vocoding.""" data = np.random.randn(10000) env = np.random.randn(10000) - b, a = butter(4, 0.001, 'lowpass') + b, a = butter(4, 0.001, "lowpass") data *= lfilter(b, a, env) # bad limits pytest.raises(ValueError, vocode, data, 44100, freq_lims=(200, 30000)) # bad mode - pytest.raises(ValueError, vocode, data, 44100, mode='foo') + pytest.raises(ValueError, vocode, data, 44100, mode="foo") # bad seed - pytest.raises(TypeError, vocode, data, 44100, seed='foo') - pytest.raises(ValueError, vocode, data, 44100, scale='foo') - voc1 = vocode(data, 20000, mode='noise', scale='log') - voc2 = vocode(data, 20000, mode='tone', order=4, seed=0, scale='hz') - voc3 = vocode(data, 20000, mode='poisson', seed=np.random.RandomState(123)) + pytest.raises(TypeError, vocode, data, 44100, seed="foo") + pytest.raises(ValueError, vocode, data, 44100, scale="foo") + voc1 = vocode(data, 20000, mode="noise", scale="log") + voc2 = vocode(data, 20000, mode="tone", order=4, seed=0, scale="hz") + voc3 = vocode(data, 20000, mode="poisson", seed=np.random.RandomState(123)) # XXX This is about the best we can do for now... assert_array_equal(voc1.shape, data.shape) assert_array_equal(voc2.shape, data.shape) @@ -142,13 +174,13 @@ def test_vocoder(): def test_rms(): """Test RMS calculation.""" # Test a couple trivial things we know - sin = np.sin(2 * np.pi * 1000 * np.arange(10000, dtype=float) / 10000.) - assert_array_almost_equal(rms(sin), 1. / np.sqrt(2)) + sin = np.sin(2 * np.pi * 1000 * np.arange(10000, dtype=float) / 10000.0) + assert_array_almost_equal(rms(sin), 1.0 / np.sqrt(2)) assert_array_almost_equal(rms(np.ones((100, 2)) * 2, 0), [2, 2]) -@pytest.mark.timeout(60) # can be slow to load on CIs -@requires_lib('mne') +@pytest.mark.timeout(120) # can be slow to load on CIs +@requires_lib("mne") def test_crm(tmpdir): """Test CRM Corpus functions.""" fs = 40000 # native rate, to avoid large resampling delay in testing @@ -156,63 +188,58 @@ def test_crm(tmpdir): tempdir = str(tmpdir) # corpus prep - talkers = [dict(sex='f', talker_num=0)] + talkers = [dict(sex="f", talker_num=0)] - crm_prepare_corpus(fs, path_out=tempdir, talker_list=talkers, - n_jobs=1) - crm_prepare_corpus(fs, path_out=tempdir, talker_list=talkers, n_jobs=1, - overwrite=True) + crm_prepare_corpus(fs, path_out=tempdir, talker_list=talkers, n_jobs=1) + crm_prepare_corpus( + fs, path_out=tempdir, talker_list=talkers, n_jobs=1, overwrite=True + ) # no overwrite pytest.raises(RuntimeError, crm_prepare_corpus, fs, path_out=tempdir) # load sentence from hard drive - crm_sentence(fs, 'f', 0, 0, 0, 0, 0, ramp_dur=0, path=tempdir) - crm_sentence(fs, 1, '0', 'charlie', 'red', '5', stereo=True, path=tempdir) + crm_sentence(fs, "f", 0, 0, 0, 0, 0, ramp_dur=0, path=tempdir) + crm_sentence(fs, 1, "0", "charlie", "red", "5", stereo=True, path=tempdir) # bad value requested - pytest.raises(ValueError, crm_sentence, fs, 1, 0, 0, 'periwinkle', 0, - path=tempdir) + pytest.raises(ValueError, crm_sentence, fs, 1, 0, 0, "periwinkle", 0, path=tempdir) # unprepared talker - pytest.raises(RuntimeError, crm_sentence, fs, 'm', 0, 0, 0, 0, - path=tempdir) + pytest.raises(RuntimeError, crm_sentence, fs, "m", 0, 0, 0, 0, path=tempdir) # unprepared sampling rate - pytest.raises(RuntimeError, crm_sentence, fs + 1, 0, 0, 0, 0, 0, - path=tempdir) + pytest.raises(RuntimeError, crm_sentence, fs + 1, 0, 0, 0, 0, 0, path=tempdir) # CRMPreload class crm = CRMPreload(fs, path=tempdir) - crm.sentence('f', 0, 0, 0, 0) + crm.sentence("f", 0, 0, 0, 0) # unprepared sampling rate pytest.raises(RuntimeError, CRMPreload, fs + 1) # bad value requested - pytest.raises(ValueError, crm.sentence, 1, 0, 0, 'periwinkle', 0) + pytest.raises(ValueError, crm.sentence, 1, 0, 0, "periwinkle", 0) # unprepared talker - pytest.raises(RuntimeError, crm.sentence, 'm', 0, 0, 0, 0) + pytest.raises(RuntimeError, crm.sentence, "m", 0, 0, 0, 0) # try to specify parameters like fs, stereo, etc. - pytest.raises(TypeError, crm.sentence, fs, '1', '0', 'charlie', 'red', '5') + pytest.raises(TypeError, crm.sentence, fs, "1", "0", "charlie", "red", "5") # add_pad x1 = np.zeros(10) x2 = np.ones((2, 5)) x = add_pad([x1, x2]) - assert (np.sum(x[..., -1] == 0)) - x = add_pad((x1, x2), 'center') - assert (np.sum(x[..., -1] == 0) and np.sum(x[..., 0] == 0)) - x = add_pad((x1, x2), 'end') - assert (np.sum(x[..., 0] == 0)) + assert np.sum(x[..., -1] == 0) + x = add_pad((x1, x2), "center") + assert np.sum(x[..., -1] == 0) and np.sum(x[..., 0] == 0) + x = add_pad((x1, x2), "end") + assert np.sum(x[..., 0] == 0) @pytest.mark.timeout(15) @requires_opengl21 def test_crm_response_menu(hide_window): """Test the CRM Response menu function.""" - with ExperimentController('crm_menu', **std_kwargs) as ec: + with ExperimentController("crm_menu", **std_kwargs) as ec: resp = crm_response_menu(ec, max_wait=0.05) crm_response_menu(ec, numbers=[0, 1, 2], max_wait=0.05) - crm_response_menu(ec, colors=['blue'], max_wait=0.05) - crm_response_menu(ec, colors=['r'], numbers=['7'], max_wait=0.05) + crm_response_menu(ec, colors=["blue"], max_wait=0.05) + crm_response_menu(ec, colors=["r"], numbers=["7"], max_wait=0.05) assert_equal(resp, (None, None)) - pytest.raises(ValueError, crm_response_menu, ec, - max_wait=0, min_wait=1) - pytest.raises(ValueError, crm_response_menu, ec, - colors=['g', 'g']) + pytest.raises(ValueError, crm_response_menu, ec, max_wait=0, min_wait=1) + pytest.raises(ValueError, crm_response_menu, ec, colors=["g", "g"]) diff --git a/expyfun/stimuli/tests/test_tracker.py b/expyfun/stimuli/tests/test_tracker.py index b22ca737..f50cc26d 100644 --- a/expyfun/stimuli/tests/test_tracker.py +++ b/expyfun/stimuli/tests/test_tracker.py @@ -1,10 +1,10 @@ import numpy as np - -from expyfun.stimuli import TrackerUD, TrackerBinom, TrackerDealer, TrackerMHW -from expyfun import ExperimentController import pytest from numpy.testing import assert_equal + +from expyfun import ExperimentController from expyfun._utils import requires_opengl21 +from expyfun.stimuli import TrackerBinom, TrackerDealer, TrackerMHW, TrackerUD def callback(event_type, value=None, timestamp=None): @@ -12,11 +12,20 @@ def callback(event_type, value=None, timestamp=None): print(event_type, value, timestamp) -std_kwargs = dict(output_dir=None, full_screen=False, window_size=(1, 1), - participant='foo', session='01', stim_db=0.0, noise_db=0.0, - trigger_controller='dummy', response_device='keyboard', - audio_controller='sound_card', - verbose=True, version='dev') +std_kwargs = dict( + output_dir=None, + full_screen=False, + window_size=(1, 1), + participant="foo", + session="01", + stim_db=0.0, + noise_db=0.0, + trigger_controller="dummy", + response_device="keyboard", + audio_controller="sound_card", + verbose=True, + version="dev", +) @pytest.mark.timeout(15) @@ -24,8 +33,9 @@ def callback(event_type, value=None, timestamp=None): def test_tracker_ud(hide_window): """Test TrackerUD""" import matplotlib.pyplot as plt + tr = TrackerUD(callback, 3, 1, 1, 1, np.inf, 10, 1) - with ExperimentController('test', **std_kwargs) as ec: + with ExperimentController("test", **std_kwargs) as ec: tr = TrackerUD(ec, 3, 1, 1, 1, np.inf, 10, 1) tr = TrackerUD(None, 3, 1, 1, 1, 10, np.inf, 1) rand = np.random.RandomState(0) @@ -70,95 +80,106 @@ def test_tracker_ud(hide_window): tr.check_valid(2) # bad callback type - with pytest.raises(TypeError, - match="callback must be a callable, None, or an"): - TrackerUD('foo', 3, 1, 1, 1, 10, np.inf, 1) + with pytest.raises(TypeError, match="callback must be a callable, None, or an"): + TrackerUD("foo", 3, 1, 1, 1, 10, np.inf, 1) # test dynamic step size and error conditions - tr = TrackerUD(None, 3, 1, [1, 0.5], [1, 0.5], 10, np.inf, 1, - change_indices=[2]) + tr = TrackerUD(None, 3, 1, [1, 0.5], [1, 0.5], 10, np.inf, 1, change_indices=[2]) tr.respond(True) - tr = TrackerUD(None, 1, 1, 0.75, 0.75, np.inf, 9, 1, - x_min=0, x_max=2) + tr = TrackerUD(None, 1, 1, 0.75, 0.75, np.inf, 9, 1, x_min=0, x_max=2) responses = [True, True, True, False, False, False, False, True, False] - with pytest.warns(UserWarning, match='exceeded x_min'): + with pytest.warns(UserWarning, match="exceeded x_min"): for r in responses: # run long enough to encounter change_indices tr.respond(r) - assert(tr.check_valid(1)) # make sure checking validity is good - assert(not tr.check_valid(3)) - with pytest.raises(ValueError, - match="with reversals attempting to exceed x_min"): + assert tr.check_valid(1) # make sure checking validity is good + assert not tr.check_valid(3) + with pytest.raises(ValueError, match="with reversals attempting to exceed x_min"): tr.threshold(1) tr.threshold(3) assert_equal(tr.n_trials, tr.stop_trials) # run tests with ignore too--should generate warnings, but no error - tr = TrackerUD(None, 1, 1, 0.75, 0.25, np.inf, 8, 1, - x_min=0, x_max=2, repeat_limit='ignore') + tr = TrackerUD( + None, 1, 1, 0.75, 0.25, np.inf, 8, 1, x_min=0, x_max=2, repeat_limit="ignore" + ) responses = [False, True, False, False, True, True, False, True] - with pytest.warns(UserWarning, match='exceeded x_min'): + with pytest.warns(UserWarning, match="exceeded x_min"): for r in responses: # run long enough to encounter change_indices tr.respond(r) tr.threshold(0) # bad stop_trials - with pytest.raises(ValueError, - match="stop_trials must be an integer or np.inf"): - TrackerUD(None, 3, 1, 1, 1, 10, 'foo', 1) + with pytest.raises(ValueError, match="stop_trials must be an integer or np.inf"): + TrackerUD(None, 3, 1, 1, 1, 10, "foo", 1) # bad stop_reversals - with pytest.raises(ValueError, - match="stop_reversals must be an integer or np.inf"): - TrackerUD(None, 3, 1, 1, 1, 'foo', 10, 1) + with pytest.raises(ValueError, match="stop_reversals must be an integer or np.inf"): + TrackerUD(None, 3, 1, 1, 1, "foo", 10, 1) # change_indices too long - with pytest.raises(ValueError, - match="one element longer than change_indices"): - TrackerUD(None, 3, 1, [1, 0.5], [1, 0.5], 10, np.inf, 1, - change_indices=[1, 2]) + with pytest.raises(ValueError, match="one element longer than change_indices"): + TrackerUD(None, 3, 1, [1, 0.5], [1, 0.5], 10, np.inf, 1, change_indices=[1, 2]) # step_size_up length mismatch - with pytest.raises(ValueError, - match="step_size_up is not scalar it must be one"): + with pytest.raises(ValueError, match="step_size_up is not scalar it must be one"): TrackerUD(None, 3, 1, [1], [1, 0.5], 10, np.inf, 1, change_indices=[2]) # step_size_down length mismatch - with pytest.raises(ValueError, - match="If step_size_down is not scalar it must be one"): + with pytest.raises( + ValueError, match="If step_size_down is not scalar it must be one" + ): TrackerUD(None, 3, 1, [1, 0.5], [1], 10, np.inf, 1, change_indices=[2]) # bad change_rule - with pytest.raises(ValueError, - match="must be either 'trials' or 'reversals"): - TrackerUD(None, 3, 1, [1, 0.5], [1, 0.5], 10, np.inf, 1, - change_indices=[2], change_rule='foo') + with pytest.raises(ValueError, match="must be either 'trials' or 'reversals"): + TrackerUD( + None, + 3, + 1, + [1, 0.5], + [1, 0.5], + 10, + np.inf, + 1, + change_indices=[2], + change_rule="foo", + ) # no change_indices (i.e. change_indices=None) - with pytest.raises(ValueError, - match="If step_size_up is longer than 1, you must"): + with pytest.raises(ValueError, match="If step_size_up is longer than 1, you must"): TrackerUD(None, 3, 1, [1, 0.5], [1, 0.5], 10, np.inf, 1) # start_value scalar type checking with pytest.raises(TypeError, match="start_value must be a scalar"): - TrackerUD(None, 3, 1, [1, 0.5], [1, 0.5], 10, np.inf, [9, 5], - change_indices=[2]) + TrackerUD( + None, 3, 1, [1, 0.5], [1, 0.5], 10, np.inf, [9, 5], change_indices=[2] + ) with pytest.raises(TypeError, match="start_value must be a scalar"): - TrackerUD(None, 3, 1, [1, 0.5], [1, 0.5], 10, np.inf, None, - change_indices=[2]) + TrackerUD(None, 3, 1, [1, 0.5], [1, 0.5], 10, np.inf, None, change_indices=[2]) # test with multiple change_indices - tr = TrackerUD(None, 3, 1, [3, 2, 1], [3, 2, 1], 10, np.inf, 1, - change_indices=[2, 4], change_rule='reversals') + tr = TrackerUD( + None, + 3, + 1, + [3, 2, 1], + [3, 2, 1], + 10, + np.inf, + 1, + change_indices=[2, 4], + change_rule="reversals", + ) @requires_opengl21 def test_tracker_binom(hide_window): """Test TrackerBinom""" tr = TrackerBinom(callback, 0.05, 0.1, 5) - with ExperimentController('test', **std_kwargs) as ec: + with ExperimentController("test", **std_kwargs) as ec: tr = TrackerBinom(ec, 0.05, 0.1, 5) tr = TrackerBinom(None, 0.05, 0.5, 2, stop_early=False) while not tr.stopped: tr.respond(False) - assert(tr.n_trials == 2) - assert(not tr.success) + assert tr.n_trials == 2 + assert not tr.success tr = TrackerBinom(None, 0.05, 0.5, 1000) while not tr.stopped: @@ -167,7 +188,7 @@ def test_tracker_binom(hide_window): tr = TrackerBinom(None, 0.05, 0.5, 1000, 100) while not tr.stopped: tr.respond(True) - assert(tr.n_trials == 100) + assert tr.n_trials == 100 tr.alpha tr.chance @@ -191,29 +212,38 @@ def test_tracker_binom(hide_window): def test_tracker_dealer(): """Test TrackerDealer.""" # test TrackerDealer with TrackerUD - trackers = [[TrackerUD(None, 1, 1, 0.06, 0.02, 20, np.inf, - 1) for _ in range(2)] for _ in range(3)] + trackers = [ + [TrackerUD(None, 1, 1, 0.06, 0.02, 20, np.inf, 1) for _ in range(2)] + for _ in range(3) + ] dealer_ud = TrackerDealer(callback, trackers) # can't respond to a trial twice dealer_ud.next() dealer_ud.respond(True) - with pytest.raises(RuntimeError, - match="You must get a trial before you can respond."): + with pytest.raises( + RuntimeError, match="You must get a trial before you can respond." + ): dealer_ud.respond(True) dealer_ud = TrackerDealer(callback, np.array(trackers)) # can't respond before you pick a tracker and get a trial - with pytest.raises(RuntimeError, - match="You must get a trial before you can respond."): + with pytest.raises( + RuntimeError, match="You must get a trial before you can respond." + ): dealer_ud.respond(True) rand = np.random.RandomState(0) for sub, x_current in dealer_ud: dealer_ud.respond(rand.rand() < x_current) - assert(np.abs(dealer_ud.trackers[0, 0].n_reversals - - dealer_ud.trackers[1, 0].n_reversals) <= 1) + assert ( + np.abs( + dealer_ud.trackers[0, 0].n_reversals + - dealer_ud.trackers[1, 0].n_reversals + ) + <= 1 + ) # test array-like indexing dealer_ud.trackers[0] @@ -225,39 +255,41 @@ def test_tracker_dealer(): dealer_ud.history(True) # bad rand type - trackers = [TrackerUD(None, 1, 1, 0.06, 0.02, 20, 50, 1) - for _ in range(2)] + trackers = [TrackerUD(None, 1, 1, 0.06, 0.02, 20, 50, 1) for _ in range(2)] with pytest.raises(TypeError, match="argument"): TrackerDealer(trackers, rand=1) # test TrackerDealer with TrackerBinom - trackers = [TrackerBinom(None, 0.05, 0.5, 50, stop_early=False) - for _ in range(2)] # start_value scalar type checking + trackers = [ + TrackerBinom(None, 0.05, 0.5, 50, stop_early=False) for _ in range(2) + ] # start_value scalar type checking with pytest.raises(TypeError, match="start_value must be a scalar"): - TrackerUD(None, 3, 1, [1, 0.5], [1, 0.5], 10, np.inf, [9, 5], - change_indices=[2]) - dealer_binom = TrackerDealer(callback, trackers, pace_rule='trials') + TrackerUD( + None, 3, 1, [1, 0.5], [1, 0.5], 10, np.inf, [9, 5], change_indices=[2] + ) + dealer_binom = TrackerDealer(callback, trackers, pace_rule="trials") for sub, x_current in dealer_binom: dealer_binom.respond(True) # if you're dealing from TrackerBinom, you can't use stop_early feature - trackers = [TrackerBinom(None, 0.05, 0.5, 50, stop_early=True, x_current=3) - for _ in range(2)] - with pytest.raises(ValueError, - match="be False to deal trials from a TrackerBinom"): - TrackerDealer(callback, trackers, 1, 'trials') + trackers = [ + TrackerBinom(None, 0.05, 0.5, 50, stop_early=True, x_current=3) + for _ in range(2) + ] + with pytest.raises(ValueError, match="be False to deal trials from a TrackerBinom"): + TrackerDealer(callback, trackers, 1, "trials") # if you're dealing from TrackerBinom, you can't use reversals to pace - with pytest.raises(ValueError, - match="be False to deal trials from a TrackerBinom"): + with pytest.raises(ValueError, match="be False to deal trials from a TrackerBinom"): TrackerDealer(callback, trackers, 1) def test_tracker_mhw(hide_window): """Test TrackerMHW""" import matplotlib.pyplot as plt + tr = TrackerMHW(callback, 0, 120) - with ExperimentController('test', **std_kwargs) as ec: + with ExperimentController("test", **std_kwargs) as ec: tr = TrackerMHW(ec, 0, 120) tr = TrackerMHW(None, 0, 120) rand = np.random.RandomState(0) @@ -268,16 +300,32 @@ def test_tracker_mhw(hide_window): rand = np.random.RandomState(0) while not tr.stopped: tr.respond(int(rand.rand() * 100) < tr.x_current) - assert(tr.check_valid(1)) # make sure checking validity is good + assert tr.check_valid(1) # make sure checking validity is good # test responding after stopped with pytest.raises(RuntimeError, match="Tracker is stopped."): tr.respond(0) - for key in ('base_step', 'factor_down', 'factor_up_nr', 'start_value', - 'x_min', 'x_max', 'n_up_stop', 'repeat_limit', - 'n_correct_levels', 'threshold', 'stopped', 'x', 'x_current', - 'responses', 'n_trials', 'n_reversals', 'reversals', - 'reversal_inds', 'threshold_reached'): + for key in ( + "base_step", + "factor_down", + "factor_up_nr", + "start_value", + "x_min", + "x_max", + "n_up_stop", + "repeat_limit", + "n_correct_levels", + "threshold", + "stopped", + "x", + "x_current", + "responses", + "n_trials", + "n_reversals", + "reversals", + "reversal_inds", + "threshold_reached", + ): assert hasattr(tr, key) fig, ax, lines = tr.plot() @@ -289,45 +337,42 @@ def test_tracker_mhw(hide_window): plt.close(fig) # start_value scalar type checking - with pytest.raises(TypeError, match='start_value must be a scalar'): + with pytest.raises(TypeError, match="start_value must be a scalar"): TrackerMHW(None, 0, 120, 5, 2, 4, [5, 4], 2) # n_up_stop integer check - with pytest.raises(TypeError, match='n_up_stop must be an integer'): + with pytest.raises(TypeError, match="n_up_stop must be an integer"): TrackerMHW(None, 0, 120, 5, 2, 4, 40, 1.5) # x_min integer or float check - with pytest.raises(TypeError, match='x_min must be a float or integer'): - TrackerMHW(None, '5', 120, 5, 2, 4, 40, 2) + with pytest.raises(TypeError, match="x_min must be a float or integer"): + TrackerMHW(None, "5", 120, 5, 2, 4, 40, 2) # x_max integer or float check - with pytest.raises(TypeError, match='x_max must be a float or integer'): - TrackerMHW(None, 0, '90', 5, 2, 4, 40, 2) + with pytest.raises(TypeError, match="x_max must be a float or integer"): + TrackerMHW(None, 0, "90", 5, 2, 4, 40, 2) # start_value is a multiple of base_step - with pytest.raises(ValueError, - match='start_value must be a multiple of base_step'): + with pytest.raises(ValueError, match="start_value must be a multiple of base_step"): TrackerMHW(None, 0, 120, 5, 2, 4, 41, 2) # x_min factor check - with pytest.raises(ValueError, - match='x_min must be a multiple of base_step'): + with pytest.raises(ValueError, match="x_min must be a multiple of base_step"): TrackerMHW(None, 2, 120, 5, 2, 4, 40, 2) # x_max factor check - with pytest.raises(ValueError, - match='x_max must be a multiple of base_step'): + with pytest.raises(ValueError, match="x_max must be a multiple of base_step"): TrackerMHW(None, 0, 93, 5, 2, 4, 40, 2) tr = TrackerMHW(None, 0, 120, 5, 2, 4, 10, 2) responses = [True, True, True, True] - with pytest.warns(UserWarning, match='exceeded x_min or x_max bounds'): + with pytest.warns(UserWarning, match="exceeded x_min or x_max bounds"): for r in responses: tr.respond(r) tr = TrackerMHW(None, 0, 120, 5, 2, 4, 40, 2) responses = [False, False, False, False, False] - with pytest.warns(UserWarning, match='exceeded x_min or x_max bounds'): + with pytest.warns(UserWarning, match="exceeded x_min or x_max bounds"): for r in responses: tr.respond(r) - assert(not tr.check_valid(3)) + assert not tr.check_valid(3) tr = TrackerMHW(None, 0, 120, 5, 2, 4, 40, 2) responses = [False, False, False, False, True, False, False, True] - with pytest.warns(UserWarning, match='exceeded x_min or x_max bounds'): + with pytest.warns(UserWarning, match="exceeded x_min or x_max bounds"): for r in responses: tr.respond(r) diff --git a/expyfun/tests/test_docstring_parameters.py b/expyfun/tests/test_docstring_parameters.py index db509a2e..0a8eeea7 100644 --- a/expyfun/tests/test_docstring_parameters.py +++ b/expyfun/tests/test_docstring_parameters.py @@ -1,11 +1,11 @@ import inspect -from inspect import getsource import os.path as op -from pkgutil import walk_packages import re import sys -from unittest import SkipTest import warnings +from inspect import getsource +from pkgutil import walk_packages +from unittest import SkipTest import pytest @@ -14,12 +14,12 @@ public_modules = [ # the list of modules users need to access for all functionality - 'expyfun', - 'expyfun.stimuli', - 'expyfun.io', - 'expyfun.visual', - 'expyfun.codeblocks', - 'expyfun.analyze', + "expyfun", + "expyfun.stimuli", + "expyfun.io", + "expyfun.visual", + "expyfun.codeblocks", + "expyfun.analyze", ] @@ -31,7 +31,7 @@ def requires_numpydoc(fun): have = False else: have = True - return pytest.mark.skipif(not have, reason='Requires numpydoc')(fun) + return pytest.mark.skipif(not have, reason="Requires numpydoc")(fun) def get_name(func, cls=None): @@ -43,70 +43,72 @@ def get_name(func, cls=None): if cls is not None: parts.append(cls.__name__) parts.append(func.__name__) - return '.'.join(parts) + return ".".join(parts) # functions to ignore args / docstring of -docstring_ignores = [ -] +docstring_ignores = [] char_limit = 800 # XX eventually we should probably get this lower -docstring_length_ignores = [ -] -tab_ignores = [ -] +docstring_length_ignores = [] +tab_ignores = [] _doc_special_members = [] def check_parameters_match(func, doc=None, cls=None): """Check docstring, return list of incorrect results.""" from numpydoc import docscrape + incorrect = [] name_ = get_name(func, cls=cls) - if not name_.startswith('expyfun.') or \ - name_.startswith('expyfun._externals'): + if not name_.startswith("expyfun.") or name_.startswith("expyfun._externals"): return incorrect if inspect.isdatadescriptor(func): return incorrect args = _get_args(func) # drop self - if len(args) > 0 and args[0] == 'self': + if len(args) > 0 and args[0] == "self": args = args[1:] if doc is None: with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') + warnings.simplefilter("always") try: doc = docscrape.FunctionDoc(func) except Exception as exp: - incorrect += [name_ + ' parsing error: ' + str(exp)] + incorrect += [name_ + " parsing error: " + str(exp)] return incorrect if len(w): - raise RuntimeError('Error for %s:\n%s' % (name_, w[0])) + raise RuntimeError("Error for %s:\n%s" % (name_, w[0])) # check set - parameters = doc['Parameters'] + parameters = doc["Parameters"] # clean up some docscrape output: - parameters = [[p[0].split(':')[0].strip('` '), p[2]] - for p in parameters] - parameters = [p for p in parameters if '*' not in p[0]] + parameters = [[p[0].split(":")[0].strip("` "), p[2]] for p in parameters] + parameters = [p for p in parameters if "*" not in p[0]] param_names = [p[0] for p in parameters] if len(param_names) != len(args): - bad = str(sorted(list(set(param_names) - set(args)) + - list(set(args) - set(param_names)))) - if not any(re.match(d, name_) for d in docstring_ignores) and \ - 'deprecation_wrapped' not in func.__code__.co_name: - incorrect += [name_ + ' arg mismatch: ' + bad] + bad = str( + sorted( + list(set(param_names) - set(args)) + list(set(args) - set(param_names)) + ) + ) + if ( + not any(re.match(d, name_) for d in docstring_ignores) + and "deprecation_wrapped" not in func.__code__.co_name + ): + incorrect += [name_ + " arg mismatch: " + bad] else: for n1, n2 in zip(param_names, args): if n1 != n2: - incorrect += [name_ + ' ' + n1 + ' != ' + n2] + incorrect += [name_ + " " + n1 + " != " + n2] for param_name, desc in parameters: - desc = '\n'.join(desc) - full_name = name_ + '::' + param_name + desc = "\n".join(desc) + full_name = name_ + "::" + param_name if full_name in docstring_length_ignores: assert len(desc) > char_limit # assert it actually needs to be elif len(desc) > char_limit: - incorrect += ['%s too long (%d > %d chars)' - % (full_name, len(desc), char_limit)] + incorrect += [ + "%s too long (%d > %d chars)" % (full_name, len(desc), char_limit) + ] return incorrect @@ -114,37 +116,39 @@ def check_parameters_match(func, doc=None, cls=None): def test_docstring_parameters(): """Test module docstring formatting.""" from numpydoc import docscrape + incorrect = [] for name in public_modules: with warnings.catch_warnings(record=True): - warnings.simplefilter('ignore') + warnings.simplefilter("ignore") module = __import__(name, globals()) - for submod in name.split('.')[1:]: + for submod in name.split(".")[1:]: module = getattr(module, submod) classes = inspect.getmembers(module, inspect.isclass) for cname, cls in classes: - if cname.startswith('_') and cname not in _doc_special_members: + if cname.startswith("_") and cname not in _doc_special_members: continue with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') + warnings.simplefilter("always") cdoc = docscrape.ClassDoc(cls) for ww in w: - if 'Using or importing the ABCs' not in str(ww.message): - raise RuntimeError('Error for __init__ of %s in %s:\n%s' - % (cls, name, ww)) - if hasattr(cls, '__init__'): + if "Using or importing the ABCs" not in str(ww.message): + raise RuntimeError( + "Error for __init__ of %s in %s:\n%s" % (cls, name, ww) + ) + if hasattr(cls, "__init__"): incorrect += check_parameters_match(cls.__init__, cdoc, cls) for method_name in cdoc.methods: method = getattr(cls, method_name) incorrect += check_parameters_match(method, cls=cls) - if hasattr(cls, '__call__'): + if hasattr(cls, "__call__"): incorrect += check_parameters_match(cls.__call__, cls=cls) functions = inspect.getmembers(module, inspect.isfunction) for fname, func in functions: - if fname.startswith('_'): + if fname.startswith("_"): continue incorrect += check_parameters_match(func) - msg = '\n' + '\n'.join(sorted(list(set(incorrect)))) + msg = "\n" + "\n".join(sorted(list(set(incorrect)))) if len(incorrect) > 0: raise AssertionError(msg) @@ -153,28 +157,27 @@ def test_tabs(): """Test that there are no tabs in our source files.""" # avoid importing modules that require mayavi if mayavi is not installed ignore = tab_ignores[:] - for importer, modname, ispkg in walk_packages(expyfun.__path__, - prefix='expyfun.'): + for importer, modname, ispkg in walk_packages(expyfun.__path__, prefix="expyfun."): if not ispkg and modname not in ignore: # mod = importlib.import_module(modname) # not py26 compatible! try: with warnings.catch_warnings(record=True): - warnings.simplefilter('ignore') + warnings.simplefilter("ignore") __import__(modname) except Exception: # can't import properly continue mod = sys.modules[modname] try: source = getsource(mod) - except IOError: # user probably should have run "make clean" + except OSError: # user probably should have run "make clean" continue - assert '\t' not in source, ('"%s" has tabs, please remove them ' - 'or add it to the ignore list' - % modname) + assert "\t" not in source, ( + '"%s" has tabs, please remove them ' + "or add it to the ignore list" % modname + ) -documented_ignored_mods = ( -) +documented_ignored_mods = () documented_ignored_names = """ add_pad fetch_data_file @@ -187,48 +190,51 @@ def test_tabs(): run_subprocess set_log_file verbose -""".split('\n') +""".split("\n") def test_documented(): """Test that public functions and classes are documented.""" # skip modules that require mayavi if mayavi is not installed public_modules_ = public_modules[:] - doc_file = op.abspath(op.join(op.dirname(__file__), '..', '..', 'doc', - 'python_reference.rst')) + doc_file = op.abspath( + op.join(op.dirname(__file__), "..", "..", "doc", "python_reference.rst") + ) if not op.isfile(doc_file): - raise SkipTest('Documentation file not found: %s' % doc_file) + raise SkipTest("Documentation file not found: %s" % doc_file) known_names = list() - with open(doc_file, 'rb') as fid: + with open(doc_file, "rb") as fid: for line in fid: - line = line.decode('utf-8') - if not line.startswith(' '): # at least two spaces + line = line.decode("utf-8") + if not line.startswith(" "): # at least two spaces continue line = line.split() - if len(line) == 1 and line[0] != ':': - known_names.append(line[0].split('.')[-1]) + if len(line) == 1 and line[0] != ":": + known_names.append(line[0].split(".")[-1]) known_names = set(known_names) missing = [] for name in public_modules_: with warnings.catch_warnings(record=True): # traits warnings - warnings.simplefilter('ignore') + warnings.simplefilter("ignore") module = __import__(name, globals()) - for submod in name.split('.')[1:]: + for submod in name.split(".")[1:]: module = getattr(module, submod) classes = inspect.getmembers(module, inspect.isclass) functions = inspect.getmembers(module, inspect.isfunction) checks = list(classes) + list(functions) for name, cf in checks: - if not name.startswith('_') and name not in known_names: + if not name.startswith("_") and name not in known_names: from_mod = inspect.getmodule(cf).__name__ - if (from_mod.startswith('expyfun') and - not from_mod.startswith('expyfun._externals') and - not any(from_mod.startswith(x) - for x in documented_ignored_mods) and - name not in documented_ignored_names): - missing.append('%s (%s.%s)' % (name, from_mod, name)) + if ( + from_mod.startswith("expyfun") + and not from_mod.startswith("expyfun._externals") + and not any(from_mod.startswith(x) for x in documented_ignored_mods) + and name not in documented_ignored_names + ): + missing.append("%s (%s.%s)" % (name, from_mod, name)) if len(missing) > 0: - raise AssertionError('\n\nFound new public members missing from ' - 'doc/python_reference.rst:\n\n* ' + - '\n* '.join(sorted(set(missing)))) + raise AssertionError( + "\n\nFound new public members missing from " + "doc/python_reference.rst:\n\n* " + "\n* ".join(sorted(set(missing))) + ) diff --git a/expyfun/tests/test_experiment_controller.py b/expyfun/tests/test_experiment_controller.py index 22ef01d2..a9332133 100644 --- a/expyfun/tests/test_experiment_controller.py +++ b/expyfun/tests/test_experiment_controller.py @@ -1,27 +1,43 @@ +import sys +import warnings from contextlib import contextmanager from copy import deepcopy from functools import partial -import sys -import warnings import numpy as np -from numpy.testing import assert_equal import pytest -from numpy.testing import assert_allclose +from numpy.testing import assert_allclose, assert_equal -from expyfun import ExperimentController, visual, _experiment_controller +from expyfun import ExperimentController, _experiment_controller, visual from expyfun._experiment_controller import _get_dev_db -from expyfun._utils import (_TempDir, fake_button_press, _check_skip_backend, - fake_mouse_click, requires_opengl21, - _wait_secs as wait_secs, known_config_types, - _new_pyglet) from expyfun._sound_controllers._sound_controller import _SOUND_CARD_KEYS +from expyfun._utils import ( + _check_skip_backend, + _new_pyglet, + _TempDir, + fake_button_press, + fake_mouse_click, + known_config_types, + requires_opengl21, +) +from expyfun._utils import ( + _wait_secs as wait_secs, +) from expyfun.stimuli import get_tdt_rates -std_args = ['test'] # experiment name -std_kwargs = dict(output_dir=None, full_screen=False, window_size=(8, 8), - participant='foo', session='01', stim_db=0.0, noise_db=0.0, - verbose=True, version='dev') +std_args = ["test"] # experiment name +std_kwargs = dict( + output_dir=None, + full_screen=False, + window_size=(8, 8), + participant="foo", + session="01", + stim_db=0.0, + noise_db=0.0, + verbose=True, + version="dev", +) +SAFE_DELAY = 0.5 if sys.platform.startswith("win") else 0.2 def dummy_print(string): @@ -29,44 +45,43 @@ def dummy_print(string): print(string) -@pytest.mark.parametrize('ws', [(2, 1), (1, 1)]) +@pytest.mark.parametrize("ws", [(2, 1), (1, 1)]) def test_unit_conversions(hide_window, ws): """Test unit conversions.""" kwargs = deepcopy(std_kwargs) - kwargs['stim_fs'] = 44100 - kwargs['window_size'] = ws + kwargs["stim_fs"] = 44100 + kwargs["window_size"] = ws with ExperimentController(*std_args, **kwargs) as ec: verts = np.random.rand(2, 4) - for to in ['norm', 'pix', 'deg', 'cm']: - for fro in ['norm', 'pix', 'deg', 'cm']: + for to in ["norm", "pix", "deg", "cm"]: + for fro in ["norm", "pix", "deg", "cm"]: v2 = ec._convert_units(verts, fro, to) v2 = ec._convert_units(v2, to, fro) assert_allclose(verts, v2) # test that degrees yield equiv. pixels in both directions verts = np.ones((2, 1)) - v0 = ec._convert_units(verts, 'deg', 'pix') + v0 = ec._convert_units(verts, "deg", "pix") verts = np.zeros((2, 1)) - v1 = ec._convert_units(verts, 'deg', 'pix') + v1 = ec._convert_units(verts, "deg", "pix") v2 = v0 - v1 # must check deviation from zero position assert_allclose(v2[0], v2[1]) - pytest.raises(ValueError, ec._convert_units, verts, 'deg', 'nothing') - pytest.raises(RuntimeError, ec._convert_units, verts[0], 'deg', 'pix') + pytest.raises(ValueError, ec._convert_units, verts, "deg", "nothing") + pytest.raises(RuntimeError, ec._convert_units, verts[0], "deg", "pix") def test_validate_audio(hide_window): """Test that validate_audio can pass through samples.""" - with ExperimentController(*std_args, suppress_resamp=True, - **std_kwargs) as ec: + with ExperimentController(*std_args, suppress_resamp=True, **std_kwargs) as ec: ec.set_stim_db(_get_dev_db(ec.audio_type) - 40) # 0.01 RMS - assert ec._stim_scaler == 1. + assert ec._stim_scaler == 1.0 for shape in ((1000,), (1, 1000), (2, 1000)): samples_in = np.zeros(shape) samples_out = ec._validate_audio(samples_in) assert samples_out.shape == (1000, 2) assert samples_out.dtype == np.float32 assert samples_out is not samples_in - for order in 'CF': + for order in "CF": samples_in = np.zeros((2, 1000), dtype=np.float32, order=order) samples_out = ec._validate_audio(samples_in) assert samples_out.shape == samples_in.shape[::-1] @@ -77,17 +92,13 @@ def test_validate_audio(hide_window): def test_data_line(hide_window): """Test writing of data lines.""" - entries = [['foo'], - ['bar', 'bar\tbar'], - ['bar2', r'bar\tbar'], - ['fb', None, -0.5]] + entries = [["foo"], ["bar", "bar\tbar"], ["bar2", r"bar\tbar"], ["fb", None, -0.5]] # this is what should be written to the file for each one - goal_vals = ['None', 'bar\\tbar', 'bar\\\\tbar', 'None'] + goal_vals = ["None", "bar\\tbar", "bar\\\\tbar", "None"] assert_equal(len(entries), len(goal_vals)) temp_dir = _TempDir() with std_kwargs_changed(output_dir=temp_dir): - with ExperimentController(*std_args, stim_fs=44100, - **std_kwargs) as ec: + with ExperimentController(*std_args, stim_fs=44100, **std_kwargs) as ec: for ent in entries: ec.write_data_line(*ent) fname = ec._data_file.name @@ -95,18 +106,17 @@ def test_data_line(hide_window): lines = fid.readlines() # check the header assert_equal(len(lines), len(entries) + 4) # header, colnames, flip, stop - assert_equal(lines[0][0], '#') # first line is a comment - for x in ['timestamp', 'event', 'value']: # second line is col header - assert (x in lines[1]) - assert ('flip' in lines[2]) # ec.__init__ ends with a flip - assert ('stop' in lines[-1]) # last line is stop (from __exit__) - outs = lines[1].strip().split('\t') - assert (all(l1 == l2 for l1, l2 in zip(outs, ['timestamp', - 'event', 'value']))) + assert_equal(lines[0][0], "#") # first line is a comment + for x in ["timestamp", "event", "value"]: # second line is col header + assert x in lines[1] + assert "flip" in lines[2] # ec.__init__ ends with a flip + assert "stop" in lines[-1] # last line is stop (from __exit__) + outs = lines[1].strip().split("\t") + assert all(l1 == l2 for l1, l2 in zip(outs, ["timestamp", "event", "value"])) # check the entries ts = [] for line, ent, gv in zip(lines[3:], entries, goal_vals): - outs = line.strip().split('\t') + outs = line.strip().split("\t") assert_equal(len(outs), 3) # check timestamping if len(ent) == 3 and ent[2] is not None: @@ -119,7 +129,7 @@ def test_data_line(hide_window): assert_equal(outs[2], gv) # make sure we got monotonically increasing timestamps ts = np.array(ts) - assert (np.all(ts[1:] >= ts[:-1])) + assert np.all(ts[1:] >= ts[:-1]) @contextmanager @@ -138,133 +148,196 @@ def std_kwargs_changed(**kwargs): def test_degenerate(): """Test degenerate EC conditions.""" - pytest.raises(TypeError, ExperimentController, *std_args, - audio_controller=1, stim_fs=44100, **std_kwargs) - pytest.raises(ValueError, ExperimentController, *std_args, - audio_controller='foo', stim_fs=44100, **std_kwargs) - pytest.raises(ValueError, ExperimentController, *std_args, - audio_controller=dict(TYPE='foo'), stim_fs=44100, - **std_kwargs) + pytest.raises( + TypeError, + ExperimentController, + *std_args, + audio_controller=1, + stim_fs=44100, + **std_kwargs, + ) + pytest.raises( + ValueError, + ExperimentController, + *std_args, + audio_controller="foo", + stim_fs=44100, + **std_kwargs, + ) + pytest.raises( + ValueError, + ExperimentController, + *std_args, + audio_controller=dict(TYPE="foo"), + stim_fs=44100, + **std_kwargs, + ) # monitor, etc. - pytest.raises(TypeError, ExperimentController, *std_args, - monitor='foo', **std_kwargs) - pytest.raises(KeyError, ExperimentController, *std_args, - monitor=dict(), **std_kwargs) - pytest.raises(ValueError, ExperimentController, *std_args, - response_device='foo', **std_kwargs) - with std_kwargs_changed(window_size=10.): - pytest.raises(ValueError, ExperimentController, *std_args, - **std_kwargs) - pytest.raises(ValueError, ExperimentController, *std_args, - audio_controller='sound_card', response_device='tdt', - **std_kwargs) - pytest.raises(ValueError, ExperimentController, *std_args, - audio_controller='pyglet', response_device='keyboard', - trigger_controller='sound_card', **std_kwargs) + pytest.raises( + TypeError, ExperimentController, *std_args, monitor="foo", **std_kwargs + ) + pytest.raises( + KeyError, ExperimentController, *std_args, monitor=dict(), **std_kwargs + ) + pytest.raises( + ValueError, ExperimentController, *std_args, response_device="foo", **std_kwargs + ) + with std_kwargs_changed(window_size=10.0): + pytest.raises(ValueError, ExperimentController, *std_args, **std_kwargs) + pytest.raises( + ValueError, + ExperimentController, + *std_args, + audio_controller="sound_card", + response_device="tdt", + **std_kwargs, + ) + pytest.raises( + ValueError, + ExperimentController, + *std_args, + audio_controller="pyglet", + response_device="keyboard", + trigger_controller="sound_card", + **std_kwargs, + ) # test type checking for 'session' with std_kwargs_changed(session=1): - pytest.raises(TypeError, ExperimentController, *std_args, - audio_controller='sound_card', stim_fs=44100, - **std_kwargs) + pytest.raises( + TypeError, + ExperimentController, + *std_args, + audio_controller="sound_card", + stim_fs=44100, + **std_kwargs, + ) # test value checking for trigger controller - pytest.raises(ValueError, ExperimentController, *std_args, - audio_controller='sound_card', trigger_controller='foo', - stim_fs=44100, **std_kwargs) + pytest.raises( + ValueError, + ExperimentController, + *std_args, + audio_controller="sound_card", + trigger_controller="foo", + stim_fs=44100, + **std_kwargs, + ) # test value checking for RMS checker - pytest.raises(ValueError, ExperimentController, *std_args, - audio_controller='sound_card', check_rms=True, stim_fs=44100, - **std_kwargs) - - -@pytest.mark.timeout(20) + pytest.raises( + ValueError, + ExperimentController, + *std_args, + audio_controller="sound_card", + check_rms=True, + stim_fs=44100, + **std_kwargs, + ) + + +@pytest.mark.timeout(120) def test_ec(ac, hide_window, monkeypatch): """Test EC methods.""" - if ac == 'tdt': - rd, tc, fs = 'tdt', 'tdt', get_tdt_rates()['25k'] - pytest.raises(ValueError, ExperimentController, *std_args, - audio_controller=dict(TYPE=ac, TDT_MODEL='foo'), - **std_kwargs) + if ac == "tdt": + rd, tc, fs = "tdt", "tdt", get_tdt_rates()["25k"] + pytest.raises( + ValueError, + ExperimentController, + *std_args, + audio_controller=dict(TYPE=ac, TDT_MODEL="foo"), + **std_kwargs, + ) else: _check_skip_backend(ac) - rd, tc, fs = 'keyboard', 'dummy', 44100 + rd, tc, fs = "keyboard", "dummy", 44100 for suppress in (True, False): with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') + warnings.simplefilter("always") with ExperimentController( - *std_args, audio_controller=ac, response_device=rd, - trigger_controller=tc, stim_fs=100., - suppress_resamp=suppress, **std_kwargs) as ec: + *std_args, + audio_controller=ac, + response_device=rd, + trigger_controller=tc, + stim_fs=100.0, + suppress_resamp=suppress, + **std_kwargs, + ) as ec: pass - w = [ww for ww in w if 'TDT is in dummy mode' in str(ww.message)] - assert len(w) == (1 if ac == 'tdt' else 0) - SAFE_DELAY = 0.2 + w = [ww for ww in w if "TDT is in dummy mode" in str(ww.message)] + assert len(w) == (1 if ac == "tdt" else 0) with ExperimentController( - *std_args, audio_controller=ac, response_device=rd, - trigger_controller=tc, stim_fs=fs, **std_kwargs) as ec: - assert (ec.participant == std_kwargs['participant']) - assert (ec.session == std_kwargs['session']) - assert (ec.exp_name == std_args[0]) + *std_args, + audio_controller=ac, + response_device=rd, + trigger_controller=tc, + stim_fs=fs, + **std_kwargs, + ) as ec: + assert ec.participant == std_kwargs["participant"] + assert ec.session == std_kwargs["session"] + assert ec.exp_name == std_args[0] stamp = ec.current_time - ec.write_data_line('hello') + ec.write_data_line("hello") ec.wait_until(stamp + 0.02) - ec.screen_prompt('test', 0.01, 0, None) - ec.screen_prompt('test', 0.01, 0, ['1']) - ec.screen_prompt(['test', 'ing'], 0.01, 0, ['1']) - ec.screen_prompt('test', 1e-3, click=True) - pytest.raises(ValueError, ec.screen_prompt, 'foo', np.inf, 0, []) + ec.screen_prompt("test", 0.01, 0, None) + ec.screen_prompt("test", 0.01, 0, ["1"]) + ec.screen_prompt(["test", "ing"], 0.01, 0, ["1"]) + ec.screen_prompt("test", 1e-3, click=True) + pytest.raises(ValueError, ec.screen_prompt, "foo", np.inf, 0, []) pytest.raises(TypeError, ec.screen_prompt, 3, 0.01, 0, None) assert_equal(ec.wait_one_press(0.01), (None, None)) - assert (ec.wait_one_press(0.01, timestamp=False) is None) + assert ec.wait_one_press(0.01, timestamp=False) is None assert_equal(ec.wait_for_presses(0.01), []) assert_equal(ec.wait_for_presses(0.01, timestamp=False), []) pytest.raises(ValueError, ec.get_presses) ec.listen_presses() assert_equal(ec.get_presses(), []) - assert_equal(ec.get_presses(kind='presses'), []) - pytest.raises(ValueError, ec.get_presses, kind='foo') - if rd == 'tdt': + assert_equal(ec.get_presses(kind="presses"), []) + pytest.raises(ValueError, ec.get_presses, kind="foo") + if rd == "tdt": # TDT does not have key release events, so should raise an # exception if asked for them: - pytest.raises(RuntimeError, ec.get_presses, kind='releases') - pytest.raises(RuntimeError, ec.get_presses, kind='both') + pytest.raises(RuntimeError, ec.get_presses, kind="releases") + pytest.raises(RuntimeError, ec.get_presses, kind="both") else: - assert_equal(ec.get_presses(kind='both'), []) - assert_equal(ec.get_presses(kind='releases'), []) + assert_equal(ec.get_presses(kind="both"), []) + assert_equal(ec.get_presses(kind="releases"), []) ec.set_noise_db(0) ec.set_stim_db(20) # test buffer data handling ec.set_rms_checking(None) ec.load_buffer([0, 0, 0, 0, 0, 0]) + ec.wait_secs(SAFE_DELAY) ec.load_buffer([]) + ec.wait_secs(SAFE_DELAY) pytest.raises(ValueError, ec.load_buffer, [0, 2, 0, 0, 0, 0]) ec.load_buffer(np.zeros((100,))) - with pytest.raises(ValueError, match='100 did not match .* count 2'): + with pytest.raises(ValueError, match="100 did not match .* count 2"): ec.load_buffer(np.zeros((100, 1))) - with pytest.raises(ValueError, match='100 did not match .* count 2'): + with pytest.raises(ValueError, match="100 did not match .* count 2"): ec.load_buffer(np.zeros((100, 2))) + ec.wait_secs(SAFE_DELAY) ec.load_buffer(np.zeros((1, 100))) ec.load_buffer(np.zeros((2, 100))) data = np.zeros(int(5e6), np.float32) # too long for TDT - if fs == get_tdt_rates()['25k']: + if fs == get_tdt_rates()["25k"]: pytest.raises(RuntimeError, ec.load_buffer, data) else: ec.load_buffer(data) ec.load_buffer(np.zeros(2)) del data - pytest.raises(ValueError, ec.stamp_triggers, 'foo') + pytest.raises(ValueError, ec.stamp_triggers, "foo") pytest.raises(ValueError, ec.stamp_triggers, 0) pytest.raises(ValueError, ec.stamp_triggers, 3) - pytest.raises(ValueError, ec.stamp_triggers, 1, check='foo') + pytest.raises(ValueError, ec.stamp_triggers, 1, check="foo") print(ec._tc) # test __repr__ - if tc == 'dummy': + if tc == "dummy": assert_equal(ec._tc._trigger_list, []) - ec.stamp_triggers(3, check='int4') + ec.stamp_triggers(3, check="int4") ec.stamp_triggers(2) ec.stamp_triggers([2, 4, 8]) - if tc == 'dummy': + if tc == "dummy": assert_equal(ec._tc._trigger_list, [3, 2, 2, 4, 8]) ec._tc._trigger_list = list() pytest.raises(ValueError, ec.load_buffer, np.zeros((100, 3))) @@ -272,38 +345,39 @@ def test_ec(ac, hide_window, monkeypatch): pytest.raises(ValueError, ec.load_buffer, np.zeros((1, 1, 1))) # test RMS checking - pytest.raises(ValueError, ec.set_rms_checking, 'foo') + pytest.raises(ValueError, ec.set_rms_checking, "foo") # click: RMS 0.0135, should pass 'fullfile' and fail 'windowed' click = np.zeros((int(ec.fs / 4),)) # 250 ms - click[len(click) // 2] = 1. - click[len(click) // 2 + 1] = -1. + click[len(click) // 2] = 1.0 + click[len(click) // 2 + 1] = -1.0 # noise: RMS 0.03, should fail both 'fullfile' and 'windowed' noise = np.random.normal(scale=0.03, size=(int(ec.fs / 4),)) ec.set_rms_checking(None) ec.load_buffer(click) # should go unchecked ec.load_buffer(noise) # should go unchecked - ec.set_rms_checking('wholefile') + ec.set_rms_checking("wholefile") ec.load_buffer(click) # should pass - with pytest.warns(UserWarning, match='exceeds stated'): + with pytest.warns(UserWarning, match="exceeds stated"): ec.load_buffer(noise) ec.wait_secs(SAFE_DELAY) - ec.set_rms_checking('windowed') - with pytest.warns(UserWarning, match='exceeds stated'): + ec.set_rms_checking("windowed") + with pytest.warns(UserWarning, match="exceeds stated"): ec.load_buffer(click) ec.wait_secs(SAFE_DELAY) - with pytest.warns(UserWarning, match='exceeds stated'): + with pytest.warns(UserWarning, match="exceeds stated"): ec.load_buffer(noise) - if ac != 'tdt': # too many samples there - monkeypatch.setattr(_experiment_controller, '_SLOW_LIMIT', 1) - with pytest.warns(UserWarning, match='samples is slow'): + if ac != "tdt": # too many samples there + monkeypatch.setattr(_experiment_controller, "_SLOW_LIMIT", 1) + with pytest.warns(UserWarning, match="samples is slow"): ec.load_buffer(np.zeros(2, dtype=np.float32)) - monkeypatch.setattr(_experiment_controller, '_SLOW_LIMIT', 1e7) + monkeypatch.setattr(_experiment_controller, "_SLOW_LIMIT", 1e7) ec.stop() ec.set_visible() ec.set_visible(False) - ec.call_on_every_flip(partial(dummy_print, 'called start stimuli')) + ec.call_on_every_flip(partial(dummy_print, "called start stimuli")) ec.wait_secs(SAFE_DELAY) + ec._ac_flush() # Note: we put some wait_secs in here because otherwise the delay in # play start (e.g. for trigdel and onsetdel) can @@ -316,43 +390,43 @@ def test_ec(ac, hide_window, monkeypatch): noise = np.random.normal(scale=0.01, size=(int(ec.fs),)) ec.load_buffer(noise) pytest.raises(RuntimeError, ec.start_stimulus) # order violation - assert (ec._playing is False) - if tc == 'dummy': + assert ec._playing is False + if tc == "dummy": assert_equal(ec._tc._trigger_list, []) - ec.start_stimulus(start_of_trial=False) # should work - if tc == 'dummy': + ec.start_stimulus(start_of_trial=False) # should work + if tc == "dummy": assert_equal(ec._tc._trigger_list, [1]) ec.wait_secs(SAFE_DELAY) - assert (ec._playing is True) - pytest.raises(RuntimeError, ec.trial_ok) # order violation + assert ec._playing is True + pytest.raises(RuntimeError, ec.trial_ok) # order violation ec.stop() - assert (ec._playing is False) + assert ec._playing is False # only binary for TTL - pytest.raises(KeyError, ec.identify_trial, ec_id='foo') # need ttl_id - pytest.raises(TypeError, ec.identify_trial, ec_id='foo', ttl_id='bar') - pytest.raises(ValueError, ec.identify_trial, ec_id='foo', ttl_id=[2]) - assert (ec._playing is False) - if tc == 'dummy': + pytest.raises(KeyError, ec.identify_trial, ec_id="foo") # need ttl_id + pytest.raises(TypeError, ec.identify_trial, ec_id="foo", ttl_id="bar") + pytest.raises(ValueError, ec.identify_trial, ec_id="foo", ttl_id=[2]) + assert ec._playing is False + if tc == "dummy": ec._tc._trigger_list = list() - ec.identify_trial(ec_id='foo', ttl_id=[0, 1]) - assert (ec._playing is False) + ec.identify_trial(ec_id="foo", ttl_id=[0, 1]) + assert ec._playing is False # # Second: start_stimuli # - pytest.raises(RuntimeError, ec.identify_trial, ec_id='foo', ttl_id=[0]) - assert (ec._playing is False) - pytest.raises(RuntimeError, ec.trial_ok) # order violation - assert (ec._playing is False) + pytest.raises(RuntimeError, ec.identify_trial, ec_id="foo", ttl_id=[0]) + assert ec._playing is False + pytest.raises(RuntimeError, ec.trial_ok) # order violation + assert ec._playing is False ec.start_stimulus(flip=False, when=-1) - if tc == 'dummy': + if tc == "dummy": assert_equal(ec._tc._trigger_list, [4, 8, 1]) - if ac != 'tdt': + if ac != "tdt": # dummy TDT version won't do this check properly, as # ec._ac._playing -> GetTagVal('playing') always gives False pytest.raises(RuntimeError, ec.play) # already played, must stop ec.wait_secs(SAFE_DELAY) ec.stop() - assert (ec._playing is False) + assert ec._playing is False # # Third: trial_ok # @@ -361,28 +435,28 @@ def test_ec(ac, hide_window, monkeypatch): ec.trial_ok() # double-check pytest.raises(RuntimeError, ec.start_stimulus) # order violation - ec.start_stimulus(start_of_trial=False) # should work - pytest.raises(RuntimeError, ec.trial_ok) # order violation + ec.start_stimulus(start_of_trial=False) # should work + pytest.raises(RuntimeError, ec.trial_ok) # order violation ec.wait_secs(SAFE_DELAY) ec.stop() - assert (ec._playing is False) + assert ec._playing is False ec.flip(-np.inf) - assert (ec._playing is False) + assert ec._playing is False ec.estimate_screen_fs() - assert (ec._playing is False) + assert ec._playing is False ec.play() ec.wait_secs(SAFE_DELAY) - assert (ec._playing is True) + assert ec._playing is True ec.call_on_every_flip(None) # something funny with the ring buffer in testing on OSX - if sys.platform != 'darwin': + if sys.platform != "darwin": ec.call_on_next_flip(ec.start_noise()) ec.flip() ec.wait_secs(SAFE_DELAY) ec.stop_noise() ec.stop() - assert (ec._playing is False) + assert ec._playing is False ec.stop_noise() ec.wait_secs(SAFE_DELAY) ec.start_stimulus(start_of_trial=False) @@ -404,150 +478,169 @@ def test_ec(ac, hide_window, monkeypatch): # we need to monkey-patch for old Pyglet try: from PIL import Image + Image.fromstring except AttributeError: Image.fromstring = None data = ec.screenshot() # HiDPI - sizes = [tuple(std_kwargs['window_size']), - tuple(np.array(std_kwargs['window_size']) * 2)] + sizes = [ + tuple(std_kwargs["window_size"]), + tuple(np.array(std_kwargs["window_size"]) * 2), + ] assert data.shape[:2] in sizes print(ec.fs) # test fs support wait_secs(0.01) test_pix = (11.3, 0.5, 110003) print(test_pix) # test __repr__ - assert all([x in repr(ec) for x in ['foo', '"test"', '01']]) + assert all([x in repr(ec) for x in ["foo", '"test"', "01"]]) ec.refocus() # smoke test for refocusing del ec -@pytest.mark.parametrize('screen_num', (None, 0)) -@pytest.mark.parametrize('monitor', ( - None, - dict(SCREEN_WIDTH=10, SCREEN_DISTANCE=10, SCREEN_SIZE_PIX=(1000, 1000)), -)) +@pytest.mark.parametrize("screen_num", (None, 0)) +@pytest.mark.parametrize( + "monitor", + ( + None, + dict(SCREEN_WIDTH=10, SCREEN_DISTANCE=10, SCREEN_SIZE_PIX=(1000, 1000)), + ), +) def test_screen_monitor(screen_num, monitor, hide_window): """Test screen and monitor option support.""" with ExperimentController( - *std_args, screen_num=screen_num, monitor=monitor, - **std_kwargs): + *std_args, screen_num=screen_num, monitor=monitor, **std_kwargs + ): pass full_kwargs = deepcopy(std_kwargs) - full_kwargs['full_screen'] = True - with pytest.raises(RuntimeError, match='resolution set incorrectly'): + full_kwargs["full_screen"] = True + with pytest.raises(RuntimeError, match="resolution set incorrectly"): ExperimentController(*std_args, **full_kwargs) - with pytest.raises(TypeError, match='must be a dict'): + with pytest.raises(TypeError, match="must be a dict"): ExperimentController(*std_args, monitor=1, **std_kwargs) - with pytest.raises(KeyError, match='is missing required keys'): + with pytest.raises(KeyError, match="is missing required keys"): ExperimentController(*std_args, monitor={}, **std_kwargs) def test_tdtpy_failure(hide_window): """Test that failed TDTpy import raises ImportError.""" try: - from tdt.util import connect_rpcox # noqa, analysis:ignore + from tdt.util import connect_rpcox # noqa: F401 except ImportError: pass else: - pytest.skip('Cannot test TDT import failure') - ac = dict(TYPE='tdt', TDT_MODEL='RP2') - with pytest.raises(ImportError, match='No module named'): + pytest.skip("Cannot test TDT import failure") + ac = dict(TYPE="tdt", TDT_MODEL="RP2") + with pytest.raises(ImportError, match="No module named"): ExperimentController( - *std_args, audio_controller=ac, response_device='keyboard', - trigger_controller='tdt', stim_fs=100., - suppress_resamp=True, **std_kwargs) + *std_args, + audio_controller=ac, + response_device="keyboard", + trigger_controller="tdt", + stim_fs=100.0, + suppress_resamp=True, + **std_kwargs, + ) @pytest.mark.timeout(10) def test_button_presses_and_window_size(hide_window): """Test EC window_size=None and button press capture.""" - with ExperimentController(*std_args, audio_controller='sound_card', - response_device='keyboard', window_size=None, - output_dir=None, full_screen=False, session='01', - participant='foo', trigger_controller='dummy', - force_quit='escape', version='dev') as ec: + with ExperimentController( + *std_args, + audio_controller="sound_card", + response_device="keyboard", + window_size=None, + output_dir=None, + full_screen=False, + session="01", + participant="foo", + trigger_controller="dummy", + force_quit="escape", + version="dev", + ) as ec: ec.listen_presses() ec.get_presses() assert_equal(ec.get_presses(), []) - fake_button_press(ec, '1', 0.5) - assert_equal(ec.screen_prompt('press 1', live_keys=['1'], - max_wait=1.5), '1') + fake_button_press(ec, "1", 0.5) + assert_equal(ec.screen_prompt("press 1", live_keys=["1"], max_wait=1.5), "1") ec.listen_presses() assert_equal(ec.get_presses(), []) - fake_button_press(ec, '1') - assert_equal(ec.get_presses(timestamp=False), [('1',)]) + fake_button_press(ec, "1") + assert_equal(ec.get_presses(timestamp=False), [("1",)]) ec.listen_presses() - fake_button_press(ec, '1') + fake_button_press(ec, "1") presses = ec.get_presses(timestamp=True, relative_to=0.2) assert_equal(len(presses), 1) assert_equal(len(presses[0]), 2) - assert_equal(presses[0][0], '1') - assert (isinstance(presses[0][1], float)) + assert_equal(presses[0][0], "1") + assert isinstance(presses[0][1], float) ec.listen_presses() - fake_button_press(ec, '1') - presses = ec.get_presses(timestamp=True, relative_to=0.1, - return_kinds=True) + fake_button_press(ec, "1") + presses = ec.get_presses(timestamp=True, relative_to=0.1, return_kinds=True) assert_equal(len(presses), 1) assert_equal(len(presses[0]), 3) - assert_equal(presses[0][::2], ('1', 'press')) - assert (isinstance(presses[0][1], float)) + assert_equal(presses[0][::2], ("1", "press")) + assert isinstance(presses[0][1], float) ec.listen_presses() - fake_button_press(ec, '1') + fake_button_press(ec, "1") presses = ec.get_presses(timestamp=False, return_kinds=True) - assert_equal(presses, [('1', 'press')]) + assert_equal(presses, [("1", "press")]) ec.listen_presses() - ec.screen_text('press 1 again') + ec.screen_text("press 1 again") ec.flip() - fake_button_press(ec, '1', 0.3) - assert_equal(ec.wait_one_press(1.5, live_keys=[1])[0], '1') - ec.screen_text('press 1 one last time') + fake_button_press(ec, "1", 0.3) + assert_equal(ec.wait_one_press(1.5, live_keys=[1])[0], "1") + ec.screen_text("press 1 one last time") ec.flip() - fake_button_press(ec, '1', 0.3) - out = ec.wait_for_presses(1.5, live_keys=['1'], timestamp=False) - assert_equal(out[0], '1') - fake_button_press(ec, 'a', 0.3) - fake_button_press(ec, 'return', 0.5) - assert ec.text_input() == 'A' - fake_button_press(ec, 'a', 0.3) - fake_button_press(ec, 'space', 0.35) - fake_button_press(ec, 'backspace', 0.4) - fake_button_press(ec, 'comma', 0.45) - fake_button_press(ec, 'return', 0.5) + fake_button_press(ec, "1", 0.3) + out = ec.wait_for_presses(1.5, live_keys=["1"], timestamp=False) + assert_equal(out[0], "1") + fake_button_press(ec, "a", 0.3) + fake_button_press(ec, "return", 0.5) + assert ec.text_input() == "A" + fake_button_press(ec, "a", 0.3) + fake_button_press(ec, "space", 0.35) + fake_button_press(ec, "backspace", 0.4) + fake_button_press(ec, "comma", 0.45) + fake_button_press(ec, "return", 0.5) # XXX this fails on OSX travis for some reason new_pyglet = _new_pyglet() - bad = sys.platform == 'darwin' - bad |= sys.platform == 'win32' and new_pyglet + bad = sys.platform == "darwin" + bad |= sys.platform == "win32" and new_pyglet if not bad: - assert ec.text_input(all_caps=False).strip() == 'a' + assert ec.text_input(all_caps=False).strip() == "a" @pytest.mark.timeout(10) @requires_opengl21 def test_mouse_clicks(hide_window): """Test EC mouse click support.""" - with ExperimentController(*std_args, participant='foo', session='01', - output_dir=None, version='dev') as ec: + with ExperimentController( + *std_args, participant="foo", session="01", output_dir=None, version="dev" + ) as ec: rect = visual.Rectangle(ec, [0, 0, 2, 2]) fake_mouse_click(ec, [1, 2], delay=0.3) - assert_equal(ec.wait_for_click_on(rect, 1.5, timestamp=False)[0], - ('left', 1, 2)) + assert_equal( + ec.wait_for_click_on(rect, 1.5, timestamp=False)[0], ("left", 1, 2) + ) pytest.raises(TypeError, ec.wait_for_click_on, (rect, rect), 1.5) - fake_mouse_click(ec, [2, 1], 'middle', delay=0.3) - out = ec.wait_one_click(1.5, 0., ['middle'], timestamp=True) - assert (out[3] < 1.5) - assert_equal(out[:3], ('middle', 2, 1)) - fake_mouse_click(ec, [3, 2], 'left', delay=0.3) - fake_mouse_click(ec, [4, 5], 'right', delay=0.3) + fake_mouse_click(ec, [2, 1], "middle", delay=0.3) + out = ec.wait_one_click(1.5, 0.0, ["middle"], timestamp=True) + assert out[3] < 1.5 + assert_equal(out[:3], ("middle", 2, 1)) + fake_mouse_click(ec, [3, 2], "left", delay=0.3) + fake_mouse_click(ec, [4, 5], "right", delay=0.3) out = ec.wait_for_clicks(1.5, timestamp=False) assert_equal(len(out), 2) - assert (any(o == ('left', 3, 2) for o in out)) - assert (any(o == ('right', 4, 5) for o in out)) + assert any(o == ("left", 3, 2) for o in out) + assert any(o == ("right", 4, 5) for o in out) out = ec.wait_for_clicks(0.1) assert_equal(len(out), 0) @@ -556,136 +649,133 @@ def test_mouse_clicks(hide_window): @pytest.mark.timeout(30) def test_background_color(hide_window): """Test setting background color""" - with ExperimentController(*std_args, participant='foo', session='01', - output_dir=None, version='dev') as ec: + with ExperimentController( + *std_args, participant="foo", session="01", output_dir=None, version="dev" + ) as ec: print((ec.window.width, ec.window.height)) - ec.set_background_color('red') + ec.set_background_color("red") ss = ec.screenshot()[:, :, :3] red_mask = (ss == [255, 0, 0]).all(axis=-1) - assert (red_mask.all()) - ec.set_background_color('white') + assert red_mask.all() + ec.set_background_color("white") ss = ec.screenshot()[:, :, :3] white_mask = (ss == [255] * 3).all(axis=-1) - assert (white_mask.all()) + assert white_mask.all() ec.flip() - ec.set_background_color('0.5') - visual.Rectangle(ec, [0, 0, 1, 1], fill_color='black').draw() + ec.set_background_color("0.5") + visual.Rectangle(ec, [0, 0, 1, 1], fill_color="black").draw() ss = ec.screenshot()[:, :, :3] - gray_mask = ((ss == [127] * 3).all(axis=-1) | - (ss == [128] * 3).all(axis=-1)) - assert (gray_mask.any()) + gray_mask = (ss == [127] * 3).all(axis=-1) | (ss == [128] * 3).all(axis=-1) + assert gray_mask.any() black_mask = (ss == [0] * 3).all(axis=-1) - assert (black_mask.any()) - assert (np.logical_or(gray_mask, black_mask).all()) + assert black_mask.any() + assert np.logical_or(gray_mask, black_mask).all() def test_tdt_delay(hide_window): """Test the tdt_delay parameter.""" - with ExperimentController(*std_args, - audio_controller=dict(TYPE='tdt', TDT_DELAY=0), - **std_kwargs) as ec: - assert_equal(ec._ac._used_params['TDT_DELAY'], 0) - with ExperimentController(*std_args, - audio_controller=dict(TYPE='tdt', TDT_DELAY=1), - **std_kwargs) as ec: - assert_equal(ec._ac._used_params['TDT_DELAY'], 1) - pytest.raises(ValueError, ExperimentController, *std_args, - audio_controller=dict(TYPE='tdt', TDT_DELAY='foo'), - **std_kwargs) - pytest.raises(OverflowError, ExperimentController, *std_args, - audio_controller=dict(TYPE='tdt', TDT_DELAY=np.inf), - **std_kwargs) - pytest.raises(TypeError, ExperimentController, *std_args, - audio_controller=dict(TYPE='tdt', TDT_DELAY=np.ones(2)), - **std_kwargs) - pytest.raises(ValueError, ExperimentController, *std_args, - audio_controller=dict(TYPE='tdt', TDT_DELAY=-1), - **std_kwargs) + with ExperimentController( + *std_args, audio_controller=dict(TYPE="tdt", TDT_DELAY=0), **std_kwargs + ) as ec: + assert_equal(ec._ac._used_params["TDT_DELAY"], 0) + with ExperimentController( + *std_args, audio_controller=dict(TYPE="tdt", TDT_DELAY=1), **std_kwargs + ) as ec: + assert_equal(ec._ac._used_params["TDT_DELAY"], 1) + pytest.raises( + ValueError, + ExperimentController, + *std_args, + audio_controller=dict(TYPE="tdt", TDT_DELAY="foo"), + **std_kwargs, + ) + pytest.raises( + OverflowError, + ExperimentController, + *std_args, + audio_controller=dict(TYPE="tdt", TDT_DELAY=np.inf), + **std_kwargs, + ) + pytest.raises( + TypeError, + ExperimentController, + *std_args, + audio_controller=dict(TYPE="tdt", TDT_DELAY=np.ones(2)), + **std_kwargs, + ) + pytest.raises( + ValueError, + ExperimentController, + *std_args, + audio_controller=dict(TYPE="tdt", TDT_DELAY=-1), + **std_kwargs, + ) def test_sound_card_triggering(hide_window): """Test using the sound card as a trigger controller.""" - audio_controller = dict(TYPE='sound_card', SOUND_CARD_TRIGGER_CHANNELS='0') - with pytest.raises(ValueError, match='SOUND_CARD_TRIGGER_CHANNELS is zer'): - ExperimentController(*std_args, - audio_controller=audio_controller, - trigger_controller='sound_card', - suppress_resamp=True, - **std_kwargs) - audio_controller.update(SOUND_CARD_TRIGGER_CHANNELS='1') + audio_controller = dict(TYPE="sound_card", SOUND_CARD_TRIGGER_CHANNELS="0") + kwargs = std_kwargs.copy() + kwargs.update( + stim_fs=44100, + suppress_resamp=True, + audio_controller=audio_controller, + trigger_controller="sound_card", + ) + with pytest.raises(ValueError, match="SOUND_CARD_TRIGGER_CHANNELS is zer"): + ExperimentController(*std_args, **kwargs) + audio_controller.update(SOUND_CARD_TRIGGER_CHANNELS="1") # Use 1 trigger ch and 1 output ch because this should work on all systems - with ExperimentController(*std_args, - audio_controller=audio_controller, - trigger_controller='sound_card', - n_channels=1, - suppress_resamp=True, - **std_kwargs) as ec: - ec.identify_trial(ttl_id=[1, 0], ec_id='') + with ExperimentController(*std_args, n_channels=1, **kwargs) as ec: + ec.identify_trial(ttl_id=[1, 0], ec_id="") ec.load_buffer([1e-2]) ec.start_stimulus() ec.stop() # Test the drift triggers audio_controller.update(SOUND_CARD_DRIFT_TRIGGER=0.001) - with ExperimentController(*std_args, - audio_controller=audio_controller, - trigger_controller='sound_card', - n_channels=1, - **std_kwargs) as ec: - ec.identify_trial(ttl_id=[1, 0], ec_id='') - with pytest.warns(UserWarning, match='Drift triggers overlap with ' - 'onset triggers.'): + with ExperimentController(*std_args, n_channels=1, **kwargs) as ec: + ec.identify_trial(ttl_id=[1, 0], ec_id="") + with pytest.warns( + UserWarning, match="Drift triggers overlap with " "onset triggers." + ): ec.load_buffer(np.zeros(ec.stim_fs)) ec.start_stimulus() ec.stop() - audio_controller.update(SOUND_CARD_DRIFT_TRIGGER=[1.1, 0.3, -0.3, - 'end']) - with ExperimentController(*std_args, - audio_controller=audio_controller, - trigger_controller='sound_card', - n_channels=1, - **std_kwargs) as ec: - ec.identify_trial(ttl_id=[1, 0], ec_id='') - with pytest.warns(UserWarning, match='Drift trigger at 1.1 seconds ' - 'occurs outside stimulus window, not stamping ' - 'trigger.'): + audio_controller.update(SOUND_CARD_DRIFT_TRIGGER=[1.1, 0.3, -0.3, "end"]) + with ExperimentController(*std_args, n_channels=1, **kwargs) as ec: + ec.identify_trial(ttl_id=[1, 0], ec_id="") + with pytest.warns( + UserWarning, + match="Drift trigger at 1.1 seconds " + "occurs outside stimulus window, not stamping " + "trigger.", + ): ec.load_buffer(np.zeros(ec.stim_fs)) ec.start_stimulus() ec.stop() - audio_controller.update(SOUND_CARD_DRIFT_TRIGGER=[0.5, 0.501]) - with ExperimentController(*std_args, - audio_controller=audio_controller, - trigger_controller='sound_card', - n_channels=1, - **std_kwargs) as ec: - ec.identify_trial(ttl_id=[1, 0], ec_id='') - with pytest.warns(UserWarning, match='Some 2-triggers overlap.*'): + audio_controller.update(SOUND_CARD_DRIFT_TRIGGER=[0.5, 0.505]) + with ExperimentController(*std_args, n_channels=1, **kwargs) as ec: + ec.identify_trial(ttl_id=[1, 0], ec_id="") + with pytest.warns(UserWarning, match="Some 2-triggers overlap.*"): ec.load_buffer(np.zeros(ec.stim_fs)) ec.start_stimulus() ec.stop() audio_controller.update(SOUND_CARD_DRIFT_TRIGGER=[]) - with ExperimentController(*std_args, - audio_controller=audio_controller, - trigger_controller='sound_card', - n_channels=1, - **std_kwargs) as ec: - ec.identify_trial(ttl_id=[1, 0], ec_id='') + with ExperimentController(*std_args, n_channels=1, **kwargs) as ec: + ec.identify_trial(ttl_id=[1, 0], ec_id="") ec.load_buffer(np.zeros(ec.stim_fs)) ec.start_stimulus() ec.stop() audio_controller.update(SOUND_CARD_DRIFT_TRIGGER=[0.2, 0.5, -0.3]) - with ExperimentController(*std_args, - audio_controller=audio_controller, - trigger_controller='sound_card', - n_channels=1, - **std_kwargs) as ec: - ec.identify_trial(ttl_id=[1, 0], ec_id='') - ec.load_buffer(np.zeros(ec.stim_fs)) + with ExperimentController(*std_args, n_channels=1, **kwargs) as ec: + ec.identify_trial(ttl_id=[1, 0], ec_id="") + ec.load_buffer(np.zeros(ec.stim_fs * 2)) ec.start_stimulus() ec.stop() -class _FakeJoystick(object): - device = 'FakeJoystick' +class _FakeJoystick: + device = "FakeJoystick" on_joybutton_press = lambda self, joystick, button: None # noqa open = lambda self, window, exclusive: None # noqa x = 0.125 @@ -694,19 +784,20 @@ class _FakeJoystick(object): def test_joystick(hide_window, monkeypatch): """Test joystick support.""" import pyglet + fake = _FakeJoystick() - monkeypatch.setattr(pyglet.input, 'get_joysticks', lambda: [fake]) + monkeypatch.setattr(pyglet.input, "get_joysticks", lambda: [fake]) with ExperimentController(*std_args, joystick=True, **std_kwargs) as ec: ec.listen_joystick_button_presses() fake.on_joybutton_press(fake, 1) presses = ec.get_joystick_button_presses() assert len(presses) == 1 - assert presses[0][0] == '1' - assert ec.get_joystick_value('x') == 0.125 + assert presses[0][0] == "1" + assert ec.get_joystick_value("x") == 0.125 def test_sound_card_params(): """Test that sound card params are known keys.""" for key in _SOUND_CARD_KEYS: - if key != 'TYPE': + if key != "TYPE": assert key in known_config_types, key diff --git a/expyfun/tests/test_eyelink_controller.py b/expyfun/tests/test_eyelink_controller.py index ed333c9c..ae37a557 100644 --- a/expyfun/tests/test_eyelink_controller.py +++ b/expyfun/tests/test_eyelink_controller.py @@ -1,12 +1,19 @@ import pytest -from expyfun import EyelinkController, ExperimentController +from expyfun import ExperimentController, EyelinkController from expyfun._utils import _TempDir, requires_opengl21 -std_args = ['test'] +std_args = ["test"] temp_dir = _TempDir() -std_kwargs = dict(output_dir=temp_dir, full_screen=False, window_size=(1, 1), - participant='foo', session='01', noise_db=0, version='dev') +std_kwargs = dict( + output_dir=temp_dir, + full_screen=False, + window_size=(1, 1), + participant="foo", + session="01", + noise_db=0, + version="dev", +) @requires_opengl21 @@ -16,52 +23,58 @@ def test_eyelink_methods(hide_window): pytest.raises(ValueError, EyelinkController, ec, fs=999) el = EyelinkController(ec) pytest.raises(RuntimeError, EyelinkController, ec) # can't have 2 open - pytest.raises(ValueError, el.custom_calibration, ctype='hey') - el.custom_calibration('H3') - el.custom_calibration('HV9') - el.custom_calibration('HV13') - pytest.raises(ValueError, el.custom_calibration, ctype='custom', - coordinates='foo') - pytest.raises(ValueError, el.custom_calibration, ctype='custom', - coordinates=[[0, 1], 0]) - pytest.raises(ValueError, el.custom_calibration, ctype='custom', - coordinates=[[0, 1], [0]]) + pytest.raises(ValueError, el.custom_calibration, ctype="hey") + el.custom_calibration("H3") + el.custom_calibration("HV9") + el.custom_calibration("HV13") + pytest.raises( + ValueError, el.custom_calibration, ctype="custom", coordinates="foo" + ) + pytest.raises( + ValueError, el.custom_calibration, ctype="custom", coordinates=[[0, 1], 0] + ) + pytest.raises( + ValueError, el.custom_calibration, ctype="custom", coordinates=[[0, 1], [0]] + ) el._open_file() pytest.raises(RuntimeError, el._open_file) el._start_recording() el.get_eye_position() pytest.raises(ValueError, el.wait_for_fix, [1]) x = el.wait_for_fix([-10000, -10000], max_wait=0.1) - assert (x is False) + assert x is False assert el.eye_used print(el.file_list) - assert (len(el.file_list) > 0) + assert len(el.file_list) > 0 print(el.fs) x = el.maintain_fix([-10000, -10000], 0.1, period=0.01) - assert (x is False) + assert x is False # run much of the calibration code, but don't *actually* do it el._fake_calibration = True el.calibrate(beep=False, prompt=False) el._fake_calibration = False # missing el_id - pytest.raises(KeyError, ec.identify_trial, ec_id='foo', ttl_id=[0]) - ec.identify_trial(ec_id='foo', ttl_id=[0], el_id=[1]) + pytest.raises(KeyError, ec.identify_trial, ec_id="foo", ttl_id=[0]) + ec.identify_trial(ec_id="foo", ttl_id=[0], el_id=[1]) ec.start_stimulus() ec.stop() ec.trial_ok() - ec.identify_trial(ec_id='foo', ttl_id=[0], el_id=[1, 1]) + ec.identify_trial(ec_id="foo", ttl_id=[0], el_id=[1, 1]) ec.start_stimulus() ec.stop() ec.trial_ok() - pytest.raises(ValueError, ec.identify_trial, ec_id='foo', ttl_id=[0], - el_id=[1, dict()]) - pytest.raises(ValueError, ec.identify_trial, ec_id='foo', ttl_id=[0], - el_id=[0] * 13) - pytest.raises(TypeError, ec.identify_trial, ec_id='foo', ttl_id=[0], - el_id=dict()) + pytest.raises( + ValueError, ec.identify_trial, ec_id="foo", ttl_id=[0], el_id=[1, dict()] + ) + pytest.raises( + ValueError, ec.identify_trial, ec_id="foo", ttl_id=[0], el_id=[0] * 13 + ) + pytest.raises( + TypeError, ec.identify_trial, ec_id="foo", ttl_id=[0], el_id=dict() + ) pytest.raises(TypeError, el._message, 1) el.stop() el.transfer_remote_file(el.file_list[0]) - assert (not el._closed) + assert not el._closed # ec.close() auto-calls el.close() - assert (el._closed) + assert el._closed diff --git a/expyfun/tests/test_logging.py b/expyfun/tests/test_logging.py index 2fb4a973..9a698fa7 100644 --- a/expyfun/tests/test_logging.py +++ b/expyfun/tests/test_logging.py @@ -2,45 +2,59 @@ import os import pytest + from expyfun import ExperimentController from expyfun._utils import _check_skip_backend, requires_lib -std_args = ['test'] -std_kwargs = dict(participant='foo', session='01', full_screen=False, - window_size=(1, 1), verbose=True, noise_db=0, version='dev') +std_args = ["test"] +std_kwargs = dict( + participant="foo", + session="01", + full_screen=False, + window_size=(1, 1), + verbose=True, + noise_db=0, + version="dev", +) -@requires_lib('mne') +@requires_lib("mne") def test_logging(ac, tmpdir, hide_window): """Test logging to file (Pyglet).""" - if ac != 'tdt': + if ac != "tdt": _check_skip_backend(ac) orig_dir = os.getcwd() os.chdir(str(tmpdir)) try: - with ExperimentController(*std_args, audio_controller=ac, - response_device='keyboard', - trigger_controller='dummy', - **std_kwargs) as ec: + with ExperimentController( + *std_args, + audio_controller=ac, + response_device="keyboard", + trigger_controller="dummy", + **std_kwargs, + ) as ec: test_name = ec._log_file stamp = ec.current_time ec.wait_until(stamp) # wait_until called w/already passed timest. - with pytest.warns(UserWarning, match='RMS'): - ec.load_buffer([1., -1., 1., -1., 1., -1.]) # RMS warning + with pytest.warns(UserWarning, match="RMS"): + ec.load_buffer([1.0, -1.0, 1.0, -1.0, 1.0, -1.0]) # RMS warning with open(test_name) as fid: - data = '\n'.join(fid.readlines()) + data = "\n".join(fid.readlines()) # check for various expected log messages (TODO: add more) - should_have = ['Participant: foo', 'Session: 01', - 'wait_until was called', - 'Stimulus max RMS ('] - if ac == 'tdt': - should_have.append('TDT') + should_have = [ + "Participant: foo", + "Session: 01", + "wait_until was called", + "Stimulus max RMS (", + ] + if ac == "tdt": + should_have.append("TDT") else: - should_have.append('sound card') - if ac != 'auto' and ac['SOUND_CARD_BACKEND'] != 'auto': - should_have.append(ac['SOUND_CARD_BACKEND']) + should_have.append("sound card") + if ac != "auto" and ac["SOUND_CARD_BACKEND"] != "auto": + should_have.append(ac["SOUND_CARD_BACKEND"]) assert_have_all(data, should_have) finally: os.chdir(orig_dir) @@ -48,8 +62,7 @@ def test_logging(ac, tmpdir, hide_window): def assert_have_all(data, should_have): """Assert all substrings are in the logging output.""" - __tracebackhide__ = operator.methodcaller('errisinstance', AssertionError) + __tracebackhide__ = operator.methodcaller("errisinstance", AssertionError) for s in should_have: if s not in data: - raise AssertionError('Missing data: "{0}" in:\n{1}' - ''.format(s, data)) + raise AssertionError(f'Missing data: "{s}" in:\n{data}' "") diff --git a/expyfun/tests/test_parallel.py b/expyfun/tests/test_parallel.py index 05c041fa..9f4d6c6b 100644 --- a/expyfun/tests/test_parallel.py +++ b/expyfun/tests/test_parallel.py @@ -1,10 +1,8 @@ -# -*- coding: utf-8 -*- - import numpy as np import pytest from numpy.testing import assert_array_equal -from expyfun._parallel import parallel_func, _check_n_jobs +from expyfun._parallel import _check_n_jobs, parallel_func from expyfun._utils import requires_lib @@ -12,10 +10,11 @@ def _identity(x): return x -@requires_lib('joblib') +@pytest.mark.timeout(15) +@requires_lib("joblib") def test_parallel(): """Test parallel support.""" - pytest.raises(TypeError, _check_n_jobs, 'foo') + pytest.raises(TypeError, _check_n_jobs, "foo") parallel, p_fun, _ = parallel_func(_identity, 1) a = np.array(parallel(p_fun(x) for x in range(10))) parallel, p_fun, _ = parallel_func(_identity, 2) diff --git a/expyfun/tests/test_trigger_conversion.py b/expyfun/tests/test_trigger_conversion.py index 99de3e1a..c4df2111 100644 --- a/expyfun/tests/test_trigger_conversion.py +++ b/expyfun/tests/test_trigger_conversion.py @@ -1,12 +1,11 @@ -from numpy.testing import assert_array_equal import pytest +from numpy.testing import assert_array_equal -from expyfun import decimals_to_binary, binary_to_decimals +from expyfun import binary_to_decimals, decimals_to_binary def test_conversion(): - """Test decimal<->binary conversion - """ + """Test decimal<->binary conversion""" pytest.raises(ValueError, decimals_to_binary, [1], [0]) pytest.raises(ValueError, decimals_to_binary, [-1], [1]) pytest.raises(ValueError, decimals_to_binary, [1, 1], [1]) @@ -18,21 +17,24 @@ def test_conversion(): pytest.raises(ValueError, binary_to_decimals, [1], [-1]) pytest.raises(ValueError, binary_to_decimals, [1], [2]) # test cases - decs = [[1], - [1, 0, 1, 4, 5], - [0, 3], - [3, 0], - ] - bits = [[1], - [1, 1, 2, 4, 4], - [2, 2], - [2, 2], - ] - bins = [[1], - [1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1], - [0, 0, 1, 1], - [1, 1, 0, 0], - ] + decs = [ + [1], + [1, 0, 1, 4, 5], + [0, 3], + [3, 0], + ] + bits = [ + [1], + [1, 1, 2, 4, 4], + [2, 2], + [2, 2], + ] + bins = [ + [1], + [1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1], + [0, 0, 1, 1], + [1, 1, 0, 0], + ] for d, n, b in zip(decs, bits, bins): assert_array_equal(decimals_to_binary(d, n), b) assert_array_equal(binary_to_decimals(b, n), d) diff --git a/expyfun/tests/test_utils.py b/expyfun/tests/test_utils.py index 20abf148..fc18e96f 100644 --- a/expyfun/tests/test_utils.py +++ b/expyfun/tests/test_utils.py @@ -1,28 +1,28 @@ -import pytest import os import warnings import numpy as np +import pytest -from expyfun._utils import get_config, set_config, deprecated, _fix_audio_dims +from expyfun._utils import _fix_audio_dims, deprecated, get_config, set_config -warnings.simplefilter('always') +warnings.simplefilter("always") def test_config(): """Test expyfun config file support.""" - key = '_EXPYFUN_CONFIG_TESTING' - value = '123456' + key = "_EXPYFUN_CONFIG_TESTING" + value = "123456" old_val = os.getenv(key, None) os.environ[key] = value - assert (get_config(key) == value) + assert get_config(key) == value del os.environ[key] # catch the warning about it being a non-standard config key with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') + warnings.simplefilter("always") # warnings raised only when setting key set_config(key, None) - assert (get_config(key) is None) + assert get_config(key) is None pytest.raises(KeyError, get_config, key, raise_error=True) set_config(key, value) assert get_config(key) == value @@ -32,17 +32,17 @@ def test_config(): os.environ[key] = old_val pytest.raises(ValueError, get_config, 1) get_config(None) - set_config(None, '0') + set_config(None, "0") -@deprecated('message') +@deprecated("message") def deprecated_func(): """Deprecated function.""" pass -@deprecated('message') -class deprecated_class(object): +@deprecated("message") +class deprecated_class: """Deprecated class.""" def __init__(self): @@ -52,13 +52,13 @@ def __init__(self): def test_deprecated(): """Test deprecated function.""" with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') + warnings.simplefilter("always") deprecated_func() - assert (len(w) == 1) + assert len(w) == 1 with warnings.catch_warnings(record=True) as w: - warnings.simplefilter('always') + warnings.simplefilter("always") deprecated_class() - assert (len(w) == 1) + assert len(w) == 1 def test_audio_dims(): @@ -73,13 +73,12 @@ def test_audio_dims(): y = _fix_audio_dims(y, 2) assert y.shape == (2, n_samples) # no tiling for >2 channel output - with pytest.raises(ValueError, match='channel count 1 did not .* 3'): + with pytest.raises(ValueError, match="channel count 1 did not .* 3"): _fix_audio_dims(x, 3) for dim in (1, 3): - want = ('signal channel count 2 did not match required channel ' - 'count %s' % dim) + want = "signal channel count 2 did not match required channel " "count %s" % dim with pytest.raises(ValueError, match=want): _fix_audio_dims(y, dim) for n_channels in (1, 2, 3): - with pytest.raises(ValueError, match='must have one or two dimension'): + with pytest.raises(ValueError, match="must have one or two dimension"): _fix_audio_dims(np.zeros((2, 2, 2)), n_channels) diff --git a/expyfun/tests/test_version.py b/expyfun/tests/test_version.py index 98346128..cf7d5af5 100644 --- a/expyfun/tests/test_version.py +++ b/expyfun/tests/test_version.py @@ -1,58 +1,56 @@ -# -*- coding: utf-8 -*- import os -from os import path as op import warnings +from os import path as op import pytest -from expyfun import (ExperimentController, assert_version, download_version, - __version__) -from expyfun._utils import _TempDir +from expyfun import ExperimentController, __version__, assert_version, download_version from expyfun._git import _has_git +from expyfun._utils import _TempDir +@pytest.mark.filterwarnings("ignore:Package 'expyfun.data' is absent.*") @pytest.mark.timeout(60) # can be slow to download -def test_version_assertions(): +# old, broken, new +@pytest.mark.parametrize("want_version", ["090948e", "cae6bc3", "b6e8a81"]) +def test_version_assertions(want_version): """Test version assertions.""" pytest.raises(TypeError, assert_version, 1) - pytest.raises(TypeError, assert_version, '1' * 8) - pytest.raises(AssertionError, assert_version, 'x' * 7) + pytest.raises(TypeError, assert_version, "1" * 8) + pytest.raises(AssertionError, assert_version, "x" * 7) assert_version(__version__[-7:]) - # old, broken, new - for wi, want_version in enumerate(('090948e', 'cae6bc3', 'b6e8a81')): - print('Running %s' % want_version) - tempdir = _TempDir() - if not _has_git: - pytest.raises(ImportError, download_version, want_version, tempdir) - else: - pytest.raises(IOError, download_version, want_version, - op.join(tempdir, 'foo')) - pytest.raises(RuntimeError, download_version, 'x' * 7, tempdir) - ex_dir = op.join(tempdir, 'expyfun') - assert not op.isdir(ex_dir) - with warnings.catch_warnings(record=True): # Sometimes warns - warnings.simplefilter('ignore') - download_version(want_version, tempdir) - assert op.isdir(ex_dir) - assert op.isfile(op.join(ex_dir, '__init__.py')) - got_fname = op.join(ex_dir, '_version.py') - with open(got_fname) as fid: - line1 = fid.readline().strip() - got_version = line1.split(' = ')[1][-8:-1] - ex = want_version - if want_version == 'cae6bc3': - ex = (ex, '.dev0+c') - assert got_version in ex, got_fname + print("Running %s" % want_version) + tempdir = _TempDir() + if not _has_git: + pytest.raises(ImportError, download_version, want_version, tempdir) + else: + pytest.raises(IOError, download_version, want_version, op.join(tempdir, "foo")) + pytest.raises(RuntimeError, download_version, "x" * 7, tempdir) + ex_dir = op.join(tempdir, "expyfun") + assert not op.isdir(ex_dir) + with warnings.catch_warnings(record=True): # Sometimes warns + warnings.simplefilter("ignore") + download_version(want_version, tempdir) + assert op.isdir(ex_dir) + assert op.isfile(op.join(ex_dir, "__init__.py")) + got_fname = op.join(ex_dir, "_version.py") + with open(got_fname) as fid: + line1 = fid.readline().strip() + got_version = line1.split(" = ")[1][-8:-1] + ex = want_version + if want_version == "cae6bc3": + ex = (ex, ".dev0+c") + assert got_version in ex, got_fname - # auto dir determination - orig_dir = os.getcwd() - os.chdir(tempdir) - try: - assert op.isdir('expyfun') - pytest.raises(IOError, download_version, want_version) - finally: - os.chdir(orig_dir) + # auto dir determination + orig_dir = os.getcwd() + os.chdir(tempdir) + try: + assert op.isdir("expyfun") + pytest.raises(IOError, download_version, want_version) + finally: + os.chdir(orig_dir) # make sure we can get latest version tempdir_2 = _TempDir() if _has_git: @@ -63,11 +61,18 @@ def test_version_assertions(): def test_integrated_version_checking(): """Test EC version checking during init.""" tempdir = _TempDir() - args = ['test'] # experiment name - kwargs = dict(output_dir=tempdir, full_screen=False, window_size=(1, 1), - participant='foo', session='01', stim_db=0.0, noise_db=0.0, - verbose=True) - pytest.raises(RuntimeError, ExperimentController, *args, version=None, - **kwargs) - pytest.raises(AssertionError, ExperimentController, *args, - version='59f3f5b', **kwargs) # the very first commit + args = ["test"] # experiment name + kwargs = dict( + output_dir=tempdir, + full_screen=False, + window_size=(1, 1), + participant="foo", + session="01", + stim_db=0.0, + noise_db=0.0, + verbose=True, + ) + pytest.raises(RuntimeError, ExperimentController, *args, version=None, **kwargs) + pytest.raises( + AssertionError, ExperimentController, *args, version="59f3f5b", **kwargs + ) # the very first commit diff --git a/expyfun/visual/__init__.py b/expyfun/visual/__init__.py index 9d98389e..9f784a01 100644 --- a/expyfun/visual/__init__.py +++ b/expyfun/visual/__init__.py @@ -1,3 +1,15 @@ -from ._visual import (Text, Line, Triangle, Rectangle, Circle, RawImage, - Diamond, ConcentricCircles, FixationDot, ProgressBar, - _convert_color, _Triangular, Video) +from ._visual import ( + Text, + Line, + Triangle, + Rectangle, + Circle, + RawImage, + Diamond, + ConcentricCircles, + FixationDot, + ProgressBar, + _convert_color, + _Triangular, + Video, +) diff --git a/expyfun/visual/_visual.py b/expyfun/visual/_visual.py index 0973afa6..669decb3 100644 --- a/expyfun/visual/_visual.py +++ b/expyfun/visual/_visual.py @@ -1,6 +1,4 @@ -""" -Visual stimulus design -====================== +"""Visual stimulus design. Tools for drawing shapes and text on the screen. """ @@ -11,41 +9,45 @@ # # License: BSD (3-clause) -from ctypes import (cast, pointer, POINTER, create_string_buffer, c_char, - c_int, c_float) -from functools import partial import re - import warnings +from ctypes import POINTER, c_char, c_float, c_int, cast, create_string_buffer, pointer +from functools import partial + import numpy as np + try: from PyOpenGL import gl except ImportError: from pyglet import gl -from .._utils import check_units, string_types, logger, _new_pyglet +from .._utils import _new_pyglet, check_units, logger def _convert_color(color, byte=True): - """Convert 3- or 4-element color into OpenGL usable color""" + """Convert 3- or 4-element color into OpenGL usable color.""" from matplotlib.colors import colorConverter - color = (0., 0., 0., 0.) if color is None else color + + color = (0.0, 0.0, 0.0, 0.0) if color is None else color color = 255 * np.array(colorConverter.to_rgba(color)) color = color.astype(np.uint8) if not byte: - color = (color / 255.).astype(np.float32) - return tuple(color) + color = tuple((color / 255.0).astype(np.float32)) + else: + color = tuple(int(c) for c in color) + return color def _replicate_color(color, pts): - """Convert single color to color array for OpenGL trianglulations""" + """Convert single color to color array for OpenGL triangulations.""" return np.tile(color, len(pts) // 2) ############################################################################## # Text -class Text(object): + +class Text: """A text object. Parameters @@ -93,31 +95,47 @@ class Text(object): The text object. """ - def __init__(self, ec, text, pos=(0, 0), color='white', - font_name='Arial', font_size=24, height=None, - width='auto', anchor_x='center', anchor_y='center', - units='norm', wrap=False, attr=True): + def __init__( + self, + ec, + text, + pos=(0, 0), + color="white", + font_name="Arial", + font_size=24, + height=None, + width="auto", + anchor_x="center", + anchor_y="center", + units="norm", + wrap=False, + attr=True, + ): import pyglet + pos = np.array(pos)[:, np.newaxis] - pos = ec._convert_units(pos, units, 'pix') - if width == 'auto': + pos = ec._convert_units(pos, units, "pix")[:, 0] + if width == "auto": width = float(ec.window_size_pix[0]) * 0.8 - elif isinstance(width, string_types): + elif isinstance(width, str): raise ValueError('"width", if str, must be "auto"') self._attr = attr if wrap: - text = text + '\n ' # weird Pyglet bug + text = text + "\n " # weird Pyglet bug if self._attr: - preamble = ('{{font_name \'{}\'}}{{font_size {}}}{{color {}}}' - '').format(font_name, font_size, _convert_color(color)) + preamble = ( + f"{{font_name '{font_name}'}}" + f"{{font_size {font_size}}}" + f"{{color {_convert_color(color)}}}" + ) doc = pyglet.text.decode_attributed(preamble + text) self._text = pyglet.text.layout.TextLayout( - doc, width=width, height=height, multiline=wrap, - dpi=int(ec.dpi)) + doc, width=width, height=height, multiline=wrap, dpi=int(ec.dpi) + ) else: self._text = pyglet.text.Label( - text, width=width, height=height, multiline=wrap, - dpi=int(ec.dpi)) + text, width=width, height=height, multiline=wrap, dpi=int(ec.dpi) + ) self._text.color = _convert_color(color) self._text.font_name = font_name self._text.font_size = font_size @@ -135,8 +153,9 @@ def set_color(self, color): The color. Use None for no color. """ if self._attr: - self._text.document.set_style(0, len(self._text.document.text), - {'color': _convert_color(color)}) + self._text.document.set_style( + 0, len(self._text.document.text), {"color": _convert_color(color)} + ) else: self._text.color = _convert_color(color) @@ -178,9 +197,11 @@ def _check_log(obj, func): func(obj, 4096, pointer(c_int()), ptr) message = log.value message = message.decode() - if message.startswith('No errors') or \ - re.match('.*shader was successfully compiled.*', message) or \ - message == 'Vertex shader(s) linked, fragment shader(s) linked.\n': + if ( + message.startswith("No errors") + or re.match(".*shader was successfully compiled.*", message) + or message == "Vertex shader(s) linked, fragment shader(s) linked.\n" + ): pass elif message: raise RuntimeError(message) @@ -190,14 +211,14 @@ def _create_program(ec, vert, frag): program = gl.glCreateProgram() vertex = gl.glCreateShader(gl.GL_VERTEX_SHADER) - buf = create_string_buffer(vert.encode('ASCII')) + buf = create_string_buffer(vert.encode("ASCII")) ptr = cast(pointer(pointer(buf)), POINTER(POINTER(c_char))) gl.glShaderSource(vertex, 1, ptr, None) gl.glCompileShader(vertex) _check_log(vertex, gl.glGetShaderInfoLog) fragment = gl.glCreateShader(gl.GL_FRAGMENT_SHADER) - buf = create_string_buffer(frag.encode('ASCII')) + buf = create_string_buffer(frag.encode("ASCII")) ptr = cast(pointer(pointer(buf)), POINTER(POINTER(c_char))) gl.glShaderSource(fragment, 1, ptr, None) gl.glCompileShader(fragment) @@ -213,9 +234,9 @@ def _create_program(ec, vert, frag): # Set the view matrix gl.glUseProgram(program) - loc = gl.glGetUniformLocation(program, b'u_view') + loc = gl.glGetUniformLocation(program, b"u_view") view = ec.window_size_pix - view = np.diag([2. / view[0], 2. / view[1], 1., 1.]) + view = np.diag([2.0 / view[0], 2.0 / view[1], 1.0, 1.0]) view[-1, :2] = -1 view = view.astype(np.float32).ravel() gl.glUniformMatrix4fv(loc, 1, False, (c_float * 16)(*view)) @@ -223,7 +244,7 @@ def _create_program(ec, vert, frag): return program -class _Triangular(object): +class _Triangular: """Super class for objects that use triangulations and/or lines""" def __init__(self, ec, fill_color, line_color, line_width, line_loop): @@ -240,13 +261,13 @@ def __init__(self, ec, fill_color, line_color, line_width, line_loop): self._buffers = dict() self._points = dict() self._tris = dict() - for kind in ('line', 'fill'): + for kind in ("line", "fill"): self._counts[kind] = 0 - self._colors[kind] = (0., 0., 0., 0.) + self._colors[kind] = (0.0, 0.0, 0.0, 0.0) self._buffers[kind] = dict(array=gl.GLuint()) - gl.glGenBuffers(1, pointer(self._buffers[kind]['array'])) - self._buffers['fill']['index'] = gl.GLuint() - gl.glGenBuffers(1, pointer(self._buffers['fill']['index'])) + gl.glGenBuffers(1, pointer(self._buffers[kind]["array"])) + self._buffers["fill"]["index"] = gl.GLuint() + gl.glGenBuffers(1, pointer(self._buffers["fill"]["index"])) gl.glUseProgram(0) self.set_fill_color(fill_color) @@ -256,12 +277,12 @@ def _set_points(self, points, kind, tris): """Set fill and line points.""" if points is None: self._counts[kind] = 0 - points = np.asarray(points, dtype=np.float32, order='C') + points = np.asarray(points, dtype=np.float32, order="C") assert points.ndim == 2 and points.shape[1] == 2 - array_count = points.size // 2 if kind == 'line' else points.size - if kind == 'fill': + array_count = points.size // 2 if kind == "line" else points.size + if kind == "fill": assert tris is not None - tris = np.asarray(tris, dtype=np.uint32, order='C') + tris = np.asarray(tris, dtype=np.uint32, order="C") assert tris.ndim == 1 and tris.size % 3 == 0 tris.shape = (-1, 3) assert (tris < len(points)).all() @@ -271,29 +292,33 @@ def _set_points(self, points, kind, tris): del points gl.glUseProgram(self._program) - gl.glBindBuffer(gl.GL_ARRAY_BUFFER, self._buffers[kind]['array']) - gl.glBufferData(gl.GL_ARRAY_BUFFER, self._points[kind].size * 4, - self._points[kind].tobytes(), - gl.GL_STATIC_DRAW) - if kind == 'line': + gl.glBindBuffer(gl.GL_ARRAY_BUFFER, self._buffers[kind]["array"]) + gl.glBufferData( + gl.GL_ARRAY_BUFFER, + self._points[kind].size * 4, + self._points[kind].tobytes(), + gl.GL_STATIC_DRAW, + ) + if kind == "line": self._counts[kind] = array_count - if kind == 'fill': + if kind == "fill": self._counts[kind] = self._tris[kind].size - gl.glBindBuffer(gl.GL_ELEMENT_ARRAY_BUFFER, - self._buffers[kind]['index']) - gl.glBufferData(gl.GL_ELEMENT_ARRAY_BUFFER, - self._tris[kind].size * 4, - self._tris[kind].tobytes(), - gl.GL_STATIC_DRAW) + gl.glBindBuffer(gl.GL_ELEMENT_ARRAY_BUFFER, self._buffers[kind]["index"]) + gl.glBufferData( + gl.GL_ELEMENT_ARRAY_BUFFER, + self._tris[kind].size * 4, + self._tris[kind].tobytes(), + gl.GL_STATIC_DRAW, + ) gl.glBindBuffer(gl.GL_ELEMENT_ARRAY_BUFFER, 0) gl.glBindBuffer(gl.GL_ARRAY_BUFFER, 0) gl.glUseProgram(0) def _set_fill_points(self, points, tris): - self._set_points(points, 'fill', tris) + self._set_points(points, "fill", tris) def _set_line_points(self, points): - self._set_points(points, 'line', None) + self._set_points(points, "line", None) def set_fill_color(self, fill_color): """Set the object color @@ -303,7 +328,7 @@ def set_fill_color(self, fill_color): fill_color : matplotlib Color | None The fill color. Use None for no fill. """ - self._colors['fill'] = _convert_color(fill_color, byte=False) + self._colors["fill"] = _convert_color(fill_color, byte=False) def set_line_color(self, line_color): """Set the object color @@ -313,7 +338,7 @@ def set_line_color(self, line_color): line_color : matplotlib Color | None The fill color. Use None for no fill. """ - self._colors['line'] = _convert_color(line_color, byte=False) + self._colors["line"] = _convert_color(line_color, byte=False) def set_line_width(self, line_width): """Set the line width in pixels @@ -326,15 +351,15 @@ def set_line_width(self, line_width): """ line_width = float(line_width) if not (0.0 <= line_width <= 10.0): - raise ValueError('line_width must be between 0 and 10') + raise ValueError("line_width must be between 0 and 10") self._line_width = line_width def draw(self): """Draw the object to the display buffer.""" gl.glUseProgram(self._program) - for kind in ('fill', 'line'): + for kind in ("fill", "line"): if self._counts[kind] > 0: - if kind == 'line': + if kind == "line": if self._line_width <= 0.0: continue gl.glLineWidth(self._line_width) @@ -344,22 +369,26 @@ def draw(self): mode = gl.GL_LINE_STRIP cmd = partial(gl.glDrawArrays, mode, 0, self._counts[kind]) else: - gl.glBindBuffer(gl.GL_ELEMENT_ARRAY_BUFFER, - self._buffers[kind]['index']) - cmd = partial(gl.glDrawElements, gl.GL_TRIANGLES, - self._counts[kind], gl.GL_UNSIGNED_INT, 0) - gl.glBindBuffer(gl.GL_ARRAY_BUFFER, - self._buffers[kind]['array']) - loc_pos = gl.glGetAttribLocation(self._program, b'a_position') + gl.glBindBuffer( + gl.GL_ELEMENT_ARRAY_BUFFER, self._buffers[kind]["index"] + ) + cmd = partial( + gl.glDrawElements, + gl.GL_TRIANGLES, + self._counts[kind], + gl.GL_UNSIGNED_INT, + 0, + ) + gl.glBindBuffer(gl.GL_ARRAY_BUFFER, self._buffers[kind]["array"]) + loc_pos = gl.glGetAttribLocation(self._program, b"a_position") gl.glEnableVertexAttribArray(loc_pos) - gl.glVertexAttribPointer(loc_pos, 2, gl.GL_FLOAT, gl.GL_FALSE, - 0, 0) - loc_col = gl.glGetUniformLocation(self._program, b'u_color') + gl.glVertexAttribPointer(loc_pos, 2, gl.GL_FLOAT, gl.GL_FALSE, 0, 0) + loc_col = gl.glGetUniformLocation(self._program, b"u_color") gl.glUniform4f(loc_col, *self._colors[kind]) cmd() # cleanup gl.glDisableVertexAttribArray(loc_pos) - if kind != 'line': + if kind != "line": gl.glBindBuffer(gl.GL_ELEMENT_ARRAY_BUFFER, 0) gl.glBindBuffer(gl.GL_ARRAY_BUFFER, 0) gl.glUseProgram(0) @@ -390,14 +419,27 @@ class Line(_Triangular): The line object. """ - def __init__(self, ec, coords, units='norm', line_color='white', - line_width=1.0, line_loop=False): - _Triangular.__init__(self, ec, fill_color=None, line_color=line_color, - line_width=line_width, line_loop=line_loop) + def __init__( + self, + ec, + coords, + units="norm", + line_color="white", + line_width=1.0, + line_loop=False, + ): + _Triangular.__init__( + self, + ec, + fill_color=None, + line_color=line_color, + line_width=line_width, + line_loop=line_loop, + ) self.set_coords(coords, units) self.set_line_color(line_color) - def set_coords(self, coords, units='norm'): + def set_coords(self, coords, units="norm"): """Set line coordinates Parameters @@ -412,10 +454,12 @@ def set_coords(self, coords, units='norm'): if coords.ndim == 1: coords = coords[:, np.newaxis] if coords.ndim != 2 or coords.shape[0] != 2: - raise ValueError('coords must be a vector of length 2, or an ' - 'array with 2 dimensions (with first dimension ' - 'having length 2') - self._set_line_points(self._ec._convert_units(coords, units, 'pix').T) + raise ValueError( + "coords must be a vector of length 2, or an " + "array with 2 dimensions (with first dimension " + "having length 2" + ) + self._set_line_points(self._ec._convert_units(coords, units, "pix").T) class Triangle(_Triangular): @@ -443,15 +487,27 @@ class Triangle(_Triangular): The triangle object. """ - def __init__(self, ec, coords, units='norm', fill_color='white', - line_color=None, line_width=1.0): - _Triangular.__init__(self, ec, fill_color=fill_color, - line_color=line_color, line_width=line_width, - line_loop=True) + def __init__( + self, + ec, + coords, + units="norm", + fill_color="white", + line_color=None, + line_width=1.0, + ): + _Triangular.__init__( + self, + ec, + fill_color=fill_color, + line_color=line_color, + line_width=line_width, + line_loop=True, + ) self.set_coords(coords, units) self.set_fill_color(fill_color) - def set_coords(self, coords, units='norm'): + def set_coords(self, coords, units="norm"): """Set triangle coordinates Parameters @@ -464,9 +520,10 @@ def set_coords(self, coords, units='norm'): check_units(units) coords = np.array(coords, dtype=float) if coords.shape != (2, 3): - raise ValueError('coords must be an array of shape (2, 3), got %s' - % (coords.shape,)) - points = self._ec._convert_units(coords, units, 'pix') + raise ValueError( + "coords must be an array of shape (2, 3), got %s" % (coords.shape,) + ) + points = self._ec._convert_units(coords, units, "pix") points = points.T self._set_fill_points(points, [0, 1, 2]) self._set_line_points(points) @@ -498,14 +555,20 @@ class Rectangle(_Triangular): The rectangle object. """ - def __init__(self, ec, pos, units='norm', fill_color='white', - line_color=None, line_width=1.0): - _Triangular.__init__(self, ec, fill_color=fill_color, - line_color=line_color, line_width=line_width, - line_loop=True) + def __init__( + self, ec, pos, units="norm", fill_color="white", line_color=None, line_width=1.0 + ): + _Triangular.__init__( + self, + ec, + fill_color=fill_color, + line_color=line_color, + line_width=line_width, + line_loop=True, + ) self.set_pos(pos, units) - def set_pos(self, pos, units='norm'): + def set_pos(self, pos, units="norm"): """Set the position of the rectangle Parameters @@ -519,16 +582,20 @@ def set_pos(self, pos, units='norm'): # do this in normalized units, then convert pos = np.array(pos) if not (pos.ndim == 1 and pos.size == 4): - raise ValueError('pos must be a 4-element array-like vector') + raise ValueError("pos must be a 4-element array-like vector") self._pos = pos w = self._pos[2] h = self._pos[3] - points = np.array([[-w / 2., -h / 2.], - [-w / 2., h / 2.], - [w / 2., h / 2.], - [w / 2., -h / 2.]]).T + points = np.array( + [ + [-w / 2.0, -h / 2.0], + [-w / 2.0, h / 2.0], + [w / 2.0, h / 2.0], + [w / 2.0, -h / 2.0], + ] + ).T points += np.array(self._pos[:2])[:, np.newaxis] - points = self._ec._convert_units(points, units, 'pix') + points = self._ec._convert_units(points, units, "pix") points = points.T self._set_fill_points(points, [0, 1, 2, 0, 2, 3]) self._set_line_points(points) # all 4 points used for line drawing @@ -560,14 +627,20 @@ class Diamond(_Triangular): The rectangle object. """ - def __init__(self, ec, pos, units='norm', fill_color='white', - line_color=None, line_width=1.0): - _Triangular.__init__(self, ec, fill_color=fill_color, - line_color=line_color, line_width=line_width, - line_loop=True) + def __init__( + self, ec, pos, units="norm", fill_color="white", line_color=None, line_width=1.0 + ): + _Triangular.__init__( + self, + ec, + fill_color=fill_color, + line_color=line_color, + line_width=line_width, + line_loop=True, + ) self.set_pos(pos, units) - def set_pos(self, pos, units='norm'): + def set_pos(self, pos, units="norm"): """Set the position of the rectangle Parameters @@ -581,16 +654,15 @@ def set_pos(self, pos, units='norm'): # do this in normalized units, then convert pos = np.array(pos) if not (pos.ndim == 1 and pos.size == 4): - raise ValueError('pos must be a 4-element array-like vector') + raise ValueError("pos must be a 4-element array-like vector") self._pos = pos w = self._pos[2] h = self._pos[3] - points = np.array([[w / 2., 0.], - [0., h / 2.], - [-w / 2., 0.], - [0., -h / 2.]]).T + points = np.array( + [[w / 2.0, 0.0], [0.0, h / 2.0], [-w / 2.0, 0.0], [0.0, -h / 2.0]] + ).T points += np.array(self._pos[:2])[:, np.newaxis] - points = self._ec._convert_units(points, units, 'pix') + points = self._ec._convert_units(points, units, "pix") points = points.T self._set_fill_points(points, [0, 1, 2, 0, 2, 3]) self._set_line_points(points) @@ -626,16 +698,29 @@ class Circle(_Triangular): The circle object. """ - def __init__(self, ec, radius=1, pos=(0, 0), units='norm', - n_edges=200, fill_color='white', line_color=None, - line_width=1.0): - _Triangular.__init__(self, ec, fill_color=fill_color, - line_color=line_color, line_width=line_width, - line_loop=True) + def __init__( + self, + ec, + radius=1, + pos=(0, 0), + units="norm", + n_edges=200, + fill_color="white", + line_color=None, + line_width=1.0, + ): + _Triangular.__init__( + self, + ec, + fill_color=fill_color, + line_color=line_color, + line_width=line_width, + line_loop=True, + ) if not isinstance(n_edges, int): - raise TypeError('n_edges must be an int') + raise TypeError("n_edges must be an int") if n_edges < 4: - raise ValueError('n_edges must be >= 4 for a reasonable circle') + raise ValueError("n_edges must be >= 4 for a reasonable circle") self._n_edges = n_edges # construct triangulation (never changes so long as n_edges is fixed) @@ -645,11 +730,11 @@ def __init__(self, ec, radius=1, pos=(0, 0), units='norm', self._orig_tris = tris # need to set a dummy value here so recalculation doesn't fail - self._radius = np.array([1., 1.]) + self._radius = np.array([1.0, 1.0]) self.set_pos(pos, units) self.set_radius(radius, units) - def set_radius(self, radius, units='norm'): + def set_radius(self, radius, units="norm"): """Set the position and radius of the circle Parameters @@ -663,19 +748,19 @@ def set_radius(self, radius, units='norm'): check_units(units) radius = np.atleast_1d(radius).astype(float) if radius.ndim != 1 or radius.size > 2: - raise ValueError('radius must be a 1- or 2-element ' - 'array-like vector') + raise ValueError("radius must be a 1- or 2-element " "array-like vector") if radius.size == 1: radius = np.r_[radius, radius] # convert to pixel (OpenGL) units - self._radius = self._ec._convert_units(radius[:, np.newaxis], - units, 'pix')[:, 0] + self._radius = self._ec._convert_units(radius[:, np.newaxis], units, "pix")[ + :, 0 + ] # need to subtract center position - ctr = self._ec._convert_units(np.zeros((2, 1)), units, 'pix')[:, 0] + ctr = self._ec._convert_units(np.zeros((2, 1)), units, "pix")[:, 0] self._radius -= ctr self._recalculate() - def set_pos(self, pos, units='norm'): + def set_pos(self, pos, units="norm"): """Set the position and radius of the circle Parameters @@ -688,18 +773,18 @@ def set_pos(self, pos, units='norm'): check_units(units) pos = np.array(pos, dtype=float) if not (pos.ndim == 1 and pos.size == 2): - raise ValueError('pos must be a 2-element array-like vector') + raise ValueError("pos must be a 2-element array-like vector") # convert to pixel (OpenGL) units - self._pos = self._ec._convert_units(pos[:, np.newaxis], - units, 'pix')[:, 0] + self._pos = self._ec._convert_units(pos[:, np.newaxis], units, "pix")[:, 0] self._recalculate() def _recalculate(self): """Helper to recalculate point coordinates""" edges = self._n_edges arg = 2 * np.pi * (np.arange(edges) / float(edges)) - points = np.array([self._radius[0] * np.cos(arg), - self._radius[1] * np.sin(arg)]) + points = np.array( + [self._radius[0] * np.cos(arg), self._radius[1] * np.sin(arg)] + ) points = np.c_[np.zeros((2, 1)), points] # prepend the center points += np.array(self._pos[:2], dtype=float)[:, np.newaxis] points = points.T @@ -707,7 +792,7 @@ def _recalculate(self): self._set_line_points(points[1:]) # omit center point for lines -class ConcentricCircles(object): +class ConcentricCircles: """A set of filled concentric circles drawn without edges. Parameters @@ -732,23 +817,26 @@ class ConcentricCircles(object): The circle object. """ - def __init__(self, ec, radii=(0.2, 0.05), pos=(0, 0), units='norm', - colors=('w', 'k')): + def __init__( + self, ec, radii=(0.2, 0.05), pos=(0, 0), units="norm", colors=("w", "k") + ): radii = np.array(radii, float) if radii.ndim != 1: - raise ValueError('radii must be 1D') + raise ValueError("radii must be 1D") if not isinstance(colors, (tuple, list)): - raise TypeError('colors must be a tuple, list, or array') + raise TypeError("colors must be a tuple, list, or array") if len(colors) != len(radii): - raise ValueError('colors and radii must be the same length') + raise ValueError("colors and radii must be the same length") # need to set a dummy value here so recalculation doesn't fail - self._circles = [Circle(ec, r, pos, units, fill_color=c, line_width=0) - for r, c in zip(radii, colors)] + self._circles = [ + Circle(ec, r, pos, units, fill_color=c, line_width=0) + for r, c in zip(radii, colors) + ] def __len__(self): return len(self._circles) - def set_pos(self, pos, units='norm'): + def set_pos(self, pos, units="norm"): """Set the position of the circles Parameters @@ -761,7 +849,7 @@ def set_pos(self, pos, units='norm'): for circle in self._circles: circle.set_pos(pos, units) - def set_radius(self, radius, idx, units='norm'): + def set_radius(self, radius, idx, units="norm"): """Set the radius of one of the circles Parameters @@ -775,7 +863,7 @@ def set_radius(self, radius, idx, units='norm'): """ self._circles[idx].set_radius(radius, units) - def set_radii(self, radii, units='norm'): + def set_radii(self, radii, units="norm"): """Set the color of each circle Parameters @@ -788,8 +876,7 @@ def set_radii(self, radii, units='norm'): """ radii = np.array(radii, float) if radii.ndim != 1 or radii.size != len(self): - raise ValueError('radii must contain exactly {0} radii' - ''.format(len(self))) + raise ValueError(f"radii must contain exactly {len(self)} radii" "") for idx, radius in enumerate(radii): self.set_radius(radius, idx, units) @@ -815,8 +902,9 @@ def set_colors(self, colors): colors as the number of circles. """ if not isinstance(colors, (tuple, list)) or len(colors) != len(self): - raise ValueError('colors must be a list or tuple with {0} colors' - ''.format(len(self))) + raise ValueError( + f"colors must be a list or tuple with {len(self)} colors" "" + ) for idx, color in enumerate(colors): self.set_color(color, idx) @@ -846,16 +934,14 @@ class FixationDot(ConcentricCircles): The fixation dot. """ - def __init__(self, ec, colors=('w', 'k')): + def __init__(self, ec, colors=("w", "k")): if len(colors) != 2: - raise ValueError('colors must have length 2') - super(FixationDot, self).__init__(ec, radii=[0.2, 0.2], - pos=[0, 0], units='deg', - colors=colors) - self.set_radius(1, 1, units='pix') + raise ValueError("colors must have length 2") + super().__init__(ec, radii=[0.2, 0.2], pos=[0, 0], units="deg", colors=colors) + self.set_radius(1, 1, units="pix") -class ProgressBar(object): +class ProgressBar: """A progress bar that can be displayed between sections. This uses two rectangles, one outline, and one solid to show how much @@ -876,12 +962,12 @@ class ProgressBar(object): white. """ - def __init__(self, ec, pos, units='norm', colors=('g', 'w')): + def __init__(self, ec, pos, units="norm", colors=("g", "w")): self._ec = ec if len(colors) != 2: - raise ValueError('colors must have length 2') - if units not in ['norm', 'pix']: - raise ValueError('units must be either \'norm\' or \'pix\'') + raise ValueError("colors must have length 2") + if units not in ["norm", "pix"]: + raise ValueError("units must be either 'norm' or 'pix'") pos = np.array(pos, dtype=float) self._pos = pos @@ -894,9 +980,10 @@ def __init__(self, ec, pos, units='norm', colors=('g', 'w')): self._init_x = self._pos_bar[0] self._pos_bar[2] = 0 - self._rectangles = [Rectangle(ec, self._pos_bar, units, colors[0], - None), - Rectangle(ec, self._pos, units, None, colors[1])] + self._rectangles = [ + Rectangle(ec, self._pos_bar, units, colors[0], None), + Rectangle(ec, self._pos, units, None, colors[1]), + ] def update_bar(self, percent): """Update the progress of the bar. @@ -907,8 +994,8 @@ def update_bar(self, percent): The percentage of the bar to be filled. Must be between 0 and 1. """ if percent > 100 or percent < 0: - raise ValueError('percent must be a float between 0 and 100') - self._pos_bar[2] = percent * self._width / 100. + raise ValueError("percent must be a float between 0 and 100") + self._pos_bar[2] = percent * self._width / 100.0 self._pos_bar[0] = self._init_x + self._pos_bar[2] * 0.5 self._rectangles[0].set_pos(self._pos_bar, self._units) @@ -921,7 +1008,8 @@ def draw(self): ############################################################################## # Image display -class RawImage(object): + +class RawImage: """Create image from array for on-screen display. Parameters @@ -944,7 +1032,7 @@ class RawImage(object): The image object. """ - def __init__(self, ec, image_buffer, pos=(0, 0), scale=1., units='norm'): + def __init__(self, ec, image_buffer, pos=(0, 0), scale=1.0, units="norm"): self._ec = ec self._img = None self.set_image(image_buffer) @@ -962,26 +1050,28 @@ def set_image(self, image_buffer): ``np.uint8`` is slightly more efficient. """ from pyglet import image, sprite + image_buffer = np.ascontiguousarray(image_buffer) if image_buffer.dtype not in (np.float64, np.uint8): - raise TypeError('image_buffer must be np.float64 or np.uint8') + raise TypeError("image_buffer must be np.float64 or np.uint8") if image_buffer.dtype == np.float64: if image_buffer.max() > 1 or image_buffer.min() < 0: - raise ValueError('all float values must be between 0 and 1') - image_buffer = (image_buffer * 255).astype('uint8') + raise ValueError("all float values must be between 0 and 1") + image_buffer = (image_buffer * 255).astype("uint8") if image_buffer.ndim == 2: # grayscale image_buffer = np.tile(image_buffer[..., np.newaxis], (1, 1, 3)) if not image_buffer.ndim == 3 or image_buffer.shape[2] not in [3, 4]: - raise RuntimeError('image_buffer incorrect size: {}' - ''.format(image_buffer.shape)) + raise RuntimeError(f"image_buffer incorrect size: {image_buffer.shape}" "") # add alpha channel if necessary dims = image_buffer.shape - fmt = 'RGB' if dims[2] == 3 else 'RGBA' - self._sprite = sprite.Sprite(image.ImageData(dims[1], dims[0], fmt, - image_buffer.tobytes(), - -dims[1] * dims[2])) - - def set_pos(self, pos, units='norm'): + fmt = "RGB" if dims[2] == 3 else "RGBA" + self._sprite = sprite.Sprite( + image.ImageData( + dims[1], dims[0], fmt, image_buffer.tobytes(), -dims[1] * dims[2] + ) + ) + + def set_pos(self, pos, units="norm"): """Set image position. Parameters @@ -993,17 +1083,16 @@ def set_pos(self, pos, units='norm'): """ pos = np.array(pos, float) if pos.ndim != 1 or pos.size != 2: - raise ValueError('pos must be a 2-element array') + raise ValueError("pos must be a 2-element array") pos = np.reshape(pos, (2, 1)) - self._pos = self._ec._convert_units(pos, units, 'pix').ravel() + self._pos = self._ec._convert_units(pos, units, "pix").ravel() @property def bounds(self): """Left, Right, Bottom, Top (in pixels) of the image.""" pos = np.array(self._pos, float) - size = np.array([self._sprite.width, - self._sprite.height], float) - bounds = np.concatenate((pos - size / 2., pos + size / 2.)) + size = np.array([self._sprite.width, self._sprite.height], float) + bounds = np.concatenate((pos - size / 2.0, pos + size / 2.0)) return bounds[[0, 2, 1, 3]] @property @@ -1026,14 +1115,14 @@ def set_scale(self, scale): def draw(self): """Draw the image to the buffer""" self._sprite.scale = self._scale - pos = self._pos - [self._sprite.width / 2., self._sprite.height / 2.] + pos = self._pos - [self._sprite.width / 2.0, self._sprite.height / 2.0] try: self._sprite.position = (pos[0], pos[1]) except AttributeError: self._sprite.set_position(pos[0], pos[1]) self._sprite.draw() - def get_rect(self, units='norm'): + def get_rect(self, units="norm"): """X, Y center, Width, Height of image. Parameters @@ -1047,15 +1136,13 @@ def get_rect(self, units='norm'): The rect. """ # left,right,bottom,top - lrbt = self._ec._convert_units(self.bounds.reshape(2, -1), - fro='pix', to=units) - center = self._ec._convert_units(self._pos.reshape(2, -1), - fro='pix', to=units) + lrbt = self._ec._convert_units(self.bounds.reshape(2, -1), fro="pix", to=units) + center = self._ec._convert_units(self._pos.reshape(2, -1), fro="pix", to=units) width_height = np.diff(lrbt, axis=-1) return np.squeeze(np.concatenate([center, width_height])) -tex_vert = ''' +tex_vert = """ #version 120 attribute vec2 a_position; @@ -1068,9 +1155,9 @@ def get_rect(self, units='norm'): gl_Position = u_view * vec4(a_position, 0.0, 1.0); v_texcoord = a_texcoord; } -''' +""" -tex_frag = ''' +tex_frag = """ #version 120 #extension GL_ARB_texture_rectangle : enable @@ -1082,10 +1169,10 @@ def get_rect(self, units='norm'): gl_FragColor = texture2DRect(u_texture, v_texcoord); gl_FragColor.a = 1.0; } -''' +""" -class Video(object): +class Video: """Read video file and draw it to the screen. Parameters @@ -1123,20 +1210,31 @@ class Video(object): entertainment for the participant during a passive auditory task). """ - def __init__(self, ec, file_name, pos=(0, 0), units='norm', scale=1., - center=True, visible=True): - from pyglet.media import load, Player + def __init__( + self, + ec, + file_name, + pos=(0, 0), + units="norm", + scale=1.0, + center=True, + visible=True, + ): + from pyglet.media import Player, load + self._ec = ec # On Windows, the default is unaccelerated WMF, which is terribly slow. decoder = None if _new_pyglet(): try: from pyglet.media.codecs.ffmpeg import FFmpegDecoder + decoder = FFmpegDecoder() except Exception as exc: warnings.warn( - 'FFmpeg decoder could not be instantiated, decoding ' - f'performance could be compromised:\n{exc}') + "FFmpeg decoder could not be instantiated, decoding " + f"performance could be compromised:\n{exc}" + ) self._source = load(file_name, decoder=decoder) self._player = Player() with warnings.catch_warnings(record=True): # deprecated eos_action @@ -1144,9 +1242,9 @@ def __init__(self, ec, file_name, pos=(0, 0), units='norm', scale=1., self._player._audio_player = None frame_rate = self.frame_rate if frame_rate is None: - logger.warning('Frame rate could not be determined') - frame_rate = 60. - self._dt = 1. / frame_rate + logger.warning("Frame rate could not be determined") + frame_rate = 60.0 + self._dt = 1.0 / frame_rate self._playing = False self._finished = False self._pos = pos @@ -1159,14 +1257,15 @@ def __init__(self, ec, file_name, pos=(0, 0), units='norm', scale=1., self._program = _create_program(ec, tex_vert, tex_frag) gl.glUseProgram(self._program) self._buffers = dict() - for key in ('position', 'texcoord'): + for key in ("position", "texcoord"): self._buffers[key] = gl.GLuint(0) gl.glGenBuffers(1, pointer(self._buffers[key])) w, h = self.source_width, self.source_height tex = np.array([(0, h), (w, h), (w, 0), (0, 0)], np.float32) - gl.glBindBuffer(gl.GL_ARRAY_BUFFER, self._buffers['texcoord']) - gl.glBufferData(gl.GL_ARRAY_BUFFER, tex.nbytes, tex.tobytes(), - gl.GL_DYNAMIC_DRAW) + gl.glBindBuffer(gl.GL_ARRAY_BUFFER, self._buffers["texcoord"]) + gl.glBufferData( + gl.GL_ARRAY_BUFFER, tex.nbytes, tex.tobytes(), gl.GL_DYNAMIC_DRAW + ) gl.glBindBuffer(gl.GL_ARRAY_BUFFER, 0) gl.glUseProgram(0) @@ -1190,8 +1289,9 @@ def play(self, auto_draw=True): self._player.play() self._playing = True else: - warnings.warn('ExperimentController.video.play() called when ' - 'already playing.') + warnings.warn( + "ExperimentController.video.play() called when " "already playing." + ) return self._ec.get_time() def pause(self): @@ -1213,8 +1313,9 @@ def pause(self): self._player.pause() self._playing = False else: - warnings.warn('ExperimentController.video.pause() called when ' - 'already paused.') + warnings.warn( + "ExperimentController.video.pause() called when " "already paused." + ) return self._ec.get_time() def _delete(self): @@ -1223,7 +1324,7 @@ def _delete(self): self.pause() self._player.delete() - def set_scale(self, scale=1.): + def set_scale(self, scale=1.0): """Set video scale. Parameters @@ -1237,20 +1338,20 @@ def set_scale(self, scale=1.): while ensuring none of the video is offscreen, which may result in letterboxing). """ - if isinstance(scale, string_types): - _scale = self._ec.window_size_pix / np.array((self.source_width, - self.source_height), - dtype=float) - if scale == 'fit': + if isinstance(scale, str): + _scale = self._ec.window_size_pix / np.array( + (self.source_width, self.source_height), dtype=float + ) + if scale == "fit": scale = _scale.min() - elif scale == 'fill': + elif scale == "fill": scale = _scale.max() self._scale = float(scale) # allows [1, 1., '1']; others: ValueError if self._scale <= 0: - raise ValueError('Video scale factor must be strictly positive.') + raise ValueError("Video scale factor must be strictly positive.") self.set_pos(self._pos, self._units, self._center) - def set_pos(self, pos, units='norm', center=True): + def set_pos(self, pos, units="norm", center=True): """Set video position. Parameters @@ -1266,9 +1367,9 @@ def set_pos(self, pos, units='norm', center=True): """ pos = np.array(pos, float) if pos.size != 2: - raise ValueError('pos must be a 2-element array') + raise ValueError("pos must be a 2-element array") pos = np.reshape(pos, (2, 1)) - pix = self._ec._convert_units(pos, units, 'pix').ravel() + pix = self._ec._convert_units(pos, units, "pix").ravel() offset = np.array((self.width, self.height)) // 2 if center else 0 self._pos = pos self._actual_pos = pix - offset @@ -1284,16 +1385,16 @@ def _draw(self): x, y = self._actual_pos w = self.source_width * self._scale h = self.source_height * self._scale - pos = np.array( - [(x, y), (x + w, y), (x + w, y + h), (x, y + h)], np.float32) - gl.glBindBuffer(gl.GL_ARRAY_BUFFER, self._buffers['position']) - gl.glBufferData(gl.GL_ARRAY_BUFFER, pos.nbytes, pos.tobytes(), - gl.GL_DYNAMIC_DRAW) - loc_pos = gl.glGetAttribLocation(self._program, b'a_position') + pos = np.array([(x, y), (x + w, y), (x + w, y + h), (x, y + h)], np.float32) + gl.glBindBuffer(gl.GL_ARRAY_BUFFER, self._buffers["position"]) + gl.glBufferData( + gl.GL_ARRAY_BUFFER, pos.nbytes, pos.tobytes(), gl.GL_DYNAMIC_DRAW + ) + loc_pos = gl.glGetAttribLocation(self._program, b"a_position") gl.glEnableVertexAttribArray(loc_pos) gl.glVertexAttribPointer(loc_pos, 2, gl.GL_FLOAT, gl.GL_FALSE, 0, 0) - gl.glBindBuffer(gl.GL_ARRAY_BUFFER, self._buffers['texcoord']) - loc_tex = gl.glGetAttribLocation(self._program, b'a_texcoord') + gl.glBindBuffer(gl.GL_ARRAY_BUFFER, self._buffers["texcoord"]) + loc_tex = gl.glGetAttribLocation(self._program, b"a_texcoord") gl.glEnableVertexAttribArray(loc_tex) gl.glVertexAttribPointer(loc_tex, 2, gl.GL_FLOAT, gl.GL_FALSE, 0, 0) gl.glBindBuffer(gl.GL_ARRAY_BUFFER, 0) @@ -1345,9 +1446,11 @@ def _eos(self): return self._eos_fun() def _eos_old(self): - return (self._player._last_video_timestamp is not None and - self._player._last_video_timestamp == - self._source.get_next_video_timestamp()) + return ( + self._player._last_video_timestamp is not None + and self._player._last_video_timestamp + == self._source.get_next_video_timestamp() + ) def _eos_new(self): done = self._player.source is None diff --git a/expyfun/visual/tests/test_visuals.py b/expyfun/visual/tests/test_visuals.py index fbae8ca2..1619f2ad 100644 --- a/expyfun/visual/tests/test_visuals.py +++ b/expyfun/visual/tests/test_visuals.py @@ -2,18 +2,26 @@ import pytest from numpy.testing import assert_equal -from expyfun import ExperimentController, visual, fetch_data_file +from expyfun import ExperimentController, fetch_data_file, visual from expyfun._utils import requires_opengl21, requires_video -std_kwargs = dict(output_dir=None, full_screen=False, window_size=(1, 1), - participant='foo', session='01', stim_db=0.0, noise_db=0.0, - verbose=True, version='dev') +std_kwargs = dict( + output_dir=None, + full_screen=False, + window_size=(1, 1), + participant="foo", + session="01", + stim_db=0.0, + noise_db=0.0, + verbose=True, + version="dev", +) @requires_opengl21 def test_visuals(hide_window): """Test EC visual methods.""" - with ExperimentController('test', **std_kwargs) as ec: + with ExperimentController("test", **std_kwargs) as ec: pytest.raises(TypeError, visual.Circle, ec, n_edges=3.5) pytest.raises(ValueError, visual.Circle, ec, n_edges=3) circ = visual.Circle(ec) @@ -21,29 +29,28 @@ def test_visuals(hide_window): pytest.raises(ValueError, circ.set_radius, [1, 2, 3]) pytest.raises(ValueError, circ.set_pos, [1]) pytest.raises(ValueError, visual.Triangle, ec, [5, 6]) - tri = visual.Triangle(ec, [[-1, 0, 1], [-1, 1, -1]], units='deg', - line_width=1.0) + tri = visual.Triangle( + ec, [[-1, 0, 1], [-1, 1, -1]], units="deg", line_width=1.0 + ) tri.draw() rect = visual.Rectangle(ec, [0, 0, 1, 1], line_width=1.0) rect.draw() diamond = visual.Diamond(ec, [0, 0, 1, 1], line_width=1.0) diamond.draw() pytest.raises(TypeError, visual.ConcentricCircles, ec, colors=dict()) - pytest.raises(TypeError, visual.ConcentricCircles, ec, - colors=np.array([])) + pytest.raises(TypeError, visual.ConcentricCircles, ec, colors=np.array([])) pytest.raises(ValueError, visual.ConcentricCircles, ec, radii=[[1]]) pytest.raises(ValueError, visual.ConcentricCircles, ec, radii=[1]) - fix = visual.ConcentricCircles(ec, radii=[1, 2, 3], - colors=['w', 'k', 'y']) + fix = visual.ConcentricCircles(ec, radii=[1, 2, 3], colors=["w", "k", "y"]) fix.set_pos([0.5, 0.5]) fix.set_radius(0.1, 1) fix.set_radii([0.1, 0.2, 0.3]) - fix.set_color('w', 1) - fix.set_colors(['w', 'k', 'k']) - fix.set_colors(('w', 'k', 'k')) - pytest.raises(IndexError, fix.set_color, 'w', 3) - pytest.raises(ValueError, fix.set_colors, ['w', 'k']) - pytest.raises(ValueError, fix.set_colors, np.array(['w', 'k', 'k'])) + fix.set_color("w", 1) + fix.set_colors(["w", "k", "k"]) + fix.set_colors(("w", "k", "k")) + pytest.raises(IndexError, fix.set_color, "w", 3) + pytest.raises(ValueError, fix.set_colors, ["w", "k"]) + pytest.raises(ValueError, fix.set_colors, np.array(["w", "k", "k"])) pytest.raises(IndexError, fix.set_radius, 0.1, 3) pytest.raises(ValueError, fix.set_radii, [0.1, 0.2]) fix.draw() @@ -60,8 +67,9 @@ def test_visuals(hide_window): assert_equal(img.scale, 1) # test get_rect imgrect = visual.Rectangle(ec, img.get_rect()) - assert_equal(imgrect._points['fill'][(0, 2, 0, 1), (0, 0, 1, 1)], - img.bounds) + assert_equal( + imgrect._points["fill"][(0, 2, 0, 1), (0, 0, 1, 1)], img.bounds + ) img.draw() line = visual.Line(ec, [[0, 1], [1, 0]]) line.draw() @@ -70,25 +78,29 @@ def test_visuals(hide_window): line.draw() pytest.raises(ValueError, line.set_coords, [0]) line.set_coords([0, 1]) - ec.set_background_color('black') - text = visual.Text(ec, 'Hello {color (255, 0, 0, 255)}Everybody!', - pos=[0, 0], color=[1, 1, 1], wrap=False) + ec.set_background_color("black") + text = visual.Text( + ec, + "Hello {color (255, 0, 0, 255)}Everybody!", + pos=[0, 0], + color=[1, 1, 1], + wrap=False, + ) text.draw() text.set_color(None) text.draw() - text = visual.Text(ec, 'Thank you, come again.', pos=[0, 0], - color='white', attr=False) + text = visual.Text( + ec, "Thank you, come again.", pos=[0, 0], color="white", attr=False + ) text.draw() - text.set_color('red') + text.set_color("red") text.draw() - bar = visual.ProgressBar(ec, [0, 0, 1, .2]) - bar = visual.ProgressBar(ec, [0, 0, 1, 1], units='pix') - bar.update_bar(.5) + bar = visual.ProgressBar(ec, [0, 0, 1, 0.2]) + bar = visual.ProgressBar(ec, [0, 0, 1, 1], units="pix") + bar.update_bar(0.5) bar.draw() - pytest.raises(ValueError, visual.ProgressBar, ec, [0, 0, 1, .1], - units='deg') - pytest.raises(ValueError, visual.ProgressBar, ec, [0, 0, 1, .1], - colors=['w']) + pytest.raises(ValueError, visual.ProgressBar, ec, [0, 0, 1, 0.1], units="deg") + pytest.raises(ValueError, visual.ProgressBar, ec, [0, 0, 1, 0.1], colors=["w"]) pytest.raises(ValueError, bar.update_bar, 500) @@ -96,21 +108,21 @@ def test_visuals(hide_window): def test_video(hide_window): """Test EC video methods.""" std_kwargs.update(dict(window_size=(640, 480))) - video_path = fetch_data_file('video/example-video.mp4') - with ExperimentController('test', **std_kwargs) as ec: + video_path = fetch_data_file("video/example-video.mp4") + with ExperimentController("test", **std_kwargs) as ec: ec.load_video(video_path) ec.video.play() pytest.raises(ValueError, ec.video.set_pos, [1, 2, 3]) - pytest.raises(ValueError, ec.video.set_scale, 'foo') + pytest.raises(ValueError, ec.video.set_scale, "foo") pytest.raises(ValueError, ec.video.set_scale, -1) ec.wait_secs(0.1) ec.video.set_visible(False) ec.wait_secs(0.1) ec.video.set_visible(True) - ec.video.set_scale('fill') - ec.video.set_scale('fit') - ec.video.set_scale('0.5') - ec.video.set_pos(pos=(0.1, 0), units='norm') + ec.video.set_scale("fill") + ec.video.set_scale("fit") + ec.video.set_scale("0.5") + ec.video.set_pos(pos=(0.1, 0), units="norm") ec.video.pause() ec.video.draw() ec.delete_video() diff --git a/git_flow.rst b/git_flow.rst index eb64d86d..fe35dd29 100644 --- a/git_flow.rst +++ b/git_flow.rst @@ -25,7 +25,7 @@ Users will want to take the "official" version of the software, make a copy of it on their own computer, and run the code from there. Using ``expyfun`` software as an example, this is done on the command line like this:: - $ git clone git://github.com/LABSN/expyfun.git + $ git clone https://github.com/LABSN/expyfun.git $ cd expyfun $ python setup.py install @@ -48,8 +48,8 @@ command sets up a relationship between that folder on your computer and the "origin" of the code. You can see this by typing:: $ git remote -v - origin git://github.com/LABSN/expyfun.git (fetch) - origin git://github.com/LABSN/expyfun.git (push) + origin https://github.com/LABSN/expyfun.git (fetch) + origin https://github.com/LABSN/expyfun.git (push) This tells you that :bash:`git` knows about two "remote" addresses of ``expyfun``: one to ``fetch`` new changes from (if the source code gets updated @@ -117,7 +117,7 @@ connect to the official remote repo with the name ``upstream``. So after forking $ git clone git@github.com:/rkmaddox/expyfun.git $ cd expyfun - $ git remote add upstream git://github.com/LABSN/expyfun.git + $ git remote add upstream https://github.com/LABSN/expyfun.git Now this user has the standard ``origin``/``upstream`` configuration, as seen below. Note the difference in the URIs between ``origin`` and ``upstream``:: @@ -125,12 +125,12 @@ below. Note the difference in the URIs between ``origin`` and ``upstream``:: $ git remote -v origin git@github.com:/rkmaddox/expyfun.git (fetch) origin git@github.com:/rkmaddox/expyfun.git (push) - upstream git://github.com/LABSN/expyfun.git (fetch) - upstream git://github.com/LABSN/expyfun.git (push) + upstream https://github.com/LABSN/expyfun.git (fetch) + upstream https://github.com/LABSN/expyfun.git (push) $ git branch * master -URIs beginning with ``git://`` are read-only connections, so ``rkmaddox`` can +URIs beginning with ``https://`` are read-only connections, so ``rkmaddox`` can pull down new changes from ``upstream``, but won't be able to directly push his local changes to upstream. Instead, he would have to push to his fork (``origin``) first, and create a @@ -244,7 +244,7 @@ Maintainers ^^^^^^^^^^^ Maintainers start out with a similar set up as Developers_. However, they might want to be able to push directly to the ``upstream`` repo as well as pushing to -their fork. Having a repo set up with :bash:`git://` access instead of +their fork. Having a repo set up with :bash:`https://` access instead of :bash:`git@github.com` or :bash:`https://` access will not allow pushing. So starting from scratch, a maintainer ``Eric89GXL`` might fork the upstream repo and then do:: @@ -252,7 +252,7 @@ and then do:: $ git clone git@github.com:/Eric89GXL/expyfun.git $ cd expyfun $ git remote add upstream git@github.com:/LABSN/expyfun.git - $ git remote add ross git://github.com/rkmaddox/expyfun.git + $ git remote add ross https://github.com/rkmaddox/expyfun.git Now the maintainer's local repository has push/pull access to their own personal development fork and the upstream repo, and has read-only access to @@ -261,8 +261,8 @@ development fork and the upstream repo, and has read-only access to $ git remote -v origin git@github.com:/Eric89GXL/expyfun.git (fetch) origin git@github.com:/Eric89GXL/expyfun.git (push) - ross git://github.com/rkmaddox/expyfun.git (fetch) - ross git://github.com/rkmaddox/expyfun.git (push) + ross https://github.com/rkmaddox/expyfun.git (fetch) + ross https://github.com/rkmaddox/expyfun.git (push) upstream git@github.com:/LABSN/expyfun.git (fetch) upstream git@github.com:/LABSN/expyfun.git (push) diff --git a/ignore_words.txt b/ignore_words.txt index 26ce2021..abef3133 100644 --- a/ignore_words.txt +++ b/ignore_words.txt @@ -5,3 +5,6 @@ ang sinc fied hist +bu +master +blacklist diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..b292de58 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,124 @@ +[tool.codespell] +ignore-words = "ignore_words.txt" +builtin = "clear,rare,informal,names,usage" +skip = "doc/references.bib" + +[tool.ruff] +exclude = ["__init__.py"] + +[tool.ruff.lint] +select = ["A", "B006", "D", "E", "F", "I", "W", "UP"] # , "UP031"] +ignore = [ + "D100", # Missing docstring in public module + "D104", # Missing docstring in public package + "D400", # First line should end with a period + "D401", # First line should be in imperative mood + "D413", # Missing blank line after last section + "UP031", # Use format specifiers instead of percent format + "UP030", # Use implicit references for positional format fields +] + +[tool.ruff.lint.pydocstyle] +convention = "numpy" +ignore-decorators = [ + "property", + "setter", + "mne.utils.copy_function_doc_to_method_doc", + "mne.utils.copy_doc", + "mne.utils.deprecated", +] + +[tool.ruff.lint.per-file-ignores] +"examples/**.py" = [ + "D205", # 1 blank line required between summary line and description +] + +[tool.pytest.ini_options] +# -r f (failed), E (error), s (skipped), x (xfail), X (xpassed), w (warnings) +# don't put in xfail for pytest 8.0+ because then it prints the tracebacks, +# which look like real errors +addopts = """--durations=20 --doctest-modules -rfEXs --cov-report= --tb=short \ + --cov-branch --doctest-ignore-import-errors --junit-xml=junit-results.xml \ + --ignore=doc --ignore=examples --ignore=tools \ + --color=yes --capture=sys""" +junit_family = "xunit2" +# Set this pretty low to ensure we do not by default add really long tests, +# or make changes that make things a lot slower +timeout = 5 +usefixtures = "matplotlib_config" +# Once SciPy updates not to have non-integer and non-tuple errors (1.2.0) we +# should remove them from here. +# This list should also be considered alongside reset_warnings in doc/conf.py +filterwarnings = ''' + error:: + ignore::ImportWarning + ignore:TDT is in dummy mode:UserWarning + ignore:generator 'ZipRunIterator.ranges' raised StopIteration:DeprecationWarning + ignore:size changed:RuntimeWarning + ignore:Using or importing the ABCs:DeprecationWarning + ignore:joblib not installed:RuntimeWarning + ignore:Matplotlib is building the font cache using fc-list:UserWarning + ignore:.*clock has been deprecated.*:DeprecationWarning + ignore:the imp module is deprecated.*:DeprecationWarning + ignore:.*eos_action is deprecated.*:DeprecationWarning + ignore:.*Vertex attribute shorthand.*: + ignore:.*ufunc size changed.*:RuntimeWarning + ignore:.*doc-files.*: + ignore:.*include is ignored because.*: + always:.*unclosed file.*:ResourceWarning + always:.*may indicate binary incompatibility.*: + ignore:.*Cannot change thread mode after it is set.*:UserWarning + ignore:.*distutils Version classes are deprecated.*:DeprecationWarning + ignore:.*distutils\.sysconfig module is deprecated.*:DeprecationWarning + ignore:.*isSet\(\) is deprecated.*:DeprecationWarning + ignore:`product` is deprecated as of NumPy.*:DeprecationWarning + ignore:Invalid dash-separated options.*: + always:.*Exception ignored in.*__del__.*: +''' + +[tool.rstcheck] +report_level = "WARNING" +ignore_roles = [ + "attr", + "class", + "doc", + "eq", + "exc", + "file", + "footcite", + "footcite:t", + "func", + "gh", + "kbd", + "meth", + "mod", + "newcontrib", + "py:mod", + "py:obj", + "obj", + "ref", + "samp", + "term", +] +ignore_directives = [ + "autoclass", + "autofunction", + "automodule", + "autosummary", + "bibliography", + "cssclass", + "currentmodule", + "dropdown", + "footbibliography", + "glossary", + "graphviz", + "grid", + "highlight", + "minigallery", + "tabularcolumns", + "toctree", + "rst-class", + "tab-set", + "towncrier-draft-entries", +] +ignore_messages = "^.*(Unknown target name|Undefined substitution referenced)[^`]*$" diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 5a1b0467..00000000 --- a/setup.cfg +++ /dev/null @@ -1,65 +0,0 @@ -[aliases] -release = egg_info -RDb '' -# Make sure the sphinx docs are built each time we do a dist. -# bdist = build_sphinx bdist -# sdist = build_sphinx sdist -# Make sure a zip file is created each time we build the sphinx docs -# build_sphinx = generate_help build_sphinx zip_help -# Make sure the docs are uploaded when we do an upload -# upload = upload upload_help - -[egg_info] -# tag_build = .dev - -[bdist_rpm] -doc-files = doc - -[tool:pytest] -addopts = - --durations=20 --doctest-modules -ra --cov-report= --tb=short - --doctest-ignore-import-errors --junit-xml=junit-results.xml - --ignore=examples --ignore=tutorials --ignore=doc --ignore=make - --capture=sys -usefixtures = matplotlib_config -junit_family = xunit2 -# Set this pretty low to ensure we do not by default add really long tests, -# or make changes that make things a lot slower -timeout = 5 -# Once SciPy updates not to have non-integer and non-tuple errors (1.2.0) we -# should remove them from here. -# This list should also be considered alongside reset_warnings in doc/conf.py -filterwarnings = - error:: - ignore::ImportWarning - ignore:TDT is in dummy mode:UserWarning - ignore:generator 'ZipRunIterator.ranges' raised StopIteration:DeprecationWarning - ignore:size changed:RuntimeWarning - ignore:Using or importing the ABCs:DeprecationWarning - ignore:joblib not installed:RuntimeWarning - ignore:Matplotlib is building the font cache using fc-list:UserWarning - ignore:.*clock has been deprecated.*:DeprecationWarning - ignore:the imp module is deprecated.*:DeprecationWarning - ignore:.*eos_action is deprecated.*:DeprecationWarning - ignore:.*Vertex attribute shorthand.*: - ignore:.*ufunc size changed.*:RuntimeWarning - ignore:.*doc-files.*: - ignore:.*include is ignored because.*: - always:.*unclosed file.*:ResourceWarning - always:.*may indicate binary incompatibility.*: - ignore:.*Cannot change thread mode after it is set.*:UserWarning - ignore:.*distutils Version classes are deprecated.*:DeprecationWarning - ignore:.*distutils\.sysconfig module is deprecated.*:DeprecationWarning - ignore:.*isSet\(\) is deprecated.*:DeprecationWarning - always:.*Exception ignored in.*__del__.*: - -[flake8] -exclude = __init__.py,decorator.py,ndarraysource.py -ignore = E226,E241,E242,E265,W504 - -[pydocstyle] -convention = pep257 -match_dir = ^(?!\.|_externals|doc|examples).*$ -match = (?!tests/__init__\.py|fixes).*\.py -add-ignore = D100,D104,D107,D413,D105,D200,D205,D400,D401 # eventually D105,D200,D205,D400,D401 should be used -add-select = D214,D215,D404,D405,D406,D407,D408,D409,D410,D411 -ignore-decorators = ^(property|.*setter).* diff --git a/setup.py b/setup.py index a4e2022a..40d8f183 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ # we are using a setuptools namespace import setuptools # noqa, analysis:ignore -from numpy.distutils.core import setup +from setuptools import setup descr = """Experiment controller functions.""" @@ -79,7 +79,19 @@ def setup_package(script_args=None): version=FULL_VERSION, download_url=DOWNLOAD_URL, long_description=long_description, - python_requires=">=3.7", + python_requires=">=3.8", + install_requires=[ + "packaging", + "numpy", + "scipy", + "matplotlib", + "pillow", + "h5io", + "decorator", + ], + extras_require={ + "test": ["pytest", "pytest-cov", "pytest-timeout"], + }, zip_safe=False, # the package can run out of an .egg file classifiers=['Intended Audience :: Science/Research', 'Intended Audience :: Developers', diff --git a/make/get_video.ps1 b/tools/get_video.ps1 similarity index 100% rename from make/get_video.ps1 rename to tools/get_video.ps1