diff --git a/data/augmentations/__init__.py b/data/augmentations/__init__.py index 801923c..040e23b 100644 --- a/data/augmentations/__init__.py +++ b/data/augmentations/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/data/data_utils.py b/data/data_utils.py index 7127bae..8645a52 100644 --- a/data/data_utils.py +++ b/data/data_utils.py @@ -18,6 +18,7 @@ import json import os from pathlib import Path + warnings.simplefilter(action="ignore", category=FutureWarning) warnings.simplefilter(action="ignore", category=UserWarning) from pathlib import Path @@ -30,14 +31,20 @@ from gluonts.transform import InstanceSampler from pandas.tseries.frequencies import to_offset -from data.read_new_dataset import get_ett_dataset, create_train_dataset_without_last_k_timesteps, TrainDatasets, MetaData +from data.read_new_dataset import ( + get_ett_dataset, + create_train_dataset_without_last_k_timesteps, + TrainDatasets, + MetaData, +) + class CombinedDatasetIterator: def __init__(self, datasets, seed, weights): self._datasets = [iter(el) for el in datasets] self._weights = weights self._rng = random.Random(seed) - + def __next__(self): (dataset,) = self._rng.choices(self._datasets, weights=self._weights, k=1) return next(dataset) @@ -105,15 +112,13 @@ def _count_timesteps( f"Too large difference between both timestamps ({left} and {right}) for _count_timesteps()." ) + from pathlib import Path from gluonts.dataset.common import ListDataset from gluonts.dataset.repository.datasets import get_dataset -def create_train_dataset_last_k_percentage( - raw_train_dataset, - freq, - k=100 -): + +def create_train_dataset_last_k_percentage(raw_train_dataset, freq, k=100): # Get training data train_data = [] for i, series in enumerate(raw_train_dataset): @@ -127,6 +132,7 @@ def create_train_dataset_last_k_percentage( return train_data + def create_train_and_val_datasets_with_dates( name, dataset_path, @@ -137,7 +143,7 @@ def create_train_and_val_datasets_with_dates( val_start_date=None, train_start_date=None, freq=None, - last_k_percentage=None + last_k_percentage=None, ): """ Train Start date is assumed to be the start of the series if not provided @@ -148,12 +154,19 @@ def create_train_and_val_datasets_with_dates( if name in ("ett_h1", "ett_h2", "ett_m1", "ett_m2"): path = os.path.join(dataset_path, "ett_datasets") raw_dataset = get_ett_dataset(name, path) - elif name in ("cpu_limit_minute", "cpu_usage_minute", \ - "function_delay_minute", "instances_minute", \ - "memory_limit_minute", "memory_usage_minute", \ - "platform_delay_minute", "requests_minute"): + elif name in ( + "cpu_limit_minute", + "cpu_usage_minute", + "function_delay_minute", + "instances_minute", + "memory_limit_minute", + "memory_usage_minute", + "platform_delay_minute", + "requests_minute", + ): path = os.path.join(dataset_path, "huawei/" + name + ".json") - with open(path, "r") as f: data = json.load(f) + with open(path, "r") as f: + data = json.load(f) metadata = MetaData(**data["metadata"]) train_data = [x for x in data["train"] if type(x["target"][0]) != str] test_data = [x for x in data["test"] if type(x["target"][0]) != str] @@ -167,8 +180,12 @@ def create_train_and_val_datasets_with_dates( metadata = MetaData(**data["metadata"]) train_test_data = [x for x in data["data"] if type(x["target"][0]) != str] full_dataset = ListDataset(train_test_data, freq=metadata.freq) - train_ds = create_train_dataset_without_last_k_timesteps(full_dataset, freq=metadata.freq, k=24) - raw_dataset = TrainDatasets(metadata=metadata, train=train_ds, test=full_dataset) + train_ds = create_train_dataset_without_last_k_timesteps( + full_dataset, freq=metadata.freq, k=24 + ) + raw_dataset = TrainDatasets( + metadata=metadata, train=train_ds, test=full_dataset + ) else: raw_dataset = get_dataset(name, path=Path(dataset_path)) @@ -257,9 +274,7 @@ def create_train_and_val_datasets_with_dates( ) -def create_test_dataset( - name, dataset_path, history_length, freq=None, data_id=None -): +def create_test_dataset(name, dataset_path, history_length, freq=None, data_id=None): """ For now, only window per series is used. make_evaluation_predictions automatically only predicts for the last "prediction_length" timesteps @@ -270,12 +285,19 @@ def create_test_dataset( if name in ("ett_h1", "ett_h2", "ett_m1", "ett_m2"): path = os.path.join(dataset_path, "ett_datasets") dataset = get_ett_dataset(name, path) - elif name in ("cpu_limit_minute", "cpu_usage_minute", \ - "function_delay_minute", "instances_minute", \ - "memory_limit_minute", "memory_usage_minute", \ - "platform_delay_minute", "requests_minute"): + elif name in ( + "cpu_limit_minute", + "cpu_usage_minute", + "function_delay_minute", + "instances_minute", + "memory_limit_minute", + "memory_usage_minute", + "platform_delay_minute", + "requests_minute", + ): path = os.path.join(dataset_path, "huawei/" + name + ".json") - with open(path, "r") as f: data = json.load(f) + with open(path, "r") as f: + data = json.load(f) metadata = MetaData(**data["metadata"]) train_data = [x for x in data["train"] if type(x["target"][0]) != str] test_data = [x for x in data["test"] if type(x["target"][0]) != str] @@ -289,7 +311,9 @@ def create_test_dataset( metadata = MetaData(**data["metadata"]) train_test_data = [x for x in data["data"] if type(x["target"][0]) != str] full_dataset = ListDataset(train_test_data, freq=metadata.freq) - train_ds = create_train_dataset_without_last_k_timesteps(full_dataset, freq=metadata.freq, k=24) + train_ds = create_train_dataset_without_last_k_timesteps( + full_dataset, freq=metadata.freq, k=24 + ) dataset = TrainDatasets(metadata=metadata, train=train_ds, test=full_dataset) else: dataset = get_dataset(name, path=Path(dataset_path)) @@ -317,4 +341,4 @@ def create_test_dataset( series_copy["data_id"] = data_id data.append(series_copy) total_points += len(data[-1]["target"]) - return ListDataset(data, freq=freq), prediction_length, total_points \ No newline at end of file + return ListDataset(data, freq=freq), prediction_length, total_points diff --git a/data/dataset_list.py b/data/dataset_list.py index 70c3c94..a975a88 100644 --- a/data/dataset_list.py +++ b/data/dataset_list.py @@ -12,4 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. -ALL_DATASETS = ["australian_electricity_demand", "electricity_hourly", "london_smart_meters_without_missing", "solar_10_minutes", "wind_farms_without_missing", "pedestrian_counts", "uber_tlc_hourly", "traffic", "kdd_cup_2018_without_missing", "saugeenday", "sunspot_without_missing", "exchange_rate", "cpu_limit_minute", "cpu_usage_minute", "function_delay_minute", "instances_minute", "memory_limit_minute", "memory_usage_minute", "platform_delay_minute", "requests_minute", "ett_h1", "ett_h2", "ett_m1", "ett_m2", "beijing_pm25", "AirQualityUCI", "beijing_multisite"] \ No newline at end of file +ALL_DATASETS = [ + "australian_electricity_demand", + "electricity_hourly", + "london_smart_meters_without_missing", + "solar_10_minutes", + "wind_farms_without_missing", + "pedestrian_counts", + "uber_tlc_hourly", + "traffic", + "kdd_cup_2018_without_missing", + "saugeenday", + "sunspot_without_missing", + "exchange_rate", + "cpu_limit_minute", + "cpu_usage_minute", + "function_delay_minute", + "instances_minute", + "memory_limit_minute", + "memory_usage_minute", + "platform_delay_minute", + "requests_minute", + "ett_h1", + "ett_h2", + "ett_m1", + "ett_m2", + "beijing_pm25", + "AirQualityUCI", + "beijing_multisite", +] diff --git a/data/read_new_dataset.py b/data/read_new_dataset.py index f46665d..68dfe2d 100644 --- a/data/read_new_dataset.py +++ b/data/read_new_dataset.py @@ -13,6 +13,7 @@ # limitations under the License. import warnings + warnings.simplefilter(action="ignore", category=FutureWarning) warnings.simplefilter(action="ignore", category=UserWarning) @@ -22,52 +23,61 @@ from gluonts.dataset.repository.datasets import get_dataset import os -def create_train_dataset_without_last_k_timesteps( - raw_train_dataset, - freq, - k=0 -): + +def create_train_dataset_without_last_k_timesteps(raw_train_dataset, freq, k=0): train_data = [] for i, series in enumerate(raw_train_dataset): s_train = series.copy() - s_train["target"] = s_train["target"][:len(s_train["target"])-k] + s_train["target"] = s_train["target"][: len(s_train["target"]) - k] train_data.append(s_train) train_data = ListDataset(train_data, freq=freq) return train_data + def load_jsonl_gzip_file(file_path): - with gzip.open(file_path, 'rt') as f: + with gzip.open(file_path, "rt") as f: return [json.loads(line) for line in f] + def get_ett_dataset(dataset_name, path): dataset_path = Path(path) / dataset_name - metadata_path = dataset_path / 'metadata.json' - with open(metadata_path, 'r') as f: + metadata_path = dataset_path / "metadata.json" + with open(metadata_path, "r") as f: metadata_dict = json.load(f) metadata = MetaData(**metadata_dict) # Load train and test datasets - train_data_path = dataset_path / 'train' / 'data.json.gz' - test_data_path = dataset_path / 'test' / 'data.json.gz' + train_data_path = dataset_path / "train" / "data.json.gz" + test_data_path = dataset_path / "test" / "data.json.gz" # test dataset test_data = load_jsonl_gzip_file(test_data_path) # Create GluonTS ListDatasets test_ds = ListDataset(test_data, freq=metadata.freq) - train_ds = create_train_dataset_without_last_k_timesteps(test_ds, freq=metadata.freq, k=24) + train_ds = create_train_dataset_without_last_k_timesteps( + test_ds, freq=metadata.freq, k=24 + ) return TrainDatasets(metadata=metadata, train=train_ds, test=test_ds) + if __name__ == "__main__": dataset_name = "ett_h1" if dataset_name in ("ett_h1", "ett_h2", "ett_m1", "ett_m2"): path = "data/datasets/ett_datasets" ds = get_ett_dataset(dataset_name, path) - - if dataset_name in ("cpu_limit_minute", "cpu_usage_minute", \ - "function_delay_minute", "instances_minute", \ - "memory_limit_minute", "memory_usage_minute", \ - "platform_delay_minute", "requests_minute"): + + if dataset_name in ( + "cpu_limit_minute", + "cpu_usage_minute", + "function_delay_minute", + "instances_minute", + "memory_limit_minute", + "memory_usage_minute", + "platform_delay_minute", + "requests_minute", + ): path = "data/datasets/huawei/" + dataset_name + ".json" - with open(path, "r") as f: data = json.load(f) + with open(path, "r") as f: + data = json.load(f) metadata = MetaData(**data["metadata"]) train_data = [x for x in data["train"] if type(x["target"][0]) != str] test_data = [x for x in data["test"] if type(x["target"][0]) != str] @@ -82,5 +92,7 @@ def get_ett_dataset(dataset_name, path): metadata = MetaData(**data["metadata"]) train_test_data = [x for x in data["data"] if type(x["target"][0]) != str] full_dataset = ListDataset(train_test_data, freq=metadata.freq) - train_ds = create_train_dataset_without_last_k_timesteps(test_ds, freq=metadata.freq, k=24) - ds = TrainDatasets(metadata=metadata, train=train_ds, test=full_dataset) \ No newline at end of file + train_ds = create_train_dataset_without_last_k_timesteps( + test_ds, freq=metadata.freq, k=24 + ) + ds = TrainDatasets(metadata=metadata, train=train_ds, test=full_dataset) diff --git a/gluon_utils/gluon_ts_distributions/implicit_quantile_network.py b/gluon_utils/gluon_ts_distributions/implicit_quantile_network.py deleted file mode 100644 index 388ac38..0000000 --- a/gluon_utils/gluon_ts_distributions/implicit_quantile_network.py +++ /dev/null @@ -1,218 +0,0 @@ -# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"). -# You may not use this file except in compliance with the License. -# A copy of the License is located at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# or in the "license" file accompanying this file. This file is distributed -# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -# express or implied. See the License for the specific language governing -# permissions and limitations under the License. - -from functools import partial -from typing import Callable, Dict, Optional, Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.distributions import Beta, Distribution, constraints - -from gluonts.core.component import validated -from gluonts.torch.distributions import DistributionOutput -from gluonts.torch.modules.lambda_layer import LambdaLayer - - -class QuantileLayer(nn.Module): - r""" - Implicit Quantile Layer from the paper ``IQN for Distributional - Reinforcement Learning`` (https://arxiv.org/abs/1806.06923) by - Dabney et al. 2018. - """ - - def __init__(self, num_output: int, cos_embedding_dim: int = 128): - super().__init__() - - self.output_layer = nn.Sequential( - nn.Linear(cos_embedding_dim, cos_embedding_dim), - nn.PReLU(), - nn.Linear(cos_embedding_dim, num_output), - ) - - self.register_buffer("integers", torch.arange(0, cos_embedding_dim)) - - def forward(self, tau: torch.Tensor) -> torch.Tensor: # tau: [B, T] - cos_emb_tau = torch.cos(tau.unsqueeze(-1) * self.integers * torch.pi) - return self.output_layer(cos_emb_tau) - - -class ImplicitQuantileModule(nn.Module): - r""" - Implicit Quantile Network from the paper ``IQN for Distributional - Reinforcement Learning`` (https://arxiv.org/abs/1806.06923) by - Dabney et al. 2018. - """ - - def __init__( - self, - in_features: int, - args_dim: Dict[str, int], - domain_map: Callable[..., Tuple[torch.Tensor]], - concentration1: float = 1.0, - concentration0: float = 1.0, - output_domain_map=None, - cos_embedding_dim: int = 64, - ): - super().__init__() - self.output_domain_map = output_domain_map - self.domain_map = domain_map - self.beta = Beta(concentration1=concentration1, concentration0=concentration0) - - self.quantile_layer = QuantileLayer( - in_features, cos_embedding_dim=cos_embedding_dim - ) - self.output_layer = nn.Sequential( - nn.Linear(in_features, in_features), nn.PReLU() - ) - - self.proj = nn.ModuleList( - [nn.Linear(in_features, dim) for dim in args_dim.values()] - ) - - def forward(self, inputs: torch.Tensor): - if self.training: - taus = self.beta.sample(sample_shape=inputs.shape[:-1]).to(inputs.device) - else: - taus = torch.rand(size=inputs.shape[:-1], device=inputs.device) - - emb_taus = self.quantile_layer(taus) - emb_inputs = inputs * (1.0 + emb_taus) - - emb_outputs = self.output_layer(emb_inputs) - outputs = [proj(emb_outputs).squeeze(-1) for proj in self.proj] - if self.output_domain_map is not None: - outputs = [self.output_domain_map(output) for output in outputs] - return (*self.domain_map(*outputs), taus) - - -class ImplicitQuantileNetwork(Distribution): - r""" - Distribution class for the Implicit Quantile from which - we can sample or calculate the quantile loss. - - Parameters - ---------- - outputs - Outputs from the Implicit Quantile Network. - taus - Tensor random numbers from the Beta or Uniform distribution for the - corresponding outputs. - """ - - arg_constraints: Dict[str, constraints.Constraint] = {} - - def __init__(self, outputs: torch.Tensor, taus: torch.Tensor, validate_args=None): - self.taus = taus - self.outputs = outputs - - super().__init__(batch_shape=outputs.shape, validate_args=validate_args) - - @torch.no_grad() - def sample(self, sample_shape=torch.Size()) -> torch.Tensor: - return self.outputs - - def quantile_loss(self, value: torch.Tensor) -> torch.Tensor: - # penalize by tau for under-predicting - # and by 1-tau for over-predicting - return (self.taus - (value < self.outputs).float()) * (value - self.outputs) - - -class ImplicitQuantileNetworkOutput(DistributionOutput): - r""" - DistributionOutput class for the IQN from the paper - ``Probabilistic Time Series Forecasting with Implicit Quantile Networks`` - (https://arxiv.org/abs/2107.03743) by Gouttes et al. 2021. - - Parameters - ---------- - output_domain - Optional domain mapping of the output. Can be "positive", "unit" - or None. - concentration1 - Alpha parameter of the Beta distribution when sampling the taus - during training. - concentration0 - Beta parameter of the Beta distribution when sampling the taus - during training. - cos_embedding_dim - The embedding dimension for the taus embedding layer of IQN. - Default is 64. - """ - - distr_cls = ImplicitQuantileNetwork - args_dim = {"quantile_function": 1} - - @validated() - def __init__( - self, - output_domain: Optional[str] = None, - concentration1: float = 1.0, - concentration0: float = 1.0, - cos_embedding_dim: int = 64, - ) -> None: - super().__init__() - - self.concentration1 = concentration1 - self.concentration0 = concentration0 - self.cos_embedding_dim = cos_embedding_dim - - if output_domain in ["positive", "unit"]: - output_domain_map_func = { - "positive": F.softplus, - "unit": partial(F.softmax, dim=-1), - } - self.output_domain_map = output_domain_map_func[output_domain] - else: - self.output_domain_map = None - - def get_args_proj(self, in_features: int) -> nn.Module: - return ImplicitQuantileModule( - in_features=in_features, - args_dim=self.args_dim, - output_domain_map=self.output_domain_map, - domain_map=LambdaLayer(self.domain_map), - concentration1=self.concentration1, - concentration0=self.concentration0, - cos_embedding_dim=self.cos_embedding_dim, - ) - - @classmethod - def domain_map(cls, *args): - return args - - def distribution(self, distr_args, loc=0, scale=None) -> ImplicitQuantileNetwork: - (outputs, taus) = distr_args - - if scale is not None: - outputs = outputs * scale - if loc is not None: - outputs = outputs + loc - return self.distr_cls(outputs=outputs, taus=taus) - - @property - def event_shape(self): - return () - - def loss( - self, - target: torch.Tensor, - distr_args: Tuple[torch.Tensor, ...], - loc: Optional[torch.Tensor] = None, - scale: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - distribution = self.distribution(distr_args, loc=loc, scale=scale) - return distribution.quantile_loss(target) - - -iqn = ImplicitQuantileNetworkOutput() diff --git a/lag_llama/gluon/__init__.py b/lag_llama/gluon/__init__.py index 801923c..040e23b 100644 --- a/lag_llama/gluon/__init__.py +++ b/lag_llama/gluon/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/lag_llama/gluon/estimator.py b/lag_llama/gluon/estimator.py index d0c5b52..f46f50f 100644 --- a/lag_llama/gluon/estimator.py +++ b/lag_llama/gluon/estimator.py @@ -14,7 +14,7 @@ from typing import Any, Dict, Iterable, Optional -import pytorch_lightning as pl +import lightning as L import torch from gluonts.core.component import validated @@ -27,10 +27,13 @@ get_lags_for_frequency, time_features_from_frequency_str, ) -from gluonts.torch.distributions import StudentTOutput, NegativeBinomialOutput +from gluonts.torch.distributions import ( + NegativeBinomialOutput, + StudentTOutput, + ImplicitQuantileNetworkOutput, +) from gluonts.torch.model.estimator import PyTorchLightningEstimator from gluonts.torch.model.predictor import PyTorchPredictor -from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood from gluonts.transform import ( AddObservedValuesIndicator, AddTimeFeatures, @@ -44,9 +47,6 @@ ValidationSplitSampler, ) -from gluon_utils.gluon_ts_distributions.implicit_quantile_network import ( - ImplicitQuantileNetworkOutput, -) from lag_llama.gluon.lightning_module import LagLlamaLightningModule PREDICTION_INPUT_NAMES = [ @@ -65,7 +65,7 @@ class LagLlamaEstimator(PyTorchLightningEstimator): This class is uses the model defined in ``ConvTSMixerModel``, and wraps it into a ``ConvTSMixerLightningModule`` for training - purposes: training is performed using PyTorch Lightning's ``pl.Trainer`` + purposes: training is performed using PyTorch Lightning's ``L.Trainer`` class. Parameters @@ -82,9 +82,6 @@ class LagLlamaEstimator(PyTorchLightningEstimator): distr_output Distribution to use to evaluate observations and sample predictions (default: StudentTOutput()). - loss - Loss to be optimized during training - (default: ``NegativeLogLikelihood()``). batch_norm Whether to apply batch normalization. batch_size @@ -93,7 +90,7 @@ class LagLlamaEstimator(PyTorchLightningEstimator): Number of batches to be processed in each training epoch (default: 50). trainer_kwargs - Additional arguments to provide to ``pl.Trainer`` for construction. + Additional arguments to provide to ``L.Trainer`` for construction. train_sampler Controls the sampling of windows during training. validation_sampler @@ -143,7 +140,6 @@ def __init__( window_warp_scales: list = [0.5, 2.0], # Continuning model arguments distr_output: str = "studentT", - loss: DistributionLoss = NegativeLogLikelihood(), num_parallel_samples: int = 100, batch_size: int = 32, num_batches_per_epoch: int = 50, @@ -152,7 +148,7 @@ def __init__( validation_sampler: Optional[InstanceSampler] = None, time_feat: bool = False, dropout: float = 0.0, - lags_seq: list = ["Q", "M", "W", "D", "H", "T", "S"], + lags_seq: list = ["QE", "ME", "W", "D", "h", "min", "s"], data_id_to_name_map: dict = {}, use_cosine_annealing_lr: bool = False, cosine_annealing_lr_args: dict = {}, @@ -160,7 +156,9 @@ def __init__( ckpt_path: Optional[str] = None, nonnegative_pred_samples: bool = False, use_single_pass_sampling: bool = False, - device: torch.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + device: torch.device = torch.device("cuda") + if torch.cuda.is_available() + else torch.device("cpu"), ) -> None: default_trainer_kwargs = {"max_epochs": 100} if trainer_kwargs is not None: @@ -200,7 +198,6 @@ def __init__( distr_output = ImplicitQuantileNetworkOutput() self.distr_output = distr_output self.num_parallel_samples = num_parallel_samples - self.loss = loss self.batch_size = batch_size self.num_batches_per_epoch = num_batches_per_epoch self.nonnegative_pred_samples = nonnegative_pred_samples @@ -266,7 +263,7 @@ def create_transformation(self) -> Transformation: start_field=FieldName.START, target_field=FieldName.TARGET, output_field=FieldName.FEAT_TIME, - time_features=time_features_from_frequency_str("S"), + time_features=time_features_from_frequency_str("s"), pred_length=self.prediction_length, ), AddObservedValuesIndicator( @@ -287,7 +284,7 @@ def create_transformation(self) -> Transformation: ] ) - def create_lightning_module(self, use_kv_cache: bool = False) -> pl.LightningModule: + def create_lightning_module(self, use_kv_cache: bool = False) -> L.LightningModule: model_kwargs = { "input_size": self.input_size, "context_length": self.context_length, @@ -308,7 +305,6 @@ def create_lightning_module(self, use_kv_cache: bool = False) -> pl.LightningMod checkpoint_path=self.ckpt_path, map_location=self.device, strict=False, - loss=self.loss, lr=self.lr, weight_decay=self.weight_decay, context_length=self.context_length, @@ -346,7 +342,6 @@ def create_lightning_module(self, use_kv_cache: bool = False) -> pl.LightningMod ) else: return LagLlamaLightningModule( - loss=self.loss, lr=self.lr, weight_decay=self.weight_decay, context_length=self.context_length, diff --git a/lag_llama/gluon/lightning_module.py b/lag_llama/gluon/lightning_module.py index c7a9428..dc7b150 100644 --- a/lag_llama/gluon/lightning_module.py +++ b/lag_llama/gluon/lightning_module.py @@ -15,18 +15,8 @@ import random import numpy as np - -from lightning import LightningModule import torch import torch.nn.functional as F - -from gluonts.core.component import validated -from gluonts.itertools import prod -from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood -from gluonts.torch.util import repeat_along_dim, take_last - -from data.augmentations.freq_mask import freq_mask -from data.augmentations.freq_mix import freq_mix from data.augmentations.augmentations import ( ApplyAugmentations, Jitter, @@ -38,9 +28,15 @@ WindowSlice, WindowWarp, ) -from gluon_utils.gluon_ts_distributions.implicit_quantile_network import ( - ImplicitQuantileNetworkOutput, -) +from data.augmentations.freq_mask import freq_mask +from data.augmentations.freq_mix import freq_mix +from gluonts.core.component import validated +from gluonts.itertools import prod + +# from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood +from gluonts.torch.util import repeat_along_dim, take_last +from lightning import LightningModule + from lag_llama.model.module import LagLlamaModel @@ -71,7 +67,6 @@ def __init__( model_kwargs: dict, context_length: int, prediction_length: int, - loss: DistributionLoss = NegativeLogLikelihood(), lr: float = 1e-3, weight_decay: float = 1e-8, aug_prob: float = 0.1, @@ -103,13 +98,13 @@ def __init__( nonnegative_pred_samples: bool = False, use_kv_cache: bool = True, use_single_pass_sampling: bool = False, + **kwargs, ): super().__init__() self.save_hyperparameters() self.context_length = self.hparams.context_length self.prediction_length = self.hparams.prediction_length self.model = LagLlamaModel(**self.hparams.model_kwargs) - self.loss = self.hparams.loss self.lr = self.hparams.lr self.weight_decay = self.hparams.weight_decay self.aug_prob = self.hparams.aug_prob @@ -237,33 +232,34 @@ def forward(self, *args, **kwargs): params, loc, scale = self.model( *args, past_time_feat=past_time_feat if self.time_feat else None, - future_time_feat=future_time_feat[..., : t + 1, :] if self.time_feat else None, + future_time_feat=future_time_feat[..., : t + 1, :] + if self.time_feat + else None, past_target=past_target, past_observed_values=past_observed_values, use_kv_cache=self.use_kv_cache, ) - sliced_params = [ - p[:, -1:] for p in params - ] # Take the last timestep predicted. Each tensor is of shape (#bsz, 1) + # Take the last timestep predicted. Each tensor is of shape (#bsz, 1) + sliced_params = [p[:, -1:] for p in params] # Singular distribution is used for getting the greedy prediction (mean) distr = self.model.distr_output.distribution(sliced_params, loc, scale) - greedy_prediction = distr.mean # (#bsz, 1) + greedy_prediction = distr.mean # (#bsz, 1) + # Take the last timestep predicted and repeat for number of samples. Each tensor is of shape (#bsz*#parallel_samples, 1) repeated_sliced_params = [ - p[:, -1:].repeat_interleave( - self.model.num_parallel_samples, 0 - ) for p in params - ] # Take the last timestep predicted and repeat for number of samples. Each tensor is of shape (#bsz*#parallel_samples, 1) - repeated_loc = loc.repeat_interleave( - self.model.num_parallel_samples, 0 - ) + p[:, -1:].repeat_interleave(self.model.num_parallel_samples, 0) + for p in params + ] + repeated_loc = loc.repeat_interleave(self.model.num_parallel_samples, 0) repeated_scale = scale.repeat_interleave( - self.model.num_parallel_samples, 0 - ) + self.model.num_parallel_samples, 0 + ) # Repeated distribution is used for getting the parallel samples # (distr.sample([self.model.num_parallel_samples]) seems to give terrible results) - repeated_distr = self.model.distr_output.distribution(repeated_sliced_params, repeated_loc, repeated_scale) + repeated_distr = self.model.distr_output.distribution( + repeated_sliced_params, repeated_loc, repeated_scale + ) sample = repeated_distr.sample() # (#bsz*#parallel_samples, 1) if self.nonnegative_pred_samples: sample = F.relu(sample) @@ -275,11 +271,19 @@ def forward(self, *args, **kwargs): ) else: # Original probabilistic forecasting: Duplicate input, `num_parallel_samples` forward passes per step, sample each distribution once, add samples to context. - repeated_past_target = past_target.repeat_interleave(self.model.num_parallel_samples, 0) - repeated_past_observed_values = past_observed_values.repeat_interleave(self.model.num_parallel_samples, 0) + repeated_past_target = past_target.repeat_interleave( + self.model.num_parallel_samples, 0 + ) + repeated_past_observed_values = past_observed_values.repeat_interleave( + self.model.num_parallel_samples, 0 + ) if self.time_feat: - repeated_past_time_feat = past_time_feat.repeat_interleave(self.model.num_parallel_samples, 0) - repeated_future_time_feat = future_time_feat.repeat_interleave(self.model.num_parallel_samples, 0) + repeated_past_time_feat = past_time_feat.repeat_interleave( + self.model.num_parallel_samples, 0 + ) + repeated_future_time_feat = future_time_feat.repeat_interleave( + self.model.num_parallel_samples, 0 + ) for t in range(self.prediction_length): if self.time_feat: @@ -321,19 +325,23 @@ def forward(self, *args, **kwargs): + self.model.distr_output.event_shape, ) - # train - def _compute_loss(self, batch, do_not_average=False, return_observed_values=False): - past_target = batch[ - "past_target" - ] # (bsz, model.context_length+max(model.lags_seq)) - past_observed_values = batch[ - "past_observed_values" - ] # (bsz, model.context_length+max(model.lags_seq)) with 0s or 1s indicating available (1s) or missing (0s) - future_target = batch["future_target"] # (bsz, model.prediction_length) - future_observed_values = batch[ - "future_observed_values" - ] # (bsz, model.prediction_length) with 0s or 1s indicating available (1s) or missing (0s) + def _compute_loss( + self, + batch, + return_observed_values=False, + aggregate_by=torch.mean, + orignal_scale=True, + ): + # (bsz, model.context_length+max(model.lags_seq)) + past_target = batch["past_target"] + + # (bsz, model.context_length+max(model.lags_seq)) with 0s or 1s indicating available (1s) or missing (0s) + past_observed_values = batch["past_observed_values"] + # (bsz, model.prediction_length) + future_target = batch["future_target"] + # (bsz, model.prediction_length) with 0s or 1s indicating available (1s) or missing (0s) + future_observed_values = batch["future_observed_values"] if self.time_feat: past_time_feat = batch["past_time_feat"] future_time_feat = batch["future_time_feat"] @@ -345,64 +353,55 @@ def _compute_loss(self, batch, do_not_average=False, return_observed_values=Fals extra_shape = future_target.shape[:extra_dims] # shape remains the same repeats = prod(extra_shape) # usually 1 - past_target = repeat_along_dim( - past_target, 0, repeats - ) # (bsz, model.context_length+max(model.lags_seq)) - past_observed_values = repeat_along_dim( - past_observed_values, 0, repeats - ) # (bsz, model.context_length+max(model.lags_seq)) + # (bsz, model.context_length+max(model.lags_seq)) + past_target = repeat_along_dim(past_target, 0, repeats) + + # (bsz, model.context_length+max(model.lags_seq)) + past_observed_values = repeat_along_dim(past_observed_values, 0, repeats) + + # (bsz, model.prediction_length) future_target_reshaped = future_target.reshape( -1, *future_target.shape[extra_dims + 1 :], - ) # (bsz, model.prediction_length) + ) + # (bsz, model.prediction_length) future_observed_reshaped = future_observed_values.reshape( -1, *future_observed_values.shape[extra_dims + 1 :], - ) # (bsz, model.prediction_length) + ) + # distr_args is a tuple with two tensors of shape (bsz, context_length+pred_len-1) distr_args, loc, scale = self.model( past_target=past_target, past_observed_values=past_observed_values, past_time_feat=past_time_feat, future_time_feat=future_time_feat, future_target=future_target_reshaped, - ) # distr_args is a tuple with two tensors of shape (bsz, context_length+pred_len-1) - context_target = take_last( - past_target, dim=-1, num=self.context_length - 1 - ) # (bsz, context_length-1) # Basically removes the first value since it cannot be predicted - target = torch.cat( - (context_target, future_target_reshaped), - dim=1, - ) # (bsz, context_length-1+pred_len) # values that can be predicted + ) + # (bsz, context_length-1) # Basically removes the first value since it cannot be predicted + context_target = take_last(past_target, dim=-1, num=self.context_length - 1) + # (bsz, context_length-1+pred_len) # values that can be predicted + target = torch.cat((context_target, future_target_reshaped), dim=1) + # same as context_target, but for observed_values tensor context_observed = take_last( past_observed_values, dim=-1, num=self.context_length - 1 - ) # same as context_target, but for observed_values tensor - observed_values = torch.cat( - (context_observed, future_observed_reshaped), dim=1 - ) # same as target but for observed_values tensor - - if type(self.model.distr_output) == ImplicitQuantileNetworkOutput: - if not do_not_average: - loss = ( - self.model.distr_output.loss(target, distr_args, loc, scale) - * observed_values - ).sum() / observed_values.sum().clamp_min(1.0) - else: - loss = ( - self.model.distr_output.loss(target, distr_args, loc, scale) - * observed_values - ) + ) + # same as target but for observed_values tensor + observed_values = torch.cat((context_observed, future_observed_reshaped), dim=1) + + if orignal_scale: + loss_values = self.model.distr_output.loss(target, distr_args, loc, scale) else: - distr = self.model.distr_output.distribution( - distr_args, loc=loc, scale=scale - ) # an object representing a distribution with the specified parameters. We need this to compute the NLL loss. - if not do_not_average: - loss = ( - self.loss(distr, target) * observed_values - ).sum() / observed_values.sum().clamp_min(1.0) - else: - loss = self.loss(distr, target) * observed_values + loss_values = self.model.distr_output.loss( + (target - loc) / scale, distr_args + ) + loss_values = loss_values * observed_values.clamp_min(1) + + loss = aggregate_by( + loss_values, + dim=tuple(range(extra_dims + 1, len(future_target.shape))), + ) if not return_observed_values: return loss @@ -433,17 +432,16 @@ def training_step(self, batch, batch_idx: int): # type: ignore batch["past_target"], batch["future_target"] ) - train_loss_per_sample, observed_values = self._compute_loss( - batch, do_not_average=True, return_observed_values=True - ) + train_loss_avg = self._compute_loss(batch, return_observed_values=False) - train_loss_avg = train_loss_per_sample.sum() / observed_values.sum().clamp_min( - 1.0 - ) self.log( - "train_loss", train_loss_avg, on_epoch=True, on_step=False, prog_bar=False + "train_loss", + train_loss_avg.mean(), + on_epoch=True, + on_step=False, + prog_bar=False, ) - return train_loss_avg + return train_loss_avg.mean() def on_train_epoch_end(self): # Log all losses @@ -477,13 +475,18 @@ def validation_step(self, batch, batch_idx: int): # type: ignore """ Execute validation step. """ - val_loss_per_sample, observed_values = self._compute_loss( - batch, do_not_average=True, return_observed_values=True + val_loss_avg = self._compute_loss( + batch, return_observed_values=False, orignal_scale=False ) - val_loss_avg = val_loss_per_sample.sum() / observed_values.sum().clamp_min(1.0) - self.log("val_loss", val_loss_avg, on_epoch=True, on_step=False, prog_bar=False) - return val_loss_avg + self.log( + "val_loss", + val_loss_avg.mean(), + on_epoch=True, + on_step=False, + prog_bar=False, + ) + return val_loss_avg.mean() def on_validation_epoch_end(self): # Log all losses diff --git a/lag_llama/model/__init__.py b/lag_llama/model/__init__.py index 801923c..040e23b 100644 --- a/lag_llama/model/__init__.py +++ b/lag_llama/model/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/lag_llama/model/module.py b/lag_llama/model/module.py index d1ec02c..35dee34 100644 --- a/lag_llama/model/module.py +++ b/lag_llama/model/module.py @@ -20,12 +20,11 @@ from torch import nn from torch.nn import functional as F +from gluon_utils.scalers.robust_scaler import RobustScaler from gluonts.torch.distributions import DistributionOutput from gluonts.torch.scaler import MeanScaler, NOPScaler, StdScaler from gluonts.torch.util import lagged_sequence_values, unsqueeze_expand -from gluon_utils.scalers.robust_scaler import RobustScaler - @dataclass class LTSMConfig: @@ -319,11 +318,15 @@ def forward(self, x: torch.Tensor, use_kv_cache: bool) -> torch.Tensor: # When kv_cache is in use and we're working with only the last token (T = 1 instead of full sequence length `true_seq_len``) # Use the full sequence length for positional embeddings (true_seq_len) # q is the query vector for the last token, so it's position is the last index (-1) - cos, sin = self.rotary_emb(device=v.device, dtype=v.dtype, seq_len=true_seq_len) + cos, sin = self.rotary_emb( + device=v.device, dtype=v.dtype, seq_len=true_seq_len + ) q, _ = apply_rotary_pos_emb(q, k, cos, sin, position_ids=[-1]) - + # k is the key matrix after concatenation with cache, so no position_ids - cos, sin = self.rotary_emb(device=v.device, dtype=v.dtype, seq_len=true_seq_len) + cos, sin = self.rotary_emb( + device=v.device, dtype=v.dtype, seq_len=true_seq_len + ) _, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids=None) else: cos, sin = self.rotary_emb(device=v.device, dtype=v.dtype, seq_len=T) @@ -495,6 +498,7 @@ def prepare_input( # In the below code, instead of max(self.lags_seq), it was previously -self.context_length if future_target is not None: + # Shape is (bsz, context_length+(pred_len-1)) input = torch.cat( ( scaled_past_target[..., max(self.lags_seq) :], # Just the context @@ -502,7 +506,7 @@ def prepare_input( / scale, # Not sure about the -1 here. Maybe so since the last value isn't used in the model for prediction of any new values. also if the prediction length is 1, this doesn't really affect anything ), dim=-1, - ) # Shape is (bsz, context_length+(pred_len-1)) + ) else: input = scaled_past_target[..., max(self.lags_seq) :] if (past_time_feat is not None) and (future_time_feat is not None): @@ -518,20 +522,18 @@ def prepare_input( else past_time_feat[..., max(self.lags_seq) :, :] ) - prior_input = ( - past_target[..., : max(self.lags_seq)] - loc - ) / scale # This the history used to construct lags. # bsz, max(self.lags_seq) + # This the history used to construct lags. # bsz, max(self.lags_seq) + prior_input = (past_target[..., : max(self.lags_seq)] - loc) / scale - lags = lagged_sequence_values( - self.lags_seq, prior_input, input, dim=-1 - ) # Lags are added as an extra dim. Shape is (bsz, context_length+(pred_len-1), len(self.lags_seq)) + # Lags are added as an extra dim. Shape is (bsz, context_length+(pred_len-1), len(self.lags_seq)) + lags = lagged_sequence_values(self.lags_seq, prior_input, input, dim=-1) - static_feat = torch.cat( - (loc.abs().log1p(), scale.log()), dim=-1 - ) # (bsz, 2) (loc and scale are concatenated) + # (bsz, 2) (loc and scale are concatenated) + static_feat = torch.cat((loc.abs().log1p(), scale.log()), dim=-1) + # (bsz, context_length+(pred_len-1), 2) expanded_static_feat = unsqueeze_expand( static_feat, dim=-2, size=lags.shape[-2] - ) # (bsz, context_length+(pred_len-1), 2) + ) # expanded_static_feat: (bsz, context_length+(pred_len-1), len(self.lags_seq) + 2); (bsz, 1); (bsz, 1) if past_time_feat is not None: @@ -566,20 +568,18 @@ def forward( transformer_input = transformer_input[:, -1:] # forward the LLaMA model itself - x = self.transformer.wte( - transformer_input - ) # token embeddings of shape (b, t, n_embd_per_head*n_head) # (bsz, context_length+(pred_len-1), n_embd_per_head*n_head) + # token embeddings of shape (b, t, n_embd_per_head*n_head) # (bsz, context_length+(pred_len-1), n_embd_per_head*n_head) + x = self.transformer.wte(transformer_input) for block in self.transformer.h: x = block(x, use_kv_cache) - x = self.transformer.ln_f( - x - ) # (bsz, context_length+(pred_len-1), n_embd_per_head*n_head) + # (bsz, context_length+(pred_len-1), n_embd_per_head*n_head) + x = self.transformer.ln_f(x) if use_kv_cache: self.y_cache = True - params = self.param_proj( - x - ) # (bsz, context_length+(pred_len-1)) ; (bsz, context_length+(pred_len-1)) + + # (bsz, context_length+(pred_len-1)) ; (bsz, context_length+(pred_len-1)) + params = self.param_proj(x) return params, loc, scale def reset_cache(self) -> None: diff --git a/requirements.txt b/requirements.txt index f43dcaf..48ef659 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -gluonts[torch]<=0.14.4 +gluonts[torch] numpy>=1.23.5 torch>=2.0.0 wandb diff --git a/run.py b/run.py index 4a24eec..e71f261 100644 --- a/run.py +++ b/run.py @@ -581,7 +581,7 @@ def train(args): parser.add_argument("--n_embd_per_head", type=int, default=64) parser.add_argument("--n_head", type=int, default=4) parser.add_argument("--dim_feedforward", type=int, default=256) - parser.add_argument("--lags_seq", type=str, nargs="+", default=["Q", "M", "W", "D", "H", "T", "S"]) + parser.add_argument("--lags_seq", type=str, nargs="+", default=["QE", "ME", "W", "D", "h", "min", "s"]) # Data normalization parser.add_argument(