From da78f7fd8631426bb00f15252812a71cc78e7288 Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Sat, 23 Mar 2024 17:42:37 +0000 Subject: [PATCH] Fix loader for the M5 dataset --- src/gluonts/dataset/repository/_m5.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/gluonts/dataset/repository/_m5.py b/src/gluonts/dataset/repository/_m5.py index 0e423a03e2..cea591dc5f 100644 --- a/src/gluonts/dataset/repository/_m5.py +++ b/src/gluonts/dataset/repository/_m5.py @@ -61,6 +61,7 @@ def generate_m5_dataset( "d", ], axis=1, + errors="ignore", ) cal_features["event_type_1"] = cal_features["event_type_1"].apply( lambda x: 0 if str(x) == "nan" else 1 @@ -112,9 +113,11 @@ def generate_m5_dataset( ] # Build target series - train_ids = sales_train_validation["id"] + train_ids = sales_train_validation["item_id"] train_df = sales_train_validation.drop( - ["id", "item_id", "dept_id", "cat_id", "store_id", "state_id"], axis=1 + ["id", "item_id", "dept_id", "cat_id", "store_id", "state_id"], + axis=1, + errors="ignore", ) test_target_values = train_df.values.copy() train_target_values = [ts[:-prediction_length] for ts in train_df.values]