From f543d40460b6b4d78518c325db91740de65007cf Mon Sep 17 00:00:00 2001 From: Wesley Gifford <79663411+wgifford@users.noreply.github.com> Date: Wed, 9 Oct 2024 10:32:30 -0400 Subject: [PATCH] minor updates, WIP --- tsfm_public/toolkit/recursive_predictor.py | 24 ++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/tsfm_public/toolkit/recursive_predictor.py b/tsfm_public/toolkit/recursive_predictor.py index 91cc384c..e367bcdc 100644 --- a/tsfm_public/toolkit/recursive_predictor.py +++ b/tsfm_public/toolkit/recursive_predictor.py @@ -1,3 +1,7 @@ +# Copyright contributors to the TSFM project +# +"""Recursive prediction model wrapper""" + import math from dataclasses import dataclass from typing import Optional @@ -80,13 +84,18 @@ def forward( self, past_values: torch.Tensor, future_values: Optional[torch.Tensor] = None, - return_dict: Optional[bool] = None, + past_observed_mask: Optional[torch.Tensor] = None, + future_observed_mask: Optional[torch.Tensor] = None, + output_hidden_states: Optional[bool] = False, return_loss: bool = True, + return_dict: Optional[bool] = None, freq_token: Optional[torch.Tensor] = None, - **kwargs, + static_categorical_values: Optional[torch.Tensor] = None, ) -> RecursivePredictorOutput: """ - Predict future points given an input sequence. + Predict future points given an input sequence, using a recursive strategy. + + add requirements of model, especially with respect to exogenous Args: past_values (torch.Tensor): Input sequence of shape (batch_size, sequence_length, num_channels). @@ -98,13 +107,12 @@ def forward( return_dict = return_dict if return_dict is not None else self.use_return_dict total_runs = math.ceil(self.requested_forecast_length / self.model_forecast_length) - # device = past_values.device - device = past_values.device # Get device of input_sequence - # self.model.to(device) # Move model to the same device as input_sequence - past_values = past_values.to(device) + # this should be handled by Trainer + # device = self.model.device # Get device of model + # past_values = past_values.to(device) - # device = next(self.model.parameters()).device + # double check need for no_grad() # with torch.no_grad(): sequence_length = past_values.size(1) predicted_sequence = past_values.clone() # Initialize predicted sequence with input sequence