From a0cef033cbea6e44e9b6786e67cc7b8c2e46aeeb Mon Sep 17 00:00:00 2001 From: Linlang <30293408+SunsetWolf@users.noreply.github.com> Date: Tue, 17 Dec 2024 11:30:06 +0800 Subject: [PATCH] update python version (#1868) * update python version * fix: Correct selector handling and add time filtering in storage.py * fix: convert index and columns to list in repr methods * feat: Add Makefile for managing project prerequisites * feat: Add Cython extensions for rolling and expanding operations * resolve install error * fix lint error * fix lint error * fix lint error * fix lint error * fix lint error * update build package * update makefile * update ci yaml * fix docs build error * fix ubuntu install error * fix docs build error * fix install error * fix install error * fix install error * fix install error * fix pylint error * fix pylint error * fix pylint error * fix pylint error * fix pylint error E1123 * fix pylint error R0917 * fix pytest error * fix pytest error * fix pytest error * update code * update code * fix ci error * fix pylint error * fix black error * fix pytest error * fix CI error * fix CI error * add python version to CI * add python version to CI * add python version to CI * fix pylint error * fix pytest general nn error * fix CI error * optimize code * add coments * Extended macos version * remove build package --------- Co-authored-by: Young --- .github/workflows/python-publish.yml | 28 +-- .github/workflows/test_qlib_from_source.yml | 94 ++------ .../workflows/test_qlib_from_source_slow.yml | 20 +- .gitignore | 3 +- MANIFEST.in | 7 +- Makefile | 195 ++++++++++++++++ README.md | 2 +- examples/benchmarks/TRA/example.py | 5 +- examples/data_demo/data_cache_demo.py | 5 +- examples/data_demo/data_mem_resuse_demo.py | 6 +- examples/run_all_model.py | 5 +- pyproject.toml | 92 +++++++- qlib/__init__.py | 8 +- qlib/backtest/high_performance_ds.py | 4 +- qlib/backtest/report.py | 10 +- qlib/contrib/data/handler.py | 8 +- qlib/contrib/model/catboost_model.py | 2 +- qlib/contrib/model/double_ensemble.py | 2 +- qlib/contrib/model/pytorch_adarnn.py | 8 +- qlib/contrib/model/pytorch_add.py | 2 +- qlib/contrib/model/pytorch_alstm.py | 2 +- qlib/contrib/model/pytorch_alstm_ts.py | 2 +- qlib/contrib/model/pytorch_gats.py | 2 +- qlib/contrib/model/pytorch_gats_ts.py | 2 +- qlib/contrib/model/pytorch_general_nn.py | 7 +- qlib/contrib/model/pytorch_gru.py | 2 +- qlib/contrib/model/pytorch_gru_ts.py | 2 +- qlib/contrib/model/pytorch_hist.py | 2 +- qlib/contrib/model/pytorch_igmtf.py | 2 +- qlib/contrib/model/pytorch_krnn.py | 2 +- qlib/contrib/model/pytorch_localformer.py | 2 +- qlib/contrib/model/pytorch_localformer_ts.py | 2 +- qlib/contrib/model/pytorch_lstm.py | 2 +- qlib/contrib/model/pytorch_lstm_ts.py | 2 +- qlib/contrib/model/pytorch_sandwich.py | 4 +- qlib/contrib/model/pytorch_sfm.py | 2 +- qlib/contrib/model/pytorch_tcn.py | 2 +- qlib/contrib/model/pytorch_tcn_ts.py | 2 +- qlib/contrib/model/pytorch_tcts.py | 2 +- qlib/contrib/model/pytorch_transformer.py | 2 +- qlib/contrib/model/pytorch_transformer_ts.py | 2 +- qlib/contrib/model/xgboost.py | 4 +- qlib/contrib/online/manager.py | 5 +- qlib/contrib/online/utils.py | 5 +- qlib/contrib/report/graph.py | 6 +- qlib/contrib/rolling/base.py | 5 +- qlib/contrib/tuner/config.py | 5 +- qlib/data/dataset/storage.py | 20 +- qlib/data/filter.py | 1 + qlib/rl/contrib/naive_config_parser.py | 6 +- qlib/rl/contrib/train_onpolicy.py | 5 +- qlib/utils/__init__.py | 8 +- qlib/utils/index_data.py | 14 +- qlib/workflow/cli.py | 8 +- qlib/workflow/expm.py | 9 +- qlib/workflow/recorder.py | 6 +- setup.cfg | 3 + setup.py | 208 ------------------ tests/data_mid_layer_tests/test_dataloader.py | 4 +- tests/dependency_tests/test_mlflow.py | 6 +- tests/misc/test_index_data.py | 2 +- tests/model/test_general_nn.py | 4 + tests/test_pit.py | 1 - 63 files changed, 462 insertions(+), 428 deletions(-) create mode 100644 Makefile create mode 100644 setup.cfg delete mode 100644 setup.py diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index b13e5bd47c..ef0aa98e48 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -12,43 +12,23 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [windows-latest, macos-13] + os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-13, macos-14, macos-latest] # FIXME: macos-latest will raise error now. # not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129 - python-version: [3.7, 3.8] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v2 - # This is because on macos systems you can install pyqlib using - # `pip install pyqlib` installs, it does not recognize the - # `pyqlib--cp38-cp38-macosx_11_0_x86_64.whl` and `pyqlib--cp38-cp37m-macosx_11_0_x86_64.whl`. - # So we limit the version of python, in order to generate a version of qlib that is usable for macos: `pyqlib--cp38-cp37m - # `pyqlib--cp38-cp38-macosx_10_15_x86_64.whl` and `pyqlib--cp38-cp37m-macosx_10_15_x86_64.whl`. - # Python 3.7.16, 3.8.16 can build macosx_10_15. But Python 3.7.17, 3.8.17 can build macosx_11_0 - name: Set up Python ${{ matrix.python-version }} - if: matrix.os == 'macos-11' && matrix.python-version == '3.7' - uses: actions/setup-python@v2 - with: - python-version: "3.7.16" - - name: Set up Python ${{ matrix.python-version }} - if: matrix.os == 'macos-11' && matrix.python-version == '3.8' - uses: actions/setup-python@v2 - with: - python-version: "3.8.16" - - name: Set up Python ${{ matrix.python-version }} - if: matrix.os != 'macos-11' uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | - python -m pip install --upgrade pip - pip install setuptools wheel twine + make dev - name: Build wheel on ${{ matrix.os }} run: | - pip install numpy - pip install cython - python setup.py bdist_wheel + make build - name: Build and publish env: TWINE_USERNAME: __token__ diff --git a/.github/workflows/test_qlib_from_source.yml b/.github/workflows/test_qlib_from_source.yml index db878be837..4dfa1f1cb2 100644 --- a/.github/workflows/test_qlib_from_source.yml +++ b/.github/workflows/test_qlib_from_source.yml @@ -19,25 +19,15 @@ jobs: # If you want to use python 3.7 in github action, then the latest macos system version is macos-13, # after macos-13 python 3.7 is no longer supported. # so we limit the macos version to macos-13. - os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-13] + os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-13, macos-14, macos-latest] # not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129 - python-version: [3.7, 3.8] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] steps: - name: Test qlib from source uses: actions/checkout@v3 - # Since version 3.7 of python for MacOS is installed in CI, version 3.7.17, this version causes "_bz not found error". - # So we make the version number of python 3.7 for MacOS more specific. - # refs: https://github.com/actions/setup-python/issues/682 - name: Set up Python ${{ matrix.python-version }} - if: (matrix.os == 'macos-latest' && matrix.python-version == '3.7') || (matrix.os == 'macos-13' && matrix.python-version == '3.7') - uses: actions/setup-python@v4 - with: - python-version: "3.7.16" - - - name: Set up Python ${{ matrix.python-version }} - if: (matrix.os != 'macos-latest' || matrix.python-version != '3.7') && (matrix.os != 'macos-13' || matrix.python-version != '3.7') uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} @@ -47,7 +37,7 @@ jobs: python -m pip install --upgrade pip - name: Installing pytorch for macos - if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-latest' }} + if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-14' || matrix.os == 'macos-latest' }} run: | python -m pip install torch torchvision torchaudio @@ -63,87 +53,33 @@ jobs: - name: Set up Python tools run: | - python -m pip install --upgrade cython - python -m pip install -e .[dev] + make dev - name: Lint with Black - # Python 3.7 will use a black with low level. So we use python with higher version for black check - if: (matrix.python-version != '3.7') run: | - pip install -U black # follow the latest version of black, previous Qlib dependency will downgrade black - black . -l 120 --check --diff + make black - name: Make html with sphinx # Since read the docs builds on ubuntu 22.04, we only need to test that the build passes on ubuntu 22.04. if: ${{ matrix.os == 'ubuntu-22.04' }} run: | - cd docs - sphinx-build -W --keep-going -b html . _build - cd .. - - # Check Qlib with pylint - # TODO: These problems we will solve in the future. Important among them are: W0221, W0223, W0237, E1102 - # C0103: invalid-name - # C0209: consider-using-f-string - # R0402: consider-using-from-import - # R1705: no-else-return - # R1710: inconsistent-return-statements - # R1725: super-with-arguments - # R1735: use-dict-literal - # W0102: dangerous-default-value - # W0212: protected-access - # W0221: arguments-differ - # W0223: abstract-method - # W0231: super-init-not-called - # W0237: arguments-renamed - # W0612: unused-variable - # W0621: redefined-outer-name - # W0622: redefined-builtin - # FIXME: specify exception type - # W0703: broad-except - # W1309: f-string-without-interpolation - # E1102: not-callable - # E1136: unsubscriptable-object - # References for parameters: https://github.com/PyCQA/pylint/issues/4577#issuecomment-1000245962 - # We use sys.setrecursionlimit(2000) to make the recursion depth larger to ensure that pylint works properly (the default recursion depth is 1000). + make docs-gen + - name: Check Qlib with pylint run: | - pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)" - pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0246,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' scripts --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)" - - # The following flake8 error codes were ignored: - # E501 line too long - # Description: We have used black to limit the length of each line to 120. - # F541 f-string is missing placeholders - # Description: The same thing is done when using pylint for detection. - # E266 too many leading '#' for block comment - # Description: To make the code more readable, a lot of "#" is used. - # This error code appears centrally in: - # qlib/backtest/executor.py - # qlib/data/ops.py - # qlib/utils/__init__.py - # E402 module level import not at top of file - # Description: There are times when module level import is not available at the top of the file. - # W503 line break before binary operator - # Description: Since black formats the length of each line of code, it has to perform a line break when a line of arithmetic is too long. - # E731 do not assign a lambda expression, use a def - # Description: Restricts the use of lambda expressions, but at some point lambda expressions are required. - # E203 whitespace before ':' - # Description: If there is whitespace before ":", it cannot pass the black check. + make pylint + - name: Check Qlib with flake8 run: | - flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib + make flake8 - # https://github.com/python/mypy/issues/10600 - name: Check Qlib with mypy run: | - mypy qlib --install-types --non-interactive || true - mypy qlib --verbose + make mypy - name: Check Qlib ipynb with nbqa run: | - nbqa black . -l 120 --check --diff - nbqa pylint . --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136,W0719,W0104,W0404,C0412,W0611,C0410 --const-rgx='[a-z_][a-z0-9_]{2,30}$' + make nbqa - name: Test data downloads run: | @@ -151,7 +87,7 @@ jobs: python scripts/get_data.py download_data --file_name rl_data.zip --target_dir tests/.data/rl - name: Install Lightgbm for MacOS - if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-latest' }} + if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-14' || matrix.os == 'macos-latest' }} run: | /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)" HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm @@ -161,11 +97,9 @@ jobs: brew unlink libomp brew install libomp.rb - # Run after data downloads - name: Check Qlib ipynb with nbconvert run: | - # add more ipynb files in future - jupyter nbconvert --to notebook --execute examples/workflow_by_code.ipynb + make nbconvert - name: Test workflow by config (install from source) run: | diff --git a/.github/workflows/test_qlib_from_source_slow.yml b/.github/workflows/test_qlib_from_source_slow.yml index 350d64c047..d302fe3072 100644 --- a/.github/workflows/test_qlib_from_source_slow.yml +++ b/.github/workflows/test_qlib_from_source_slow.yml @@ -19,41 +19,29 @@ jobs: # If you want to use python 3.7 in github action, then the latest macos system version is macos-13, # after macos-13 python 3.7 is no longer supported. # so we limit the macos version to macos-13. - os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-13] + os: [windows-latest, ubuntu-20.04, ubuntu-22.04, macos-13, macos-14, macos-latest] # not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129 - python-version: [3.7, 3.8] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] steps: - name: Test qlib from source slow uses: actions/checkout@v3 - # Since version 3.7 of python for MacOS is installed in CI, version 3.7.17, this version causes "_bz not found error". - # So we make the version number of python 3.7 for MacOS more specific. - # refs: https://github.com/actions/setup-python/issues/682 - name: Set up Python ${{ matrix.python-version }} - if: (matrix.os == 'macos-latest' && matrix.python-version == '3.7') || (matrix.os == 'macos-13' && matrix.python-version == '3.7') - uses: actions/setup-python@v4 - with: - python-version: "3.7.16" - - - name: Set up Python ${{ matrix.python-version }} - if: (matrix.os != 'macos-latest' || matrix.python-version != '3.7') && (matrix.os != 'macos-13' || matrix.python-version != '3.7') uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Set up Python tools run: | - python -m pip install --upgrade pip - pip install --upgrade cython numpy - pip install -e .[dev] + make dev - name: Downloads dependencies data run: | python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data --interval 1d --region cn - name: Install Lightgbm for MacOS - if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-latest' }} + if: ${{ matrix.os == 'macos-13' || matrix.os == 'macos-14' || matrix.os == 'macos-latest' }} run: | /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Microsoft/qlib/main/.github/brew_install.sh)" HOMEBREW_NO_AUTO_UPDATE=1 brew install lightgbm diff --git a/.gitignore b/.gitignore index 29ea1cd5e3..10e3cd2ff1 100644 --- a/.gitignore +++ b/.gitignore @@ -48,4 +48,5 @@ tags *.swp ./pretrain -.idea/ \ No newline at end of file +.idea/ +.aider* diff --git a/MANIFEST.in b/MANIFEST.in index 8dd91c79d2..2b8421a53d 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1,6 @@ -include qlib/VERSION.txt +exclude tests/* +include qlib/* +include qlib/*/* +include qlib/*/*/* +include qlib/*/*/*/* +include qlib/*/*/*/*/* diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000..27824a6d3d --- /dev/null +++ b/Makefile @@ -0,0 +1,195 @@ +.PHONY: clean deepclean prerequisite dependencies lightgbm rl develop lint docs package test analysis all install dev black pylint flake8 mypy nbqa nbconvert lint build upload docs-gen +#You can modify it according to your terminal +SHELL := /bin/bash + +######################################################################################## +# Variables +######################################################################################## + +# Documentation target directory, will be adapted to specific folder for readthedocs. +PUBLIC_DIR := $(shell [ "$$READTHEDOCS" = "True" ] && echo "$$READTHEDOCS_OUTPUT/html" || echo "public") + +SO_DIR := qlib/data/_libs +SO_FILES := $(wildcard $(SO_DIR)/*.so) + +######################################################################################## +# Development Environment Management +######################################################################################## +# Remove common intermediate files. +clean: + -rm -rf \ + $(PUBLIC_DIR) \ + qlib/data/_libs/*.cpp \ + qlib/data/_libs/*.so \ + mlruns \ + public \ + build \ + .coverage \ + .mypy_cache \ + .pytest_cache \ + .ruff_cache \ + Pipfile* \ + coverage.xml \ + dist \ + release-notes.md + + find . -name '*.egg-info' -print0 | xargs -0 rm -rf + find . -name '*.pyc' -print0 | xargs -0 rm -f + find . -name '*.swp' -print0 | xargs -0 rm -f + find . -name '.DS_Store' -print0 | xargs -0 rm -f + find . -name '__pycache__' -print0 | xargs -0 rm -rf + +# Remove pre-commit hook, virtual environment alongside itermediate files. +deepclean: clean + if command -v pre-commit > /dev/null 2>&1; then pre-commit uninstall --hook-type pre-push; fi + if command -v pipenv >/dev/null 2>&1 && pipenv --venv >/dev/null 2>&1; then pipenv --rm; fi + +# Prerequisite section +# What this code does is compile two Cython modules, rolling and expanding, using setuptools and Cython, +# and builds them as binary expansion modules that can be imported directly into Python. +# Since pyproject.toml can't do that, we compile it here. +prerequisite: + @if [ -n "$(SO_FILES)" ]; then \ + echo "Shared library files exist, skipping build."; \ + else \ + echo "No shared library files found, building..."; \ + pip install --upgrade setuptools wheel; \ + python -m pip install cython numpy; \ + python -c "from setuptools import setup, Extension; from Cython.Build import cythonize; import numpy; extensions = [Extension('qlib.data._libs.rolling', ['qlib/data/_libs/rolling.pyx'], language='c++', include_dirs=[numpy.get_include()]), Extension('qlib.data._libs.expanding', ['qlib/data/_libs/expanding.pyx'], language='c++', include_dirs=[numpy.get_include()])]; setup(ext_modules=cythonize(extensions, language_level='3'), script_args=['build_ext', '--inplace'])"; \ + fi + +# Install the package in editable mode. +dependencies: + python -m pip install -e . + +lightgbm: + python -m pip install lightgbm --prefer-binary + +rl: + python -m pip install -e .[rl] + +develop: + python -m pip install -e .[dev] + +lint: + python -m pip install -e .[lint] + +docs: + python -m pip install -e .[docs] + +package: + python -m pip install -e .[package] + +test: + python -m pip install -e .[test] + +analysis: + python -m pip install -e .[analysis] + +all: + python -m pip install -e .[dev,lint,docs,package,test,analysis,rl] + +install: prerequisite dependencies + +dev: prerequisite all + +######################################################################################## +# Lint and pre-commit +######################################################################################## + +# Check lint with black. +black: + black . -l 120 --check --diff + +# Check code folder with pylint. +# TODO: These problems we will solve in the future. Important among them are: W0221, W0223, W0237, E1102 +# C0103: invalid-name +# C0209: consider-using-f-string +# R0402: consider-using-from-import +# R1705: no-else-return +# R1710: inconsistent-return-statements +# R1725: super-with-arguments +# R1735: use-dict-literal +# W0102: dangerous-default-value +# W0212: protected-access +# W0221: arguments-differ +# W0223: abstract-method +# W0231: super-init-not-called +# W0237: arguments-renamed +# W0612: unused-variable +# W0621: redefined-outer-name +# W0622: redefined-builtin +# FIXME: specify exception type +# W0703: broad-except +# W1309: f-string-without-interpolation +# E1102: not-callable +# E1136: unsubscriptable-object +# W4904: deprecated-class +# R0917: too-many-positional-arguments +# E1123: unexpected-keyword-arg +# References for disable error: https://pylint.pycqa.org/en/latest/user_guide/messages/messages_overview.html +# We use sys.setrecursionlimit(2000) to make the recursion depth larger to ensure that pylint works properly (the default recursion depth is 1000). +# References for parameters: https://github.com/PyCQA/pylint/issues/4577#issuecomment-1000245962 +pylint: + pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R0917,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,W4904,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1730,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}' qlib --init-hook="import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)" + pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R0917,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,E1123,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0246,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}' scripts --init-hook="import astroid; astroid.context.InferenceContext.max_inferred = 500; import sys; sys.setrecursionlimit(2000)" + +# Check code with flake8. +# The following flake8 error codes were ignored: +# E501 line too long +# Description: We have used black to limit the length of each line to 120. +# F541 f-string is missing placeholders +# Description: The same thing is done when using pylint for detection. +# E266 too many leading '#' for block comment +# Description: To make the code more readable, a lot of "#" is used. +# This error code appears centrally in: +# qlib/backtest/executor.py +# qlib/data/ops.py +# qlib/utils/__init__.py +# E402 module level import not at top of file +# Description: There are times when module level import is not available at the top of the file. +# W503 line break before binary operator +# Description: Since black formats the length of each line of code, it has to perform a line break when a line of arithmetic is too long. +# E731 do not assign a lambda expression, use a def +# Description: Restricts the use of lambda expressions, but at some point lambda expressions are required. +# E203 whitespace before ':' +# Description: If there is whitespace before ":", it cannot pass the black check. +flake8: + flake8 --ignore=E501,F541,E266,E402,W503,E731,E203 --per-file-ignores="__init__.py:F401,F403" qlib + +# Check code with mypy. +# https://github.com/python/mypy/issues/10600 +mypy: + mypy qlib --install-types --non-interactive + mypy qlib --verbose + +# Check ipynb with nbqa. +nbqa: + nbqa black . -l 120 --check --diff + nbqa pylint . --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136,W0719,W0104,W0404,C0412,W0611,C0410 --const-rgx='[a-z_][a-z0-9_]{2,30}' + +# Check ipynb with nbconvert.(Run after data downloads) +# TODO: Add more ipynb files in future +nbconvert: + jupyter nbconvert --to notebook --execute examples/workflow_by_code.ipynb + +lint: black pylint flake8 mypy nbqa + +######################################################################################## +# Package +######################################################################################## + +# Build the package. +build: + python -m build + +# Upload the package. +upload: + python -m twine upload dist/* + +######################################################################################## +# Documentation +######################################################################################## + +docs-gen: + python -m sphinx.cmd.build -W docs $(PUBLIC_DIR) \ No newline at end of file diff --git a/README.md b/README.md index e1aff0cbe0..1621800963 100644 --- a/README.md +++ b/README.md @@ -358,7 +358,7 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu ``` Here are detailed documents for `qrun` and [workflow](https://qlib.readthedocs.io/en/latest/component/workflow.html). -2. Graphical Reports Analysis: Run `examples/workflow_by_code.ipynb` with `jupyter notebook` to get graphical reports +2. Graphical Reports Analysis: First, run `python -m pip install .[analysis]` to install the required dependencies. Then run `examples/workflow_by_code.ipynb` with `jupyter notebook` to get graphical reports. - Forecasting signal (model prediction) analysis - Cumulative Return of groups ![Cumulative Return](https://github.com/microsoft/qlib/blob/main/docs/_static/img/analysis/analysis_model_cumulative_return.png) diff --git a/examples/benchmarks/TRA/example.py b/examples/benchmarks/TRA/example.py index 0d52c87750..f7e16ddee4 100644 --- a/examples/benchmarks/TRA/example.py +++ b/examples/benchmarks/TRA/example.py @@ -1,14 +1,15 @@ import argparse import qlib -import ruamel.yaml as yaml +from ruamel.yaml import YAML from qlib.utils import init_instance_by_config def main(seed, config_file="configs/config_alstm.yaml"): # set random seed with open(config_file) as f: - config = yaml.safe_load(f) + yaml = YAML(typ="safe", pure=True) + config = yaml.load(f) # seed_suffix = "/seed1000" if "init" in config_file else f"/seed{seed}" seed_suffix = "" diff --git a/examples/data_demo/data_cache_demo.py b/examples/data_demo/data_cache_demo.py index 6898c1e829..55adb8be68 100644 --- a/examples/data_demo/data_cache_demo.py +++ b/examples/data_demo/data_cache_demo.py @@ -9,8 +9,8 @@ from pathlib import Path import pickle from pprint import pprint +from ruamel.yaml import YAML import subprocess -import yaml from qlib.log import TimeInspector from qlib import init @@ -30,7 +30,8 @@ subprocess.run(f"qrun {config_path}", shell=True) # 2) dump handler - task_config = yaml.safe_load(config_path.open()) + yaml = YAML(typ="safe", pure=True) + task_config = yaml.load(config_path.open()) hd_conf = task_config["task"]["dataset"]["kwargs"]["handler"] pprint(hd_conf) hd: DataHandlerLP = init_instance_by_config(hd_conf) diff --git a/examples/data_demo/data_mem_resuse_demo.py b/examples/data_demo/data_mem_resuse_demo.py index 9cc44e86cb..cec5133063 100644 --- a/examples/data_demo/data_mem_resuse_demo.py +++ b/examples/data_demo/data_mem_resuse_demo.py @@ -9,10 +9,9 @@ from pathlib import Path import pickle from pprint import pprint +from ruamel.yaml import YAML import subprocess -import yaml - from qlib import init from qlib.data.dataset.handler import DataHandlerLP from qlib.log import TimeInspector @@ -29,7 +28,8 @@ exp_name = "data_mem_reuse_demo" config_path = DIRNAME.parent / "benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml" - task_config = yaml.safe_load(config_path.open()) + yaml = YAML(typ="safe", pure=True) + task_config = yaml.load(config_path.open()) # 1) without using processed data in memory with TimeInspector.logt("The original time without reusing processed data in memory:"): diff --git a/examples/run_all_model.py b/examples/run_all_model.py index dda3b98f62..70571556b1 100644 --- a/examples/run_all_model.py +++ b/examples/run_all_model.py @@ -6,7 +6,6 @@ import fire import time import glob -import yaml import shutil import signal import inspect @@ -15,6 +14,7 @@ import statistics import subprocess from datetime import datetime +from ruamel.yaml import YAML from pathlib import Path from operator import xor from pprint import pprint @@ -188,7 +188,8 @@ def gen_and_save_md_table(metrics, dataset): # read yaml, remove seed kwargs of model, and then save file in the temp_dir def gen_yaml_file_without_seed_kwargs(yaml_path, temp_dir): with open(yaml_path, "r") as fp: - config = yaml.safe_load(fp) + yaml = YAML(typ="safe", pure=True) + config = yaml.load(fp) try: del config["task"]["model"]["kwargs"]["seed"] except KeyError: diff --git a/pyproject.toml b/pyproject.toml index 6350d092c7..547625d53e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,2 +1,92 @@ [build-system] -requires = ["setuptools", "numpy", "Cython"] +requires = ["setuptools", "cython", "numpy>=1.24.0"] +build-backend = "setuptools.build_meta" + +[project] +classifiers = [ + "Operating System :: POSIX :: Linux", + "Operating System :: Microsoft :: Windows", + "Operating System :: MacOS", + "License :: OSI Approved :: MIT License", + "Development Status :: 3 - Alpha", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] +name = "pyqlib" +dynamic = ["version"] +description = "A Quantitative-research Platform" +requires-python = ">=3.8.0" + +dependencies = [ + "pyyaml", + "numpy", + "pandas", + "mlflow", + "filelock>=3.16.0", + "redis", + "dill", + "fire", + "ruamel.yaml>=0.17.38", + "python-redis-lock", + "tqdm", + "pymongo", + "loguru", + "lightgbm", + "gym", + "cvxpy", + "joblib", + "matplotlib", + "jupyter", + "nbconvert", +] + +[project.optional-dependencies] +dev = [ + "pytest", + "statsmodels", +] +# On macos-13 system, when using python version greater than or equal to 3.10, +# pytorch can't fully support Numpy version above 2.0, so, when you want to install torch, +# it will limit the version of Numpy less than 2.0. +rl = [ + "tianshou<=0.4.10", + "torch", + "numpy<2.0.0", +] +lint = [ + "black", + "pylint", + "mypy<1.5.0", + "flake8", + "nbqa", +] +docs = [ + "sphinx", + "sphinx_rtd_theme", + "readthedocs_sphinx_ext", +] +package = [ + "twine", + "build", +] +# test_pit dependency packages +test = [ + "yahooquery", + "baostock", +] +analysis = [ + "plotly", +] + +[tool.setuptools] +packages = [ + "qlib", +] + +[project.scripts] +qrun = "qlib.workflow.cli:run" diff --git a/qlib/__init__.py b/qlib/__init__.py index fca74e4567..db45762b86 100644 --- a/qlib/__init__.py +++ b/qlib/__init__.py @@ -6,7 +6,7 @@ __version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version import os from typing import Union -import yaml +from ruamel.yaml import YAML import logging import platform import subprocess @@ -176,7 +176,8 @@ def init_from_yaml_conf(conf_path, **kwargs): config = {} else: with open(conf_path) as f: - config = yaml.safe_load(f) + yaml = YAML(typ="safe", pure=True) + config = yaml.load(f) config.update(kwargs) default_conf = config.pop("default_conf", "client") init(default_conf, **config) @@ -272,7 +273,8 @@ def auto_init(**kwargs): logger = get_module_logger("Initialization") conf_pp = pp / "config.yaml" with conf_pp.open() as f: - conf = yaml.safe_load(f) + yaml = YAML(typ="safe", pure=True) + conf = yaml.load(f) conf_type = conf.get("conf_type", "origin") if conf_type == "origin": diff --git a/qlib/backtest/high_performance_ds.py b/qlib/backtest/high_performance_ds.py index dc467bd59b..67acc7adde 100644 --- a/qlib/backtest/high_performance_ds.py +++ b/qlib/backtest/high_performance_ds.py @@ -278,7 +278,7 @@ def empty(self) -> bool: raise NotImplementedError(f"Please implement the `empty` method") def add(self, other: BaseSingleMetric, fill_value: float = None) -> BaseSingleMetric: - """Replace np.NaN with fill_value in two metrics and add them.""" + """Replace np.nan with fill_value in two metrics and add them.""" raise NotImplementedError(f"Please implement the `add` method") @@ -412,7 +412,7 @@ def sum_all_indicators( metrics : Union[str, List[str]] all metrics needs to be sumed. fill_value : float, optional - fill np.NaN with value. By default None. + fill np.nan with value. By default None. """ raise NotImplementedError(f"Please implement the 'sum_all_indicators' method") diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py index e7c6041efd..89f595df75 100644 --- a/qlib/backtest/report.py +++ b/qlib/backtest/report.py @@ -325,9 +325,9 @@ def _update_order_trade_info(self, trade_info: List[Tuple[Order, float, float, f def _update_order_fulfill_rate(self) -> None: def func(deal_amount, amount): - # deal_amount is np.NaN or None when there is no inner decision. So full fill rate is 0. + # deal_amount is np.nan or None when there is no inner decision. So full fill rate is 0. tmp_deal_amount = deal_amount.reindex(amount.index, 0) - tmp_deal_amount = tmp_deal_amount.replace({np.NaN: 0}) + tmp_deal_amount = tmp_deal_amount.replace({np.nan: 0}) return tmp_deal_amount / amount self.order_indicator.transfer(func, "ffr") @@ -354,8 +354,8 @@ def trade_amount_func(deal_amount, trade_price): ) def func(trade_price, deal_amount): - # trade_price is np.NaN instead of inf when deal_amount is zero. - tmp_deal_amount = deal_amount.replace({0: np.NaN}) + # trade_price is np.nan instead of inf when deal_amount is zero. + tmp_deal_amount = deal_amount.replace({0: np.nan}) return trade_price / tmp_deal_amount self.order_indicator.transfer(func, "trade_price") @@ -425,7 +425,7 @@ def _get_base_vol_pri( assert isinstance(price_s, idd.SingleData) price_s = price_s.loc[(price_s > 1e-08).data.astype(bool)] # NOTE ~(price_s < 1e-08) is different from price_s >= 1e-8 - # ~(np.NaN < 1e-8) -> ~(False) -> True + # ~(np.nan < 1e-8) -> ~(False) -> True assert isinstance(price_s, idd.SingleData) if agg == "vwap": diff --git a/qlib/contrib/data/handler.py b/qlib/contrib/data/handler.py index 7c63e5a639..2fe5258daa 100644 --- a/qlib/contrib/data/handler.py +++ b/qlib/contrib/data/handler.py @@ -58,7 +58,7 @@ def __init__( fit_end_time=None, filter_pipe=None, inst_processors=None, - **kwargs + **kwargs, ): infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time) @@ -83,7 +83,7 @@ def __init__( data_loader=data_loader, learn_processors=learn_processors, infer_processors=infer_processors, - **kwargs + **kwargs, ) def get_label_config(self): @@ -109,7 +109,7 @@ def __init__( process_type=DataHandlerLP.PTYPE_A, filter_pipe=None, inst_processors=None, - **kwargs + **kwargs, ): infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time) learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time) @@ -134,7 +134,7 @@ def __init__( infer_processors=infer_processors, learn_processors=learn_processors, process_type=process_type, - **kwargs + **kwargs, ) def get_feature_config(self): diff --git a/qlib/contrib/model/catboost_model.py b/qlib/contrib/model/catboost_model.py index ac30028f99..4fc1c6f893 100644 --- a/qlib/contrib/model/catboost_model.py +++ b/qlib/contrib/model/catboost_model.py @@ -33,7 +33,7 @@ def fit( verbose_eval=20, evals_result=dict(), reweighter=None, - **kwargs + **kwargs, ): df_train, df_valid = dataset.prepare( ["train", "valid"], diff --git a/qlib/contrib/model/double_ensemble.py b/qlib/contrib/model/double_ensemble.py index f0b2188d06..de737b56da 100644 --- a/qlib/contrib/model/double_ensemble.py +++ b/qlib/contrib/model/double_ensemble.py @@ -31,7 +31,7 @@ def __init__( sub_weights=None, epochs=100, early_stopping_rounds=None, - **kwargs + **kwargs, ): self.base_model = base_model # "gbm" or "mlp", specifically, we use lgbm for "gbm" self.num_models = num_models # the number of sub-models diff --git a/qlib/contrib/model/pytorch_adarnn.py b/qlib/contrib/model/pytorch_adarnn.py index ca5e8ba865..6988837efb 100644 --- a/qlib/contrib/model/pytorch_adarnn.py +++ b/qlib/contrib/model/pytorch_adarnn.py @@ -56,7 +56,7 @@ def __init__( n_splits=2, GPU=0, seed=None, - **_ + **_, ): # Set logger. self.logger = get_module_logger("ADARNN") @@ -154,10 +154,7 @@ def train_AdaRNN(self, train_loader_list, epoch, dist_old=None, weight_mat=None) self.model.train() criterion = nn.MSELoss() dist_mat = torch.zeros(self.num_layers, self.len_seq).to(self.device) - len_loader = np.inf - for loader in train_loader_list: - if len(loader) < len_loader: - len_loader = len(loader) + out_weight_list = None for data_all in zip(*train_loader_list): # for data_all in zip(*train_loader_list): self.train_optimizer.zero_grad() @@ -571,6 +568,7 @@ def compute(self, X, Y): Returns: [tensor] -- transfer loss """ + loss = None if self.loss_type in ("mmd_lin", "mmd"): mmdloss = MMD_loss(kernel_type="linear") loss = mmdloss(X, Y) diff --git a/qlib/contrib/model/pytorch_add.py b/qlib/contrib/model/pytorch_add.py index e929fe97f8..e97621157a 100644 --- a/qlib/contrib/model/pytorch_add.py +++ b/qlib/contrib/model/pytorch_add.py @@ -63,7 +63,7 @@ def __init__( mu=0.05, GPU=0, seed=None, - **kwargs + **kwargs, ): # Set logger. self.logger = get_module_logger("ADD") diff --git a/qlib/contrib/model/pytorch_alstm.py b/qlib/contrib/model/pytorch_alstm.py index 2fe7cce3b0..d1c619ebf4 100644 --- a/qlib/contrib/model/pytorch_alstm.py +++ b/qlib/contrib/model/pytorch_alstm.py @@ -52,7 +52,7 @@ def __init__( optimizer="adam", GPU=0, seed=None, - **kwargs + **kwargs, ): # Set logger. self.logger = get_module_logger("ALSTM") diff --git a/qlib/contrib/model/pytorch_alstm_ts.py b/qlib/contrib/model/pytorch_alstm_ts.py index 3fb7cb9e19..95b5cf95d8 100644 --- a/qlib/contrib/model/pytorch_alstm_ts.py +++ b/qlib/contrib/model/pytorch_alstm_ts.py @@ -56,7 +56,7 @@ def __init__( n_jobs=10, GPU=0, seed=None, - **kwargs + **kwargs, ): # Set logger. self.logger = get_module_logger("ALSTM") diff --git a/qlib/contrib/model/pytorch_gats.py b/qlib/contrib/model/pytorch_gats.py index 63ebd480a4..2a39e4b0ff 100644 --- a/qlib/contrib/model/pytorch_gats.py +++ b/qlib/contrib/model/pytorch_gats.py @@ -56,7 +56,7 @@ def __init__( optimizer="adam", GPU=0, seed=None, - **kwargs + **kwargs, ): # Set logger. self.logger = get_module_logger("GATs") diff --git a/qlib/contrib/model/pytorch_gats_ts.py b/qlib/contrib/model/pytorch_gats_ts.py index b1239f78e1..3bcb73c551 100644 --- a/qlib/contrib/model/pytorch_gats_ts.py +++ b/qlib/contrib/model/pytorch_gats_ts.py @@ -73,7 +73,7 @@ def __init__( GPU=0, n_jobs=10, seed=None, - **kwargs + **kwargs, ): # Set logger. self.logger = get_module_logger("GATs") diff --git a/qlib/contrib/model/pytorch_general_nn.py b/qlib/contrib/model/pytorch_general_nn.py index 696a20254f..1e660fa080 100644 --- a/qlib/contrib/model/pytorch_general_nn.py +++ b/qlib/contrib/model/pytorch_general_nn.py @@ -319,7 +319,12 @@ def fit( if self.use_gpu: torch.cuda.empty_cache() - def predict(self, dataset: Union[DatasetH, TSDatasetH]): + def predict( + self, + dataset: Union[DatasetH, TSDatasetH], + batch_size=None, + n_jobs=None, + ): if not self.fitted: raise ValueError("model is not fitted yet!") diff --git a/qlib/contrib/model/pytorch_gru.py b/qlib/contrib/model/pytorch_gru.py index 3306115507..06aa6810b8 100755 --- a/qlib/contrib/model/pytorch_gru.py +++ b/qlib/contrib/model/pytorch_gru.py @@ -52,7 +52,7 @@ def __init__( optimizer="adam", GPU=0, seed=None, - **kwargs + **kwargs, ): # Set logger. self.logger = get_module_logger("GRU") diff --git a/qlib/contrib/model/pytorch_gru_ts.py b/qlib/contrib/model/pytorch_gru_ts.py index 2e5076ea67..65da5ac4b4 100755 --- a/qlib/contrib/model/pytorch_gru_ts.py +++ b/qlib/contrib/model/pytorch_gru_ts.py @@ -54,7 +54,7 @@ def __init__( n_jobs=10, GPU=0, seed=None, - **kwargs + **kwargs, ): # Set logger. self.logger = get_module_logger("GRU") diff --git a/qlib/contrib/model/pytorch_hist.py b/qlib/contrib/model/pytorch_hist.py index 33df8e4875..e4220d0556 100644 --- a/qlib/contrib/model/pytorch_hist.py +++ b/qlib/contrib/model/pytorch_hist.py @@ -59,7 +59,7 @@ def __init__( optimizer="adam", GPU=0, seed=None, - **kwargs + **kwargs, ): # Set logger. self.logger = get_module_logger("HIST") diff --git a/qlib/contrib/model/pytorch_igmtf.py b/qlib/contrib/model/pytorch_igmtf.py index 46a25c00f4..3bc5ac78d9 100644 --- a/qlib/contrib/model/pytorch_igmtf.py +++ b/qlib/contrib/model/pytorch_igmtf.py @@ -55,7 +55,7 @@ def __init__( optimizer="adam", GPU=0, seed=None, - **kwargs + **kwargs, ): # Set logger. self.logger = get_module_logger("IGMTF") diff --git a/qlib/contrib/model/pytorch_krnn.py b/qlib/contrib/model/pytorch_krnn.py index 7c252672d1..f69d1d23b1 100644 --- a/qlib/contrib/model/pytorch_krnn.py +++ b/qlib/contrib/model/pytorch_krnn.py @@ -255,7 +255,7 @@ def __init__( optimizer="adam", GPU=0, seed=None, - **kwargs + **kwargs, ): # Set logger. self.logger = get_module_logger("KRNN") diff --git a/qlib/contrib/model/pytorch_localformer.py b/qlib/contrib/model/pytorch_localformer.py index 830bc59f03..42851dd6a2 100644 --- a/qlib/contrib/model/pytorch_localformer.py +++ b/qlib/contrib/model/pytorch_localformer.py @@ -44,7 +44,7 @@ def __init__( n_jobs=10, GPU=0, seed=None, - **kwargs + **kwargs, ): # set hyper-parameters. self.d_model = d_model diff --git a/qlib/contrib/model/pytorch_localformer_ts.py b/qlib/contrib/model/pytorch_localformer_ts.py index b05c2d311a..ae60a39968 100644 --- a/qlib/contrib/model/pytorch_localformer_ts.py +++ b/qlib/contrib/model/pytorch_localformer_ts.py @@ -42,7 +42,7 @@ def __init__( n_jobs=10, GPU=0, seed=None, - **kwargs + **kwargs, ): # set hyper-parameters. self.d_model = d_model diff --git a/qlib/contrib/model/pytorch_lstm.py b/qlib/contrib/model/pytorch_lstm.py index 168be6ca56..3ba09097ac 100755 --- a/qlib/contrib/model/pytorch_lstm.py +++ b/qlib/contrib/model/pytorch_lstm.py @@ -51,7 +51,7 @@ def __init__( optimizer="adam", GPU=0, seed=None, - **kwargs + **kwargs, ): # Set logger. self.logger = get_module_logger("LSTM") diff --git a/qlib/contrib/model/pytorch_lstm_ts.py b/qlib/contrib/model/pytorch_lstm_ts.py index 8ecafc2d5d..a0fc34d583 100755 --- a/qlib/contrib/model/pytorch_lstm_ts.py +++ b/qlib/contrib/model/pytorch_lstm_ts.py @@ -53,7 +53,7 @@ def __init__( n_jobs=10, GPU=0, seed=None, - **kwargs + **kwargs, ): # Set logger. self.logger = get_module_logger("LSTM") diff --git a/qlib/contrib/model/pytorch_sandwich.py b/qlib/contrib/model/pytorch_sandwich.py index 020c736fd3..344368143f 100644 --- a/qlib/contrib/model/pytorch_sandwich.py +++ b/qlib/contrib/model/pytorch_sandwich.py @@ -35,7 +35,7 @@ def __init__( rnn_layers, dropout, device, - **params + **params, ): """Build a Sandwich model @@ -129,7 +129,7 @@ def __init__( optimizer="adam", GPU=0, seed=None, - **kwargs + **kwargs, ): # Set logger. self.logger = get_module_logger("Sandwich") diff --git a/qlib/contrib/model/pytorch_sfm.py b/qlib/contrib/model/pytorch_sfm.py index e79f475d69..c971f1a58c 100644 --- a/qlib/contrib/model/pytorch_sfm.py +++ b/qlib/contrib/model/pytorch_sfm.py @@ -212,7 +212,7 @@ def __init__( optimizer="gd", GPU=0, seed=None, - **kwargs + **kwargs, ): # Set logger. self.logger = get_module_logger("SFM") diff --git a/qlib/contrib/model/pytorch_tcn.py b/qlib/contrib/model/pytorch_tcn.py index 38e289342d..f6e7e953a0 100755 --- a/qlib/contrib/model/pytorch_tcn.py +++ b/qlib/contrib/model/pytorch_tcn.py @@ -56,7 +56,7 @@ def __init__( optimizer="adam", GPU=0, seed=None, - **kwargs + **kwargs, ): # Set logger. self.logger = get_module_logger("TCN") diff --git a/qlib/contrib/model/pytorch_tcn_ts.py b/qlib/contrib/model/pytorch_tcn_ts.py index 605da62c49..a6cc38885c 100755 --- a/qlib/contrib/model/pytorch_tcn_ts.py +++ b/qlib/contrib/model/pytorch_tcn_ts.py @@ -54,7 +54,7 @@ def __init__( n_jobs=10, GPU=0, seed=None, - **kwargs + **kwargs, ): # Set logger. self.logger = get_module_logger("TCN") diff --git a/qlib/contrib/model/pytorch_tcts.py b/qlib/contrib/model/pytorch_tcts.py index 651bd03d23..d8736627c2 100644 --- a/qlib/contrib/model/pytorch_tcts.py +++ b/qlib/contrib/model/pytorch_tcts.py @@ -58,7 +58,7 @@ def __init__( mode="soft", seed=None, lowest_valid_performance=0.993, - **kwargs + **kwargs, ): # Set logger. self.logger = get_module_logger("TCTS") diff --git a/qlib/contrib/model/pytorch_transformer.py b/qlib/contrib/model/pytorch_transformer.py index f4b7a06eb6..d05b9f4cad 100644 --- a/qlib/contrib/model/pytorch_transformer.py +++ b/qlib/contrib/model/pytorch_transformer.py @@ -43,7 +43,7 @@ def __init__( n_jobs=10, GPU=0, seed=None, - **kwargs + **kwargs, ): # set hyper-parameters. self.d_model = d_model diff --git a/qlib/contrib/model/pytorch_transformer_ts.py b/qlib/contrib/model/pytorch_transformer_ts.py index 84b093805c..70590e03e5 100644 --- a/qlib/contrib/model/pytorch_transformer_ts.py +++ b/qlib/contrib/model/pytorch_transformer_ts.py @@ -41,7 +41,7 @@ def __init__( n_jobs=10, GPU=0, seed=None, - **kwargs + **kwargs, ): # set hyper-parameters. self.d_model = d_model diff --git a/qlib/contrib/model/xgboost.py b/qlib/contrib/model/xgboost.py index 67bedafa87..634259aab1 100755 --- a/qlib/contrib/model/xgboost.py +++ b/qlib/contrib/model/xgboost.py @@ -28,7 +28,7 @@ def fit( verbose_eval=20, evals_result=dict(), reweighter=None, - **kwargs + **kwargs, ): df_train, df_valid = dataset.prepare( ["train", "valid"], @@ -63,7 +63,7 @@ def fit( early_stopping_rounds=early_stopping_rounds, verbose_eval=verbose_eval, evals_result=evals_result, - **kwargs + **kwargs, ) evals_result["train"] = list(evals_result["train"].values())[0] evals_result["valid"] = list(evals_result["valid"].values())[0] diff --git a/qlib/contrib/online/manager.py b/qlib/contrib/online/manager.py index d101bcd088..7475bb6fc5 100644 --- a/qlib/contrib/online/manager.py +++ b/qlib/contrib/online/manager.py @@ -4,10 +4,10 @@ # pylint: skip-file # flake8: noqa -import yaml import pathlib import pandas as pd import shutil +from ruamel.yaml import YAML from ...backtest.account import Account from .user import User from .utils import load_instance, save_instance @@ -110,7 +110,8 @@ def add_user(self, user_id, config_file, add_date): raise ValueError("User data for {} already exists".format(user_id)) with config_file.open("r") as fp: - config = yaml.safe_load(fp) + yaml = YAML(typ="safe", pure=True) + config = yaml.load(fp) # load model model = init_instance_by_config(config["model"]) diff --git a/qlib/contrib/online/utils.py b/qlib/contrib/online/utils.py index 5f2cbcf750..dddf7f0d2a 100644 --- a/qlib/contrib/online/utils.py +++ b/qlib/contrib/online/utils.py @@ -6,8 +6,8 @@ import pathlib import pickle -import yaml import pandas as pd +from ruamel.yaml import YAML from ...data import D from ...config import C from ...log import get_module_logger @@ -91,7 +91,8 @@ def prepare(um, today, user_id, exchange_config=None): dates.append(get_next_trading_date(dates[-1], future=True)) if exchange_config: with pathlib.Path(exchange_config).open("r") as fp: - exchange_paras = yaml.safe_load(fp) + yaml = YAML(typ="safe", pure=True) + exchange_paras = yaml.load(fp) else: exchange_paras = {} trade_exchange = Exchange(trade_dates=dates, **exchange_paras) diff --git a/qlib/contrib/report/graph.py b/qlib/contrib/report/graph.py index f9cf517ea7..387a057a29 100644 --- a/qlib/contrib/report/graph.py +++ b/qlib/contrib/report/graph.py @@ -176,7 +176,7 @@ def _get_data(self): x=self._df.columns, y=self._df.index, z=self._df.values.tolist(), - **self._graph_kwargs + **self._graph_kwargs, ) ] return _data @@ -213,7 +213,7 @@ def __init__( sub_graph_layout: dict = None, sub_graph_data: list = None, subplots_kwargs: dict = None, - **kwargs + **kwargs, ): """ @@ -355,7 +355,7 @@ def _init_figure(self): df=self._df.loc[:, [column_name]], name_dict={column_name: temp_name}, graph_kwargs=_graph_kwargs, - ) + ), ) else: raise TypeError() diff --git a/qlib/contrib/rolling/base.py b/qlib/contrib/rolling/base.py index 05467a6be2..5f17c05623 100644 --- a/qlib/contrib/rolling/base.py +++ b/qlib/contrib/rolling/base.py @@ -2,11 +2,11 @@ # Licensed under the MIT License. from copy import deepcopy from pathlib import Path +from ruamel.yaml import YAML from typing import List, Optional, Union import fire import pandas as pd -import yaml from qlib import auto_init from qlib.log import get_module_logger @@ -117,7 +117,8 @@ def __init__( def _raw_conf(self) -> dict: with self.conf_path.open("r") as f: - return yaml.safe_load(f) + yaml = YAML(typ="safe", pure=True) + return yaml.load(f) def _replace_handler_with_cache(self, task: dict): """ diff --git a/qlib/contrib/tuner/config.py b/qlib/contrib/tuner/config.py index 7a8534a20f..4cedd3642b 100644 --- a/qlib/contrib/tuner/config.py +++ b/qlib/contrib/tuner/config.py @@ -4,9 +4,9 @@ # pylint: skip-file # flake8: noqa -import yaml import copy import os +from ruamel.yaml import YAML class TunerConfigManager: @@ -16,7 +16,8 @@ def __init__(self, config_path): self.config_path = config_path with open(config_path) as fp: - config = yaml.safe_load(fp) + yaml = YAML(typ="safe", pure=True) + config = yaml.load(fp) self.config = copy.deepcopy(config) self.pipeline_ex_config = PipelineExperimentConfig(config.get("experiment", dict()), self) diff --git a/qlib/data/dataset/storage.py b/qlib/data/dataset/storage.py index 49afef9128..2adf6cd62a 100644 --- a/qlib/data/dataset/storage.py +++ b/qlib/data/dataset/storage.py @@ -104,15 +104,24 @@ def _fetch_hash_df_by_stock(self, selector, level): """ stock_selector = slice(None) + time_selector = slice(None) # by default not filter by time. if level is None: + # For directly applying. if isinstance(selector, tuple) and self.stock_level < len(selector): + # full selector format stock_selector = selector[self.stock_level] + time_selector = selector[1 - self.stock_level] elif isinstance(selector, (list, str)) and self.stock_level == 0: + # only stock selector stock_selector = selector elif level in ("instrument", self.stock_level): if isinstance(selector, tuple): + # NOTE: How could the stock level selector be a tuple? stock_selector = selector[0] + raise TypeError( + "I forget why would this case appear. But I think it does not make sense. So we raise a error for that case." + ) elif isinstance(selector, (list, str)): stock_selector = selector @@ -120,7 +129,7 @@ def _fetch_hash_df_by_stock(self, selector, level): raise TypeError(f"stock selector must be type str|list, or slice(None), rather than {stock_selector}") if stock_selector == slice(None): - return self.hash_df + return self.hash_df, time_selector if isinstance(stock_selector, str): stock_selector = [stock_selector] @@ -129,7 +138,7 @@ def _fetch_hash_df_by_stock(self, selector, level): for each_stock in sorted(stock_selector): if each_stock in self.hash_df: select_dict[each_stock] = self.hash_df[each_stock] - return select_dict + return select_dict, time_selector def fetch( self, @@ -138,10 +147,13 @@ def fetch( col_set: Union[str, List[str]] = DataHandler.CS_ALL, fetch_orig: bool = True, ) -> pd.DataFrame: - fetch_stock_df_list = list(self._fetch_hash_df_by_stock(selector=selector, level=level).values()) + fetch_stock_df_list, time_selector = self._fetch_hash_df_by_stock(selector=selector, level=level) + fetch_stock_df_list = list(fetch_stock_df_list.values()) for _index, stock_df in enumerate(fetch_stock_df_list): fetch_col_df = fetch_df_by_col(df=stock_df, col_set=col_set) - fetch_index_df = fetch_df_by_index(df=fetch_col_df, selector=selector, level=level, fetch_orig=fetch_orig) + fetch_index_df = fetch_df_by_index( + df=fetch_col_df, selector=time_selector, level="datetime", fetch_orig=fetch_orig + ) fetch_stock_df_list[_index] = fetch_index_df if len(fetch_stock_df_list) == 0: index_names = ("instrument", "datetime") if self.stock_level == 0 else ("datetime", "instrument") diff --git a/qlib/data/filter.py b/qlib/data/filter.py index 9e924f728a..5057e20a4b 100644 --- a/qlib/data/filter.py +++ b/qlib/data/filter.py @@ -164,6 +164,7 @@ def _toTimestamp(self, timestamp_series): timestamp = [] _lbool = None _ltime = None + _cur_start = None for _ts, _bool in timestamp_series.items(): # there is likely to be NAN when the filter series don't have the # bool value, so we just change the NAN into False diff --git a/qlib/rl/contrib/naive_config_parser.py b/qlib/rl/contrib/naive_config_parser.py index 2255c7414a..5608cbd1ef 100644 --- a/qlib/rl/contrib/naive_config_parser.py +++ b/qlib/rl/contrib/naive_config_parser.py @@ -7,8 +7,7 @@ import sys import tempfile from importlib import import_module - -import yaml +from ruamel.yaml import YAML DELETE_KEY = "_delete_" @@ -57,7 +56,8 @@ def parse_backtest_config(path: str) -> dict: del sys.modules[tmp_module_name] else: with open(tmp_config_file.name) as input_stream: - config = yaml.safe_load(input_stream) + yaml = YAML(typ="safe", pure=True) + config = yaml.load(input_stream) if "_base_" in config: base_file_name = config.pop("_base_") diff --git a/qlib/rl/contrib/train_onpolicy.py b/qlib/rl/contrib/train_onpolicy.py index cd5d0e55ef..83dd924103 100644 --- a/qlib/rl/contrib/train_onpolicy.py +++ b/qlib/rl/contrib/train_onpolicy.py @@ -8,12 +8,12 @@ import sys import warnings from pathlib import Path +from ruamel.yaml import YAML from typing import cast, List, Optional import numpy as np import pandas as pd import torch -import yaml from qlib.backtest import Order from qlib.backtest.decision import OrderDir from qlib.constant import ONE_MIN @@ -263,6 +263,7 @@ def main(config: dict, run_training: bool, run_backtest: bool) -> None: args = parser.parse_args() with open(args.config_path, "r") as input_stream: - config = yaml.safe_load(input_stream) + yaml = YAML(typ="safe", pure=True) + config = yaml.load(input_stream) main(config, run_training=not args.no_training, run_backtest=args.run_backtest) diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 732638b236..2a94ebd555 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -10,7 +10,6 @@ import re import copy import json -import yaml import redis import bisect import struct @@ -25,6 +24,7 @@ from pathlib import Path from typing import List, Union, Optional, Callable from packaging import version +from ruamel.yaml import YAML from .file import ( get_or_create_path, save_multiple_parts_file, @@ -244,12 +244,13 @@ def parse_config(config): if not isinstance(config, str): return config # Check whether config is file + yaml = YAML(typ="safe", pure=True) if os.path.exists(config): with open(config, "r") as f: - return yaml.safe_load(f) + return yaml.load(f) # Check whether the str can be parsed try: - return yaml.safe_load(config) + return yaml.load(config) except BaseException as base_exp: raise ValueError("cannot parse config!") from base_exp @@ -799,6 +800,7 @@ def try_replace_placeholder(value): ) return value + item_keys = None while top < tail: now_item = item_queue[top] top += 1 diff --git a/qlib/utils/index_data.py b/qlib/utils/index_data.py index 6c4525ce36..c707240d09 100644 --- a/qlib/utils/index_data.py +++ b/qlib/utils/index_data.py @@ -44,7 +44,7 @@ def concat(data_list: Union[SingleData], axis=0) -> MultiData: all_index_map = dict(zip(all_index, range(len(all_index)))) # concat all - tmp_data = np.full((len(all_index), len(data_list)), np.NaN) + tmp_data = np.full((len(all_index), len(data_list)), np.nan) for data_id, index_data in enumerate(data_list): assert isinstance(index_data, SingleData) now_data_map = [all_index_map[index] for index in index_data.index] @@ -64,7 +64,7 @@ def sum_by_index(data_list: Union[SingleData], new_index: list, fill_value=0) -> new_index : list the new_index of new SingleData. fill_value : float - fill the missing values or replace np.NaN. + fill the missing values or replace np.nan. Returns ------- @@ -444,7 +444,7 @@ def __invert__(self): return self.__class__(~self.data.astype(bool), *self.indices) def abs(self): - """get the abs of data except np.NaN.""" + """get the abs of data except np.nan.""" tmp_data = np.absolute(self.data) return self.__class__(tmp_data, *self.indices) @@ -566,8 +566,8 @@ def _align_indices(self, other): f"The indexes of self and other do not meet the requirements of the four arithmetic operations" ) - def reindex(self, index: Index, fill_value=np.NaN) -> SingleData: - """reindex data and fill the missing value with np.NaN. + def reindex(self, index: Index, fill_value=np.nan) -> SingleData: + """reindex data and fill the missing value with np.nan. Parameters ---------- @@ -615,7 +615,7 @@ def to_series(self): return pd.Series(self.data, index=self.index) def __repr__(self) -> str: - return str(pd.Series(self.data, index=self.index)) + return str(pd.Series(self.data, index=self.index.tolist())) class MultiData(IndexData): @@ -651,4 +651,4 @@ def _align_indices(self, other): ) def __repr__(self) -> str: - return str(pd.DataFrame(self.data, index=self.index, columns=self.columns)) + return str(pd.DataFrame(self.data, index=self.index.tolist(), columns=self.columns.tolist())) diff --git a/qlib/workflow/cli.py b/qlib/workflow/cli.py index cda3fdbe16..d6e401e010 100644 --- a/qlib/workflow/cli.py +++ b/qlib/workflow/cli.py @@ -7,7 +7,7 @@ import fire from jinja2 import Template, meta -import ruamel.yaml as yaml +from ruamel.yaml import YAML import qlib from qlib.config import C @@ -104,7 +104,8 @@ def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"): """ # Render the template rendered_yaml = render_template(config_path) - config = yaml.safe_load(rendered_yaml) + yaml = YAML(typ="safe", pure=True) + config = yaml.load(rendered_yaml) base_config_path = config.get("BASE_CONFIG_PATH", None) if base_config_path: @@ -126,7 +127,8 @@ def workflow(config_path, experiment_name="workflow", uri_folder="mlruns"): raise FileNotFoundError(f"Can't find the BASE_CONFIG file: {base_config_path}") with open(path) as fp: - base_config = yaml.safe_load(fp) + yaml = YAML(typ="safe", pure=True) + base_config = yaml.load(fp) logger.info(f"Load BASE_CONFIG_PATH succeed: {path.resolve()}") config = update_config(base_config, config) diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 94d17beaf1..5047ccfb26 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -8,6 +8,7 @@ from mlflow.entities import ViewType import os from typing import Optional, Text +from pathlib import Path from .exp import MLflowExperiment, Experiment from ..config import C @@ -233,7 +234,7 @@ def _get_or_create_exp(self, experiment_id=None, experiment_name=None) -> (objec # So we supported it in the interface wrapper pr = urlparse(self.uri) if pr.scheme == "file": - with FileLock(os.path.join(pr.netloc, pr.path, "filelock")): # pylint: disable=E0110 + with FileLock(Path(os.path.join(pr.netloc, pr.path.lstrip("/"), "filelock"))): # pylint: disable=E0110 return self.create_exp(experiment_name), True # NOTE: for other schemes like http, we double check to avoid create exp conflicts try: @@ -421,7 +422,11 @@ def delete_exp(self, experiment_id=None, experiment_name=None): def list_experiments(self): # retrieve all the existing experiments - exps = self.client.list_experiments(view_type=ViewType.ACTIVE_ONLY) + mlflow_version = int(mlflow.__version__.split(".", maxsplit=1)[0]) + if mlflow_version >= 2: + exps = self.client.search_experiments(view_type=ViewType.ACTIVE_ONLY) + else: + exps = self.client.list_experiments(view_type=ViewType.ACTIVE_ONLY) # pylint: disable=E1101 experiments = dict() for exp in exps: experiment = MLflowExperiment(exp.experiment_id, exp.name, self.uri) diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index 25f465936b..5fd99c0769 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -9,6 +9,7 @@ import pickle import tempfile import subprocess +import platform from pathlib import Path from datetime import datetime @@ -316,7 +317,10 @@ def get_local_dir(self): This function will return the directory path of this recorder. """ if self.artifact_uri is not None: - local_dir_path = Path(self.artifact_uri.lstrip("file:")) / ".." + if platform.system() == "Windows": + local_dir_path = Path(self.artifact_uri.lstrip("file:").lstrip("/")).parent + else: + local_dir_path = Path(self.artifact_uri.lstrip("file:")).parent local_dir_path = str(local_dir_path.resolve()) if os.path.isdir(local_dir_path): return local_dir_path diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000000..b406f824c8 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,3 @@ +[metadata] +name = qlib +version = attr: qlib.__version__ diff --git a/setup.py b/setup.py deleted file mode 100644 index bce3b812d0..0000000000 --- a/setup.py +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import os -import numpy - -from setuptools import find_packages, setup, Extension - - -def read(rel_path: str) -> str: - here = os.path.abspath(os.path.dirname(__file__)) - with open(os.path.join(here, rel_path), encoding="utf-8") as fp: - return fp.read() - - -def get_version(rel_path: str) -> str: - for line in read(rel_path).splitlines(): - if line.startswith("__version__"): - delim = '"' if '"' in line else "'" - return line.split(delim)[1] - raise RuntimeError("Unable to find version string.") - - -# Package meta-data. -NAME = "pyqlib" -DESCRIPTION = "A Quantitative-research Platform" -REQUIRES_PYTHON = ">=3.5.0" - -VERSION = get_version("qlib/__init__.py") - -# Detect Cython -try: - import Cython - - ver = Cython.__version__ - _CYTHON_INSTALLED = ver >= "0.28" -except ImportError: - _CYTHON_INSTALLED = False - -if not _CYTHON_INSTALLED: - print("Required Cython version >= 0.28 is not detected!") - print('Please run "pip install --upgrade cython" first.') - exit(-1) - -# What packages are required for this module to be executed? -# `estimator` may depend on other packages. In order to reduce dependencies, it is not written here. -REQUIRED = [ - "numpy>=1.12.0, <1.24", - "pandas>=0.25.1", - "scipy>=1.7.3", - # scs is a dependency package, - # and the latest version of scs: scs-3.2.4.post3.tar.gz causes the documentation build to fail, - # so we have temporarily limited the version of scs. - "scs<=3.2.4", - "requests>=2.18.0", - "sacred>=0.7.4", - "python-socketio", - "redis>=3.0.1", - "python-redis-lock>=3.3.1", - "schedule>=0.6.0", - "cvxpy>=1.0.21", - "hyperopt==0.1.2", - "fire>=0.3.1", - "statsmodels", - "xlrd>=1.0.0", - "plotly>=4.12.0", - "matplotlib>=3.3", - "tables>=3.6.1", - "pyyaml>=5.3.1", - # To ensure stable operation of the experiment manager, we have limited the version of mlflow, - # and we need to verify whether version 2.0 of mlflow can serve qlib properly. - "mlflow>=1.12.1, <=1.30.0", - # mlflow 1.30.0 requires packaging<22, so we limit the packaging version, otherwise the CI will fail. - "packaging<22", - "tqdm", - "loguru", - "lightgbm>=3.3.0", - "tornado", - "joblib>=0.17.0", - # With the upgrading of ruamel.yaml to 0.18, the safe_load method was deprecated, - # which would cause qlib.workflow.cli to not work properly, - # and no good replacement has been found, so the version of ruamel.yaml has been restricted for now. - # Refs: https://pypi.org/project/ruamel.yaml/ - "ruamel.yaml<=0.17.36", - "pymongo==3.7.2", # For task management - "scikit-learn>=0.22", - "dill", - "dataclasses;python_version<'3.7'", - "filelock", - "jinja2", - "gym", - # Installing the latest version of protobuf for python versions below 3.8 will cause unit tests to fail. - "protobuf<=3.20.1;python_version<='3.8'", - "cryptography", -] - -# Numpy include -NUMPY_INCLUDE = numpy.get_include() - -here = os.path.abspath(os.path.dirname(__file__)) - -with open(os.path.join(here, "README.md"), encoding="utf-8") as f: - long_description = f.read() - - -# Cython Extensions -extensions = [ - Extension( - "qlib.data._libs.rolling", - ["qlib/data/_libs/rolling.pyx"], - language="c++", - include_dirs=[NUMPY_INCLUDE], - ), - Extension( - "qlib.data._libs.expanding", - ["qlib/data/_libs/expanding.pyx"], - language="c++", - include_dirs=[NUMPY_INCLUDE], - ), -] - -# Where the magic happens: -setup( - name=NAME, - version=VERSION, - license="MIT Licence", - url="https://github.com/microsoft/qlib", - description=DESCRIPTION, - long_description=long_description, - long_description_content_type="text/markdown", - python_requires=REQUIRES_PYTHON, - packages=find_packages(exclude=("tests",)), - # if your package is a single module, use this instead of 'packages': - # py_modules=['qlib'], - entry_points={ - # 'console_scripts': ['mycli=mymodule:cli'], - "console_scripts": [ - "qrun=qlib.workflow.cli:run", - ], - }, - ext_modules=extensions, - install_requires=REQUIRED, - extras_require={ - "dev": [ - "coverage", - "pytest>=3", - "sphinx", - "sphinx_rtd_theme", - "pre-commit", - # CI dependencies - "wheel", - "setuptools", - "black", - # Version 3.0 of pylint had problems with the build process, so we limited the version of pylint. - "pylint<=2.17.6", - # Using the latest versions(0.981 and 0.982) of mypy, - # the error "multiprocessing.Value()" is detected in the file "qlib/rl/utils/data_queue.py", - # If this is fixed in a subsequent version of mypy, then we will revert to the latest version of mypy. - # References: https://github.com/python/typeshed/issues/8799 - "mypy<0.981", - "flake8", - "nbqa", - "jupyter", - "nbconvert", - # The 5.0.0 version of importlib-metadata removed the deprecated endpoint, - # which prevented flake8 from working properly, so we restricted the version of importlib-metadata. - # To help ensure the dependencies of flake8 https://github.com/python/importlib_metadata/issues/406 - "importlib-metadata<5.0.0", - "readthedocs_sphinx_ext", - "cmake", - "lxml", - "baostock", - "yahooquery", - # 2024-05-30 scs has released a new version: 3.2.4.post2, - # this version, causes qlib installation to fail, so we've limited the scs version a bit for now. - "scs<=3.2.4", - "beautifulsoup4", - # In version 0.4.11 of tianshou, the code: - # logits, hidden = self.actor(batch.obs, state=state, info=batch.info) - # was changed in PR787, - # which causes pytest errors(AttributeError: 'dict' object has no attribute 'info') in CI, - # so we restricted the version of tianshou. - # References: - # https://github.com/thu-ml/tianshou/releases - "tianshou<=0.4.10", - "gym>=0.24", # If you do not put gym at the end, gym will degrade causing pytest results to fail. - ], - "rl": [ - "tianshou<=0.4.10", - "torch", - ], - }, - include_package_data=True, - classifiers=[ - # Trove classifiers - # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers - # 'License :: OSI Approved :: MIT License', - "Operating System :: POSIX :: Linux", - "Operating System :: Microsoft :: Windows", - "Operating System :: MacOS", - "License :: OSI Approved :: MIT License", - "Development Status :: 3 - Alpha", - "Programming Language :: Python", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - ], -) diff --git a/tests/data_mid_layer_tests/test_dataloader.py b/tests/data_mid_layer_tests/test_dataloader.py index 4d057be4fc..53a58c62b2 100644 --- a/tests/data_mid_layer_tests/test_dataloader.py +++ b/tests/data_mid_layer_tests/test_dataloader.py @@ -16,7 +16,7 @@ class TestDataLoader(unittest.TestCase): def test_nested_data_loader(self): - qlib.init() + qlib.init(kernels=1) nd = NestedDataLoader( dataloader_l=[ { @@ -30,7 +30,7 @@ def test_nested_data_loader(self): ) # Of course you can use StaticDataLoader - dataset = nd.load() + dataset = nd.load(start_time="2020-01-01", end_time="2020-01-31") assert dataset is not None diff --git a/tests/dependency_tests/test_mlflow.py b/tests/dependency_tests/test_mlflow.py index 578376a857..4b4d0105ba 100644 --- a/tests/dependency_tests/test_mlflow.py +++ b/tests/dependency_tests/test_mlflow.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import unittest +import platform import mlflow import time from pathlib import Path @@ -26,7 +27,10 @@ def test_creating_client(self): _ = mlflow.tracking.MlflowClient(tracking_uri=str(self.TMP_PATH)) end = time.time() elapsed = end - start - self.assertLess(elapsed, 1e-2) # it can be done in less than 10ms + if platform.system() == "Linux": + self.assertLess(elapsed, 1e-2) # it can be done in less than 10ms + else: + self.assertLess(elapsed, 2e-2) print(elapsed) diff --git a/tests/misc/test_index_data.py b/tests/misc/test_index_data.py index b3045a5c7f..89fccb4d91 100644 --- a/tests/misc/test_index_data.py +++ b/tests/misc/test_index_data.py @@ -70,7 +70,7 @@ def test_sorting(self): print(sd.loc[:"c"]) def test_corner_cases(self): - sd = idd.MultiData([[1, 2], [3, np.NaN]], index=["foo", "bar"], columns=["f", "g"]) + sd = idd.MultiData([[1, 2], [3, np.nan]], index=["foo", "bar"], columns=["f", "g"]) print(sd) self.assertTrue(np.isnan(sd.loc["bar", "g"])) diff --git a/tests/model/test_general_nn.py b/tests/model/test_general_nn.py index dd695efcc5..d67ad6eaf8 100644 --- a/tests/model/test_general_nn.py +++ b/tests/model/test_general_nn.py @@ -50,6 +50,8 @@ def test_both_dataset(self): model_l = [ GeneralPTNN( n_epochs=2, + batch_size=32, + n_jobs=0, pt_model_uri="qlib.contrib.model.pytorch_gru_ts.GRUModel", pt_model_kwargs={ "d_feat": 3, @@ -60,6 +62,8 @@ def test_both_dataset(self): ), GeneralPTNN( n_epochs=2, + batch_size=32, + n_jobs=0, pt_model_uri="qlib.contrib.model.pytorch_nn.Net", # it is a MLP pt_model_kwargs={ "input_dim": 3, diff --git a/tests/test_pit.py b/tests/test_pit.py index 8320e1d361..548f91baaa 100644 --- a/tests/test_pit.py +++ b/tests/test_pit.py @@ -8,7 +8,6 @@ import unittest import pytest import pandas as pd -import baostock as bs from pathlib import Path from qlib.data import D