From fd1ba4c4f220a07c5db0c6664049a7853fc1afea Mon Sep 17 00:00:00 2001 From: SunsetWolf Date: Sat, 14 Dec 2024 23:37:04 +0800 Subject: [PATCH] fix pytest general nn error --- Makefile | 2 +- pyproject.toml | 2 +- qlib/contrib/model/pytorch_general_nn.py | 31 +++++++++++++++++++----- tests/model/test_general_nn.py | 4 +-- 4 files changed, 29 insertions(+), 10 deletions(-) diff --git a/Makefile b/Makefile index fd133df095..7f8c7a1c7d 100644 --- a/Makefile +++ b/Makefile @@ -55,7 +55,7 @@ prerequisite: echo "No shared library files found, building..."; \ pip install --upgrade setuptools wheel; \ python -m pip install cython; \ - python -m pip install "numpy>=1.24.0"; \ + python -m pip install "numpy<2.0.0"; \ python -c "from setuptools import setup, Extension; from Cython.Build import cythonize; import numpy; extensions = [Extension('qlib.data._libs.rolling', ['qlib/data/_libs/rolling.pyx'], language='c++', include_dirs=[numpy.get_include()]), Extension('qlib.data._libs.expanding', ['qlib/data/_libs/expanding.pyx'], language='c++', include_dirs=[numpy.get_include()])]; setup(ext_modules=cythonize(extensions, language_level='3'), script_args=['build_ext', '--inplace'])"; \ fi diff --git a/pyproject.toml b/pyproject.toml index 8f515429e4..84cce7637e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools", "cython", "numpy>=1.24.0"] +requires = ["setuptools", "cython", "numpy<2.0.0>"] build-backend = "setuptools.build_meta" [project] diff --git a/qlib/contrib/model/pytorch_general_nn.py b/qlib/contrib/model/pytorch_general_nn.py index 696a20254f..f7f5b51743 100644 --- a/qlib/contrib/model/pytorch_general_nn.py +++ b/qlib/contrib/model/pytorch_general_nn.py @@ -233,7 +233,15 @@ def fit( evals_result=dict(), save_path=None, reweighter=None, + batch_size=None, + n_jobs=None, ): + if batch_size is None: + batch_size = self.batch_size + + if n_jobs is None: + n_jobs = self.n_jobs + ists = isinstance(dataset, TSDatasetH) # is this time series dataset dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) @@ -261,16 +269,16 @@ def fit( train_loader = DataLoader( ConcatDataset(dl_train, wl_train), - batch_size=self.batch_size, + batch_size=batch_size, shuffle=True, - num_workers=self.n_jobs, + num_workers=n_jobs, drop_last=True, ) valid_loader = DataLoader( ConcatDataset(dl_valid, wl_valid), - batch_size=self.batch_size, + batch_size=batch_size, shuffle=False, - num_workers=self.n_jobs, + num_workers=n_jobs, drop_last=True, ) del dl_train, dl_valid, wl_train, wl_valid @@ -319,7 +327,18 @@ def fit( if self.use_gpu: torch.cuda.empty_cache() - def predict(self, dataset: Union[DatasetH, TSDatasetH]): + def predict( + self, + dataset: Union[DatasetH, TSDatasetH], + batch_size=None, + n_jobs=None, + ): + if batch_size is None: + batch_size = self.batch_size + + if n_jobs is None: + n_jobs = self.n_jobs + if not self.fitted: raise ValueError("model is not fitted yet!") @@ -333,7 +352,7 @@ def predict(self, dataset: Union[DatasetH, TSDatasetH]): index = dl_test.index dl_test = dl_test.values - test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs) + test_loader = DataLoader(dl_test, batch_size=batch_size, num_workers=n_jobs) self.dnn_model.eval() preds = [] diff --git a/tests/model/test_general_nn.py b/tests/model/test_general_nn.py index dd695efcc5..c7691506bb 100644 --- a/tests/model/test_general_nn.py +++ b/tests/model/test_general_nn.py @@ -68,8 +68,8 @@ def test_both_dataset(self): ] for ds, model in list(zip((tsds, tbds), model_l)): - model.fit(ds) # It works - model.predict(ds) # It works + model.fit(ds, batch_size=32, n_jobs=0) # It works + model.predict(ds, batch_size=32, n_jobs=0) # It works if __name__ == "__main__":