From c9ed050ef034fe6519c14b59f3d207abcb693282 Mon Sep 17 00:00:00 2001 From: cyncyw <47289405+taozhiwang@users.noreply.github.com> Date: Thu, 11 Jul 2024 17:59:18 +0800 Subject: [PATCH] Ptnn4both datatypes and alignment tests (#1827) * Init model for both dataset * Remove some deprecated code * Add model template; * We must align with previous results * We choose another mode as the initial version * Almost success to run GRU * Successfully run training * Passed general_nn test * gru test * Alignment test passed * comment * fix readme & minor errors * general nn updates & benchmarks * Update examples/benchmarks/GeneralPtNN/workflow_config_gru2mlp.yaml --------- Co-authored-by: Young Co-authored-by: you-n-g --- examples/benchmarks/GeneralPtNN/README.md | 19 + .../GeneralPtNN/workflow_config_gru.yaml | 100 +++++ .../GeneralPtNN/workflow_config_gru2mlp.yaml | 93 +++++ .../GeneralPtNN/workflow_config_mlp.yaml | 98 +++++ qlib/contrib/model/pytorch_general_nn.py | 353 ++++++++++++++++++ qlib/contrib/model/pytorch_gru.py | 1 - tests/model/test_general_nn.py | 76 ++++ 7 files changed, 739 insertions(+), 1 deletion(-) create mode 100644 examples/benchmarks/GeneralPtNN/README.md create mode 100755 examples/benchmarks/GeneralPtNN/workflow_config_gru.yaml create mode 100644 examples/benchmarks/GeneralPtNN/workflow_config_gru2mlp.yaml create mode 100644 examples/benchmarks/GeneralPtNN/workflow_config_mlp.yaml create mode 100644 qlib/contrib/model/pytorch_general_nn.py create mode 100644 tests/model/test_general_nn.py diff --git a/examples/benchmarks/GeneralPtNN/README.md b/examples/benchmarks/GeneralPtNN/README.md new file mode 100644 index 0000000000..9778322204 --- /dev/null +++ b/examples/benchmarks/GeneralPtNN/README.md @@ -0,0 +1,19 @@ + + +# Introduction + +What is GeneralPtNN +- Fix previous design that fail to support both Time-series and tabular data +- Now you can just replace the Pytorch model structure to run a NN model. + +We provide an example to demonstrate the effectiveness of the current design. +- `workflow_config_gru.yaml` align with previous results [GRU(Kyunghyun Cho, et al.)](../README.md#Alpha158-dataset) + - `workflow_config_gru2mlp.yaml` to demonstrate we can convert config from time-series to tabular data with minimal changes + - You only have to change the net & dataset class to make the conversion. +- `workflow_config_mlp.yaml` achieved similar functionality with [MLP](../README.md#Alpha158-dataset) + +# TODO + +- We will align existing models to current design. + +- The result of `workflow_config_mlp.yaml` is different with the result of [MLP](../README.md#Alpha158-dataset) since GeneralPtNN has a different stopping method compared to previous implementations. Specificly, GeneralPtNN controls training according to epoches, whereas previous methods controlled by max_steps. diff --git a/examples/benchmarks/GeneralPtNN/workflow_config_gru.yaml b/examples/benchmarks/GeneralPtNN/workflow_config_gru.yaml new file mode 100755 index 0000000000..74900fc3fd --- /dev/null +++ b/examples/benchmarks/GeneralPtNN/workflow_config_gru.yaml @@ -0,0 +1,100 @@ +qlib_init: + provider_uri: "~/.qlib/qlib_data/cn_data" + region: cn +market: &market csi300 +benchmark: &benchmark SH000300 +data_handler_config: &data_handler_config + start_time: 2008-01-01 + end_time: 2020-08-01 + fit_start_time: 2008-01-01 + fit_end_time: 2014-12-31 + instruments: *market + infer_processors: + - class: FilterCol + kwargs: + fields_group: feature + col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10", + "ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5", + "RSQR20", "CORD60", "CORD10", "CORR20", "KLOW" + ] + - class: RobustZScoreNorm + kwargs: + fields_group: feature + clip_outlier: true + - class: Fillna + kwargs: + fields_group: feature + learn_processors: + - class: DropnaLabel + - class: CSRankNorm + kwargs: + fields_group: label + label: ["Ref($close, -2) / Ref($close, -1) - 1"] + +port_analysis_config: &port_analysis_config + strategy: + class: TopkDropoutStrategy + module_path: qlib.contrib.strategy + kwargs: + signal: + topk: 50 + n_drop: 5 + backtest: + start_time: 2017-01-01 + end_time: 2020-08-01 + account: 100000000 + benchmark: *benchmark + exchange_kwargs: + limit_threshold: 0.095 + deal_price: close + open_cost: 0.0005 + close_cost: 0.0015 + min_cost: 5 +task: + model: + class: GeneralPTNN + module_path: qlib.contrib.model.pytorch_general_nn + kwargs: + n_epochs: 200 + lr: 2e-4 + early_stop: 10 + batch_size: 800 + metric: loss + loss: mse + n_jobs: 20 + GPU: 0 + pt_model_uri: "qlib.contrib.model.pytorch_gru_ts.GRUModel" + pt_model_kwargs: { + "d_feat": 20, + "hidden_size": 64, + "num_layers": 2, + "dropout": 0., + } + dataset: + class: TSDatasetH + module_path: qlib.data.dataset + kwargs: + handler: + class: Alpha158 + module_path: qlib.contrib.data.handler + kwargs: *data_handler_config + segments: + train: [2008-01-01, 2014-12-31] + valid: [2015-01-01, 2016-12-31] + test: [2017-01-01, 2020-08-01] + step_len: 20 + record: + - class: SignalRecord + module_path: qlib.workflow.record_temp + kwargs: + model: + dataset: + - class: SigAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + ana_long_short: False + ann_scaler: 252 + - class: PortAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + config: *port_analysis_config diff --git a/examples/benchmarks/GeneralPtNN/workflow_config_gru2mlp.yaml b/examples/benchmarks/GeneralPtNN/workflow_config_gru2mlp.yaml new file mode 100644 index 0000000000..3c2e4fabb1 --- /dev/null +++ b/examples/benchmarks/GeneralPtNN/workflow_config_gru2mlp.yaml @@ -0,0 +1,93 @@ +qlib_init: + provider_uri: "~/.qlib/qlib_data/cn_data" + region: cn +market: &market csi300 +benchmark: &benchmark SH000300 +data_handler_config: &data_handler_config + start_time: 2008-01-01 + end_time: 2020-08-01 + fit_start_time: 2008-01-01 + fit_end_time: 2014-12-31 + instruments: *market + infer_processors: + - class: FilterCol + kwargs: + fields_group: feature + col_list: ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10", + "ROC60", "RESI10", "VSTD5", "RSQR60", "CORR60", "WVMA60", "STD5", + "RSQR20", "CORD60", "CORD10", "CORR20", "KLOW" + ] + - class: RobustZScoreNorm + kwargs: + fields_group: feature + clip_outlier: true + - class: Fillna + kwargs: + fields_group: feature + learn_processors: + - class: DropnaLabel + - class: CSRankNorm + kwargs: + fields_group: label + label: ["Ref($close, -2) / Ref($close, -1) - 1"] + +port_analysis_config: &port_analysis_config + strategy: + class: TopkDropoutStrategy + module_path: qlib.contrib.strategy + kwargs: + signal: + topk: 50 + n_drop: 5 + backtest: + start_time: 2017-01-01 + end_time: 2020-08-01 + account: 100000000 + benchmark: *benchmark + exchange_kwargs: + limit_threshold: 0.095 + deal_price: close + open_cost: 0.0005 + close_cost: 0.0015 + min_cost: 5 +task: + model: + class: GeneralPTNN + module_path: qlib.contrib.model.pytorch_general_nn + kwargs: + lr: 1e-3 + n_epochs: 1 + batch_size: 800 + loss: mse + optimizer: adam + pt_model_uri: "qlib.contrib.model.pytorch_nn.Net" + pt_model_kwargs: + input_dim: 20 + layers: [20,] + dataset: + class: DatasetH + module_path: qlib.data.dataset + kwargs: + handler: + class: Alpha158 + module_path: qlib.contrib.data.handler + kwargs: *data_handler_config + segments: + train: [2008-01-01, 2014-12-31] + valid: [2015-01-01, 2016-12-31] + test: [2017-01-01, 2020-08-01] + record: + - class: SignalRecord + module_path: qlib.workflow.record_temp + kwargs: + model: + dataset: + - class: SigAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + ana_long_short: False + ann_scaler: 252 + - class: PortAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + config: *port_analysis_config diff --git a/examples/benchmarks/GeneralPtNN/workflow_config_mlp.yaml b/examples/benchmarks/GeneralPtNN/workflow_config_mlp.yaml new file mode 100644 index 0000000000..d8567679c7 --- /dev/null +++ b/examples/benchmarks/GeneralPtNN/workflow_config_mlp.yaml @@ -0,0 +1,98 @@ +qlib_init: + provider_uri: "~/.qlib/qlib_data/cn_data" + region: cn +market: &market csi300 +benchmark: &benchmark SH000300 +data_handler_config: &data_handler_config + start_time: 2008-01-01 + end_time: 2020-08-01 + fit_start_time: 2008-01-01 + fit_end_time: 2014-12-31 + instruments: *market + infer_processors: [ + { + "class" : "DropCol", + "kwargs":{"col_list": ["VWAP0"]} + }, + { + "class" : "CSZFillna", + "kwargs":{"fields_group": "feature"} + } + ] + learn_processors: [ + { + "class" : "DropCol", + "kwargs":{"col_list": ["VWAP0"]} + }, + { + "class" : "DropnaProcessor", + "kwargs":{"fields_group": "feature"} + }, + "DropnaLabel", + { + "class": "CSZScoreNorm", + "kwargs": {"fields_group": "label"} + } + ] + process_type: "independent" + +port_analysis_config: &port_analysis_config + strategy: + class: TopkDropoutStrategy + module_path: qlib.contrib.strategy + kwargs: + signal: + topk: 50 + n_drop: 5 + backtest: + start_time: 2017-01-01 + end_time: 2020-08-01 + account: 100000000 + benchmark: *benchmark + exchange_kwargs: + limit_threshold: 0.095 + deal_price: close + open_cost: 0.0005 + close_cost: 0.0015 + min_cost: 5 +task: + model: + class: GeneralPTNN + module_path: qlib.contrib.model.pytorch_general_nn + kwargs: + # FIXME: wrong parameters. + lr: 2e-3 + batch_size: 8192 + loss: mse + weight_decay: 0.0002 + optimizer: adam + pt_model_uri: "qlib.contrib.model.pytorch_nn.Net" + pt_model_kwargs: + input_dim: 157 + dataset: + class: DatasetH + module_path: qlib.data.dataset + kwargs: + handler: + class: Alpha158 + module_path: qlib.contrib.data.handler + kwargs: *data_handler_config + segments: + train: [2008-01-01, 2014-12-31] + valid: [2015-01-01, 2016-12-31] + test: [2017-01-01, 2020-08-01] + record: + - class: SignalRecord + module_path: qlib.workflow.record_temp + kwargs: + model: + dataset: + - class: SigAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + ana_long_short: False + ann_scaler: 252 + - class: PortAnaRecord + module_path: qlib.workflow.record_temp + kwargs: + config: *port_analysis_config diff --git a/qlib/contrib/model/pytorch_general_nn.py b/qlib/contrib/model/pytorch_general_nn.py new file mode 100644 index 0000000000..696a20254f --- /dev/null +++ b/qlib/contrib/model/pytorch_general_nn.py @@ -0,0 +1,353 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import division +from __future__ import print_function + +from torch.utils.data import DataLoader + + +import numpy as np +import pandas as pd +from typing import Union +import copy + +import torch +import torch.optim as optim + +from qlib.data.dataset.weight import Reweighter + +from .pytorch_utils import count_parameters +from ...model.base import Model +from ...data.dataset import DatasetH, TSDatasetH +from ...data.dataset.handler import DataHandlerLP +from ...utils import ( + init_instance_by_config, + get_or_create_path, +) +from ...log import get_module_logger + +from ...model.utils import ConcatDataset + + +class GeneralPTNN(Model): + """ + Motivation: + We want to provide a Qlib General Pytorch Model Adaptor + You can reuse it for all kinds of Pytorch models. + It should include the training and predict process + + Parameters + ---------- + d_feat : int + input dimension for each time step + metric: str + the evaluation metric used in early stop + optimizer : str + optimizer name + GPU : str + the GPU ID(s) used for training + """ + + def __init__( + self, + n_epochs=200, + lr=0.001, + metric="", + batch_size=2000, + early_stop=20, + loss="mse", + weight_decay=0.0, + optimizer="adam", + n_jobs=10, + GPU=0, + seed=None, + pt_model_uri="qlib.contrib.model.pytorch_gru_ts.GRUModel", + pt_model_kwargs={ + "d_feat": 6, + "hidden_size": 64, + "num_layers": 2, + "dropout": 0.0, + }, + ): + # Set logger. + self.logger = get_module_logger("GeneralPTNN") + self.logger.info("GeneralPTNN pytorch version...") + + # set hyper-parameters. + self.n_epochs = n_epochs + self.lr = lr + self.metric = metric + self.batch_size = batch_size + self.early_stop = early_stop + self.optimizer = optimizer.lower() + self.loss = loss + self.weight_decay = weight_decay + self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu") + self.n_jobs = n_jobs + self.seed = seed + + self.pt_model_uri, self.pt_model_kwargs = pt_model_uri, pt_model_kwargs + self.dnn_model = init_instance_by_config({"class": pt_model_uri, "kwargs": pt_model_kwargs}) + + self.logger.info( + "GeneralPTNN parameters setting:" + "\nn_epochs : {}" + "\nlr : {}" + "\nmetric : {}" + "\nbatch_size : {}" + "\nearly_stop : {}" + "\noptimizer : {}" + "\nloss_type : {}" + "\ndevice : {}" + "\nn_jobs : {}" + "\nuse_GPU : {}" + "\nweight_decay : {}" + "\nseed : {}" + "\npt_model_uri: {}" + "\npt_model_kwargs: {}".format( + n_epochs, + lr, + metric, + batch_size, + early_stop, + optimizer.lower(), + loss, + self.device, + n_jobs, + self.use_gpu, + weight_decay, + seed, + pt_model_uri, + pt_model_kwargs, + ) + ) + + if self.seed is not None: + np.random.seed(self.seed) + torch.manual_seed(self.seed) + + self.logger.info("model:\n{:}".format(self.dnn_model)) + self.logger.info("model size: {:.4f} MB".format(count_parameters(self.dnn_model))) + + if optimizer.lower() == "adam": + self.train_optimizer = optim.Adam(self.dnn_model.parameters(), lr=self.lr, weight_decay=weight_decay) + elif optimizer.lower() == "gd": + self.train_optimizer = optim.SGD(self.dnn_model.parameters(), lr=self.lr, weight_decay=weight_decay) + else: + raise NotImplementedError("optimizer {} is not supported!".format(optimizer)) + + self.fitted = False + self.dnn_model.to(self.device) + + @property + def use_gpu(self): + return self.device != torch.device("cpu") + + def mse(self, pred, label, weight): + loss = weight * (pred - label) ** 2 + return torch.mean(loss) + + def loss_fn(self, pred, label, weight=None): + mask = ~torch.isnan(label) + + if weight is None: + weight = torch.ones_like(label) + + if self.loss == "mse": + return self.mse(pred[mask], label[mask], weight[mask]) + + raise ValueError("unknown loss `%s`" % self.loss) + + def metric_fn(self, pred, label): + mask = torch.isfinite(label) + + if self.metric in ("", "loss"): + return -self.loss_fn(pred[mask], label[mask]) + + raise ValueError("unknown metric `%s`" % self.metric) + + def _get_fl(self, data: torch.Tensor): + """ + get feature and label from data + - Handle the different data shape of time series and tabular data + + Parameters + ---------- + data : torch.Tensor + input data which maybe 3 dimension or 2 dimension + - 3dim: [batch_size, time_step, feature_dim] + - 2dim: [batch_size, feature_dim] + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + """ + if data.dim() == 3: + # it is a time series dataset + feature = data[:, :, 0:-1].to(self.device) + label = data[:, -1, -1].to(self.device) + elif data.dim() == 2: + # it is a tabular dataset + feature = data[:, 0:-1].to(self.device) + label = data[:, -1].to(self.device) + else: + raise ValueError("Unsupported data shape.") + return feature, label + + def train_epoch(self, data_loader): + self.dnn_model.train() + + for data, weight in data_loader: + feature, label = self._get_fl(data) + + pred = self.dnn_model(feature.float()) + loss = self.loss_fn(pred, label, weight.to(self.device)) + + self.train_optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_value_(self.dnn_model.parameters(), 3.0) + self.train_optimizer.step() + + def test_epoch(self, data_loader): + self.dnn_model.eval() + + scores = [] + losses = [] + + for data, weight in data_loader: + feature, label = self._get_fl(data) + + with torch.no_grad(): + pred = self.dnn_model(feature.float()) + loss = self.loss_fn(pred, label, weight.to(self.device)) + losses.append(loss.item()) + + score = self.metric_fn(pred, label) + scores.append(score.item()) + + return np.mean(losses), np.mean(scores) + + def fit( + self, + dataset: Union[DatasetH, TSDatasetH], + evals_result=dict(), + save_path=None, + reweighter=None, + ): + ists = isinstance(dataset, TSDatasetH) # is this time series dataset + + dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) + dl_valid = dataset.prepare("valid", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L) + if dl_train.empty or dl_valid.empty: + raise ValueError("Empty data from dataset, please check your dataset config.") + + if reweighter is None: + wl_train = np.ones(len(dl_train)) + wl_valid = np.ones(len(dl_valid)) + elif isinstance(reweighter, Reweighter): + wl_train = reweighter.reweight(dl_train) + wl_valid = reweighter.reweight(dl_valid) + else: + raise ValueError("Unsupported reweighter type.") + + # Preprocess for data. To align to Dataset Interface for DataLoader + if ists: + dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader + dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader + else: + # If it is a tabular, we convert the dataframe to numpy to be indexable by DataLoader + dl_train = dl_train.values + dl_valid = dl_valid.values + + train_loader = DataLoader( + ConcatDataset(dl_train, wl_train), + batch_size=self.batch_size, + shuffle=True, + num_workers=self.n_jobs, + drop_last=True, + ) + valid_loader = DataLoader( + ConcatDataset(dl_valid, wl_valid), + batch_size=self.batch_size, + shuffle=False, + num_workers=self.n_jobs, + drop_last=True, + ) + del dl_train, dl_valid, wl_train, wl_valid + + save_path = get_or_create_path(save_path) + + stop_steps = 0 + train_loss = 0 + best_score = -np.inf + best_epoch = 0 + evals_result["train"] = [] + evals_result["valid"] = [] + + # train + self.logger.info("training...") + self.fitted = True + + for step in range(self.n_epochs): + self.logger.info("Epoch%d:", step) + self.logger.info("training...") + self.train_epoch(train_loader) + self.logger.info("evaluating...") + train_loss, train_score = self.test_epoch(train_loader) + val_loss, val_score = self.test_epoch(valid_loader) + self.logger.info("train %.6f, valid %.6f" % (train_score, val_score)) + evals_result["train"].append(train_score) + evals_result["valid"].append(val_score) + + if step == 0: + best_param = copy.deepcopy(self.dnn_model.state_dict()) + if val_score > best_score: + best_score = val_score + stop_steps = 0 + best_epoch = step + best_param = copy.deepcopy(self.dnn_model.state_dict()) + else: + stop_steps += 1 + if stop_steps >= self.early_stop: + self.logger.info("early stop") + break + + self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch)) + self.dnn_model.load_state_dict(best_param) + torch.save(best_param, save_path) + + if self.use_gpu: + torch.cuda.empty_cache() + + def predict(self, dataset: Union[DatasetH, TSDatasetH]): + if not self.fitted: + raise ValueError("model is not fitted yet!") + + dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I) + + if isinstance(dataset, TSDatasetH): + dl_test.config(fillna_type="ffill+bfill") # process nan brought by dataloader + index = dl_test.get_index() + else: + # If it is a tabular, we convert the dataframe to numpy to be indexable by DataLoader + index = dl_test.index + dl_test = dl_test.values + + test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs) + self.dnn_model.eval() + preds = [] + + for data in test_loader: + feature, _ = self._get_fl(data) + feature = feature.to(self.device) + + with torch.no_grad(): + pred = self.dnn_model(feature.float()).detach().cpu().numpy() + + preds.append(pred) + + preds_concat = np.concatenate(preds) + if preds_concat.ndim != 1: + preds_concat = preds_concat.ravel() + + return pd.Series(preds_concat, index=index) diff --git a/qlib/contrib/model/pytorch_gru.py b/qlib/contrib/model/pytorch_gru.py index e0f883f094..3306115507 100755 --- a/qlib/contrib/model/pytorch_gru.py +++ b/qlib/contrib/model/pytorch_gru.py @@ -317,7 +317,6 @@ def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"): class GRUModel(nn.Module): - def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0): super().__init__() diff --git a/tests/model/test_general_nn.py b/tests/model/test_general_nn.py new file mode 100644 index 0000000000..dd695efcc5 --- /dev/null +++ b/tests/model/test_general_nn.py @@ -0,0 +1,76 @@ +import unittest +from qlib.tests import TestAutoData + + +class TestNN(TestAutoData): + def test_both_dataset(self): + try: + from qlib.contrib.model.pytorch_general_nn import GeneralPTNN + from qlib.data.dataset import DatasetH, TSDatasetH + from qlib.data.dataset.handler import DataHandlerLP + except ImportError: + print("Import error.") + return + + data_handler_config = { + "start_time": "2008-01-01", + "end_time": "2020-08-01", + "instruments": "csi300", + "data_loader": { + "class": "QlibDataLoader", # Assuming QlibDataLoader is a string reference to the class + "kwargs": { + "config": { + "feature": [["$high", "$close", "$low"], ["H", "C", "L"]], + "label": [["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"]], + }, + "freq": "day", + }, + }, + # TODO: processors + "learn_processors": [ + { + "class": "DropnaLabel", + }, + {"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}}, + ], + } + segments = { + "train": ["2008-01-01", "2014-12-31"], + "valid": ["2015-01-01", "2016-12-31"], + "test": ["2017-01-01", "2020-08-01"], + } + data_handler = DataHandlerLP(**data_handler_config) + + # time-series dataset + tsds = TSDatasetH(handler=data_handler, segments=segments) + + # tabular dataset + tbds = DatasetH(handler=data_handler, segments=segments) + + model_l = [ + GeneralPTNN( + n_epochs=2, + pt_model_uri="qlib.contrib.model.pytorch_gru_ts.GRUModel", + pt_model_kwargs={ + "d_feat": 3, + "hidden_size": 8, + "num_layers": 1, + "dropout": 0.0, + }, + ), + GeneralPTNN( + n_epochs=2, + pt_model_uri="qlib.contrib.model.pytorch_nn.Net", # it is a MLP + pt_model_kwargs={ + "input_dim": 3, + }, + ), + ] + + for ds, model in list(zip((tsds, tbds), model_l)): + model.fit(ds) # It works + model.predict(ds) # It works + + +if __name__ == "__main__": + unittest.main()