Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Irregular time series support #1973

Open
wants to merge 34 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
88a88ce
initial TimeFeature
kashif May 11, 2022
a15f80b
add to deepar
kashif May 11, 2022
617f809
fix typo
kashif May 12, 2022
e586f3e
make_evaluation_predictions
kashif May 12, 2022
8380f78
to_pandas
kashif May 12, 2022
c40485c
fixed typoi
kashif May 12, 2022
37c4d7d
more typos
kashif May 12, 2022
b9c571e
set default index_field name
kashif May 12, 2022
96ec68e
add index to forecast generator
kashif May 12, 2022
aa3e85b
fix index name
kashif May 12, 2022
97f137b
missing arguments
kashif May 12, 2022
5c13d90
flake8
kashif May 12, 2022
2b3a14e
for irregular index do not clamp the timestamp
kashif May 16, 2022
1a4c5fb
concat index to prediction_length size
kashif May 16, 2022
c4bd002
added irregular time series test
kashif May 16, 2022
536cd4e
flake8
kashif May 16, 2022
8da6a8a
fix prediction_length call
kashif May 16, 2022
64e99e4
only assign if its not none
kashif May 16, 2022
45821a3
undo target size
kashif May 16, 2022
caa6442
Merge branch 'master' into irregular
kashif May 17, 2022
b5381bb
fix typo
kashif May 17, 2022
6ec69f4
added encode for DatetimeIndex
kashif May 19, 2022
4a3d344
array of indexes
kashif May 19, 2022
29c3c4c
Merge branch 'master' into irregular
kashif May 24, 2022
29ba7ae
Merge branch 'master' into irregular
kashif May 25, 2022
fb4bb0f
Merge branch 'master' into irregular
kashif May 31, 2022
5257b66
black
kashif May 31, 2022
23ac7c7
fix missing import
kashif May 31, 2022
528ac03
returns either PeriodIndex or DatetimeIndex
kashif May 31, 2022
39aee65
fix type hints
kashif May 31, 2022
1d319bc
use timedeltas for irregular age calculation
kashif May 31, 2022
5b0ccca
use components
kashif May 31, 2022
afcaf9f
fix typo
kashif May 31, 2022
8793f52
Merge branch 'dev' into irregular
kashif Aug 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/gluonts/core/serde/pd.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,20 @@ def encode_pd_timestamp(v: pd.Timestamp) -> Any:
}


@encode.register(pd.DatetimeIndex)
def encode_pd_datetime_index(v: pd.DatetimeIndex) -> Any:
"""
Specializes :func:`encode` for invocations where ``v`` is an instance of
the :class:`~pandas.DatetimeIndex` class.
"""
return {
"__kind__": Kind.Instance,
"class": "pandas.DatetimeIndex",
"args": [encode([str(indx) for indx in v])],
"kwargs": {"freq": v.freqstr if v.freq else None},
}


@encode.register(pd.Period)
def encode_pd_period(v: pd.Period) -> Any:
"""
Expand Down
2 changes: 2 additions & 0 deletions src/gluonts/dataset/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,8 @@ def __call__(self, data: DataEntry) -> DataEntry:
data[self.name] = pd.Timestamp(data[self.name])
else:
data[self.name] = _as_period(data[self.name], self.freq)
if FieldName.INDEX in data:
data[FieldName.INDEX] = pd.DatetimeIndex(data[FieldName.INDEX])
except (TypeError, ValueError) as e:
raise GluonTSDataError(
f'Error "{e}" occurred, when reading field "{self.name}"'
Expand Down
2 changes: 2 additions & 0 deletions src/gluonts/dataset/field_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,5 @@ class FieldName:
FORECAST_START = "forecast_start"

TARGET_DIM_INDICATOR = "target_dimension_indicator"

INDEX = "index"
10 changes: 6 additions & 4 deletions src/gluonts/dataset/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ def to_pandas(entry: DataEntry, freq: Optional[str] = None) -> pd.Series:
pandas.Series
Pandas time series object.
"""
return pd.Series(
entry[FieldName.TARGET],
index=period_index(entry, freq=freq),
)
if FieldName.INDEX in entry:
index = entry[FieldName.INDEX]
else:
index = period_index(entry, freq=freq)

return pd.Series(entry[FieldName.TARGET], index=index)
15 changes: 14 additions & 1 deletion src/gluonts/model/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ def plot(
)

@property
def index(self) -> pd.PeriodIndex:
def index(self) -> Union[pd.DatetimeIndex, pd.PeriodIndex]:
if self._index is None:
self._index = pd.period_range(
self.start_date,
Expand Down Expand Up @@ -408,6 +408,8 @@ class SampleForecast(Forecast):
info
additional information that the forecaster may provide e.g. estimated
parameters, number of iterations ran etc.
index
optional datatime index of the forecast for irregular time series.
"""

@validated()
Expand All @@ -417,6 +419,7 @@ def __init__(
start_date: pd.Period,
item_id: Optional[str] = None,
info: Optional[Dict] = None,
index: Optional[Union[pd.DatetimeIndex, pd.PeriodIndex]] = None,
) -> None:
assert isinstance(
samples, np.ndarray
Expand All @@ -431,6 +434,8 @@ def __init__(
self._dim = None
self.item_id = item_id
self.info = info
if index is not None:
self._index = index[-self.prediction_length :]

assert isinstance(
start_date, pd.Period
Expand Down Expand Up @@ -494,6 +499,7 @@ def copy_dim(self, dim: int) -> "SampleForecast":
start_date=self.start_date,
item_id=self.item_id,
info=self.info,
index=self.index,
)

def copy_aggregate(self, agg_fun: Callable) -> "SampleForecast":
Expand All @@ -507,6 +513,7 @@ def copy_aggregate(self, agg_fun: Callable) -> "SampleForecast":
start_date=self.start_date,
item_id=self.item_id,
info=self.info,
index=self.index,
)

def dim(self) -> int:
Expand Down Expand Up @@ -543,6 +550,7 @@ def to_quantile_forecast(self, quantiles: List[str]) -> "QuantileForecast":
forecast_keys=quantiles,
item_id=self.item_id,
info=self.info,
index=self.index,
)


Expand All @@ -563,6 +571,8 @@ class QuantileForecast(Forecast):
info
additional information that the forecaster may provide e.g. estimated
parameters, number of iterations ran etc.
index
optional datatime index of the forecast for irregular time series.
"""

def __init__(
Expand All @@ -572,6 +582,7 @@ def __init__(
forecast_keys: List[str],
item_id: Optional[str] = None,
info: Optional[Dict] = None,
index: Optional[Union[pd.DatetimeIndex, pd.PeriodIndex]] = None,
) -> None:
self.forecast_array = forecast_arrays
assert isinstance(
Expand All @@ -598,6 +609,8 @@ def __init__(
k: self.forecast_array[i] for i, k in enumerate(self.forecast_keys)
}
self._nan_out = np.array([np.nan] * self.prediction_length)
if index is not None:
self._index = index[-self.prediction_length :]

def quantile(self, inference_quantile: Union[float, str]) -> np.ndarray:
sorted_forecast_dict = dict(sorted(self._forecast_dict.items()))
Expand Down
9 changes: 9 additions & 0 deletions src/gluonts/model/forecast_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def __call__(
else None,
info=batch["info"][i] if "info" in batch else None,
forecast_keys=self.quantiles,
index=batch[FieldName.INDEX][i]
if FieldName.INDEX in batch
else None,
)
assert i + 1 == len(batch[FieldName.FORECAST_START])

Expand Down Expand Up @@ -179,6 +182,9 @@ def __call__(
if FieldName.ITEM_ID in batch
else None,
info=batch["info"][i] if "info" in batch else None,
index=batch[FieldName.INDEX][i]
if FieldName.INDEX in batch
else None,
)
assert i + 1 == len(batch[FieldName.FORECAST_START])

Expand Down Expand Up @@ -219,5 +225,8 @@ def __call__(
if FieldName.ITEM_ID in batch
else None,
info=batch["info"][i] if "info" in batch else None,
index=batch[FieldName.INDEX][i]
if FieldName.INDEX in batch
else None,
)
assert i + 1 == len(batch[FieldName.FORECAST_START])
5 changes: 5 additions & 0 deletions src/gluonts/mx/model/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(
start_date: pd.Period,
item_id: Optional[str] = None,
info: Optional[Dict] = None,
index: Optional[pd.DatetimeIndex] = None,
) -> None:
self.distribution = distribution
self.shape = (
Expand All @@ -65,6 +66,8 @@ def __init__(
self.prediction_length = self.shape[0]
self.item_id = item_id
self.info = info
if index is not None:
self._index = index[-self.prediction_length :]

assert isinstance(
start_date, pd.Period
Expand Down Expand Up @@ -102,6 +105,7 @@ def to_sample_forecast(self, num_samples: int = 200) -> SampleForecast:
start_date=self.start_date,
item_id=self.item_id,
info=self.info,
index=self.index,
)

def to_quantile_forecast(self, quantiles: List[Union[float, str]]):
Expand All @@ -111,4 +115,5 @@ def to_quantile_forecast(self, quantiles: List[Union[float, str]]):
start_date=self.start_date,
item_id=self.item_id,
info=self.info,
index=self.index,
)
4 changes: 4 additions & 0 deletions src/gluonts/torch/model/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
start_date: pd.Period,
item_id: Optional[str] = None,
info: Optional[Dict] = None,
index: Optional[Union[pd.DatetimeIndex, pd.PeriodIndex]] = None,
) -> None:
self.distribution = distribution
self.shape = distribution.batch_shape + distribution.event_shape
Expand All @@ -65,6 +66,8 @@ def __init__(
self.start_date = start_date

self._mean = None
if index is not None:
self._index = index[-self.prediction_length :]

@property
def mean(self) -> np.ndarray:
Expand Down Expand Up @@ -100,4 +103,5 @@ def to_sample_forecast(self, num_samples: int = 200) -> SampleForecast:
start_date=self.start_date,
item_id=self.item_id,
info=self.info,
index=self.index,
)
51 changes: 39 additions & 12 deletions src/gluonts/transform/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,8 @@ class AddTimeFeatures(MapTransformation):
list of time features to use.
pred_length
Prediction length
index_field:
Field with the array containing the datetime index for irregular data.
"""

@validated()
Expand All @@ -342,26 +344,31 @@ def __init__(
output_field: str,
time_features: List[TimeFeature],
pred_length: int,
index_field: Optional[str] = FieldName.INDEX,
dtype: Type = np.float32,
) -> None:
self.date_features = time_features
self.pred_length = pred_length
self.start_field = start_field
self.target_field = target_field
self.output_field = output_field
self.index_field = index_field
self.dtype = dtype

def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry:
if not self.date_features:
data[self.output_field] = None
return data

start = data[self.start_field]
length = target_transformation_length(
data[self.target_field], self.pred_length, is_train=is_train
)
if self.index_field in data:
index = data[self.index_field]
else:
start = data[self.start_field]
length = target_transformation_length(
data[self.target_field], self.pred_length, is_train=is_train
)

index = pd.period_range(start, periods=length, freq=start.freq)
index = pd.period_range(start, periods=length, freq=start.freq)

data[self.output_field] = np.vstack(
[feat(index) for feat in self.date_features]
Expand Down Expand Up @@ -411,15 +418,35 @@ def __init__(
self.dtype = dtype

def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry:
length = target_transformation_length(
data[self.target_field], self.pred_length, is_train=is_train
)

if self.log_scale:
age = np.log10(2.0 + np.arange(length, dtype=self.dtype))
if FieldName.INDEX in data:
length = len(data[FieldName.INDEX])
components = pd.TimedeltaIndex(
data[FieldName.INDEX] - data[FieldName.INDEX][0]
).components
base_freq = data[FieldName.START].freq
if base_freq == "ns":
age = components.nanoseconds.values.astype(self.dtype)
elif base_freq == "us":
age = components.microseconds.values.astype(self.dtype)
elif base_freq == "ms":
age = components.milliseconds.values.astype(self.dtype)
elif base_freq == "S":
age = components.seconds.values.astype(self.dtype)
elif base_freq == "min" or base_freq == "T":
age = components.minutes.values.astype(self.dtype)
elif base_freq == "H":
age = components.hours.values.astype(self.dtype)
else:
age = components.days.values.astype(self.dtype)
else:
length = target_transformation_length(
data[self.target_field], self.pred_length, is_train=is_train
)
age = np.arange(length, dtype=self.dtype)

if self.log_scale:
age = np.log10(2.0 + age)

data[self.feature_name] = age.reshape((1, length))

return data
Expand Down Expand Up @@ -500,7 +527,7 @@ def __init__(
)

def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry:
assert self.base_freq == data["start"].freq
assert self.base_freq == data[FieldName.START].freq

# convert to pandas Series for easier indexing and aggregation
if is_train:
Expand Down
3 changes: 2 additions & 1 deletion src/gluonts/transform/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Iterator, List, Optional, Tuple

import numpy as np
import pandas as pd
from pandas.tseries.offsets import BaseOffset

from gluonts.core.component import validated
Expand Down Expand Up @@ -136,7 +137,7 @@ def _split_instance(self, entry: DataEntry, idx: int) -> DataEntry:
if self.output_NTC:
past_piece = past_piece.transpose()
future_piece = future_piece.transpose()

entry[self._past(ts_field)] = past_piece
entry[self._future(ts_field)] = future_piece
del entry[ts_field]
Expand Down
Loading
Loading