Skip to content

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] DeepARModel and TFTModel don't work on pytorch_lightning>=1.9.1 #26

Closed
1 task done
Mr-Geekman opened this issue Aug 14, 2023 · 1 comment
Closed
1 task done
Labels
bug Something isn't working priority/medium Medium priority task

Comments

@Mr-Geekman
Copy link

Issue by Mr-Geekman
Monday Feb 27, 2023 at 07:25 GMT
Originally opened as tinkoff-ai#1130


🐛 Bug Report

DeepARModel and TFTModel don't work on pytorch_lightning>=1.9.1.

Fitting fails with error:

AttributeError: 'tuple' object has no attribute 'items'

As I understand, it is connected to the issue in pytorch_forecasting library: 'tuple' object has no attribute 'items' in models.

Expected behavior

Everything works fine.

How To Reproduce

Script to check TFTModel (with DeepAR error is the same).

import pandas as pd
import numpy as np

from etna.datasets.tsdataset import TSDataset
from etna.pipeline import Pipeline
from etna.transforms import DateFlagsTransform
from etna.transforms import LagTransform
from etna.transforms import PytorchForecastingTransform
from pytorch_forecasting.data import GroupNormalizer
from etna.models.nn import TFTModel


original_df = pd.DataFrame(np.array([["2021-05-31", 1, 3],
                                     ["2021-06-07", 1, 6],
                                     ["2021-06-14", 1, 9],
                                     ["2021-06-21", 1, 12],
                                     ["2021-06-28", 1, 15]]),
                           columns=['timestamp', 'segment', 'target'])
original_df['timestamp'] = pd.to_datetime(original_df['timestamp'])
original_df['target'] = original_df['target'].astype(float)
df = TSDataset.to_dataset(original_df)
ts = TSDataset(df, freq="W-MON")

HORIZON = 1
transform_date = DateFlagsTransform(day_number_in_week=True, day_number_in_month=False, out_column="dateflag")
num_lags = 2
transform_lag = LagTransform(
    in_column="target",
    lags=[HORIZON + i for i in range(num_lags)],
    out_column="target_lag",
)

transform_tft = PytorchForecastingTransform(
    max_encoder_length=HORIZON,
    max_prediction_length=HORIZON,
    time_varying_known_reals=["time_idx"],
    time_varying_unknown_reals=["target"],
    time_varying_known_categoricals=["dateflag_day_number_in_week"],
    static_categoricals=["segment"],
    target_normalizer=GroupNormalizer(groups=["segment"]),
)
model_tft = TFTModel(max_epochs=5, learning_rate=[0.1], gpus=0, batch_size=64)

pipeline_tft = Pipeline(
    model=model_tft,
    horizon=HORIZON,
    transforms=[transform_lag, transform_date, transform_tft],
)

pipeline_tft.fit(ts)

Script fails on pipline_tft.fit(ts) with error:

AttributeError: 'tuple' object has no attribute 'items'

Environment

No response

Additional context

No response

Checklist

  • Bug appears at the latest library version
@Mr-Geekman Mr-Geekman added bug Something isn't working priority/medium Medium priority task labels Aug 14, 2023
@Mr-Geekman
Copy link
Author

Comment by Mr-Geekman
Monday Apr 17, 2023 at 12:06 GMT


Package pytorch_forecasting was updated recently.
It looks like they probably solved the problem with pytorch_lightning there. But other packages requirements are very strict.

@Mr-Geekman Mr-Geekman moved this to Todo in etna board Aug 15, 2023
@d-a-bunin d-a-bunin moved this from Todo to Hold in etna board Feb 22, 2024
@etna-team etna-team locked and limited conversation to collaborators May 30, 2024
@d-a-bunin d-a-bunin converted this issue into discussion #371 May 30, 2024
@github-project-automation github-project-automation bot moved this from Hold to Done in etna board May 30, 2024

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

Labels
bug Something isn't working priority/medium Medium priority task
Projects
Status: Done
Development

No branches or pull requests

1 participant