Skip to content

Commit

Permalink
minor updates, WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
wgifford committed Oct 9, 2024
1 parent 95fb570 commit f543d40
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions tsfm_public/toolkit/recursive_predictor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# Copyright contributors to the TSFM project
#
"""Recursive prediction model wrapper"""

import math
from dataclasses import dataclass
from typing import Optional
Expand Down Expand Up @@ -80,13 +84,18 @@ def forward(
self,
past_values: torch.Tensor,
future_values: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
past_observed_mask: Optional[torch.Tensor] = None,
future_observed_mask: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = False,
return_loss: bool = True,
return_dict: Optional[bool] = None,
freq_token: Optional[torch.Tensor] = None,
**kwargs,
static_categorical_values: Optional[torch.Tensor] = None,
) -> RecursivePredictorOutput:
"""
Predict future points given an input sequence.
Predict future points given an input sequence, using a recursive strategy.
add requirements of model, especially with respect to exogenous
Args:
past_values (torch.Tensor): Input sequence of shape (batch_size, sequence_length, num_channels).
Expand All @@ -98,13 +107,12 @@ def forward(
return_dict = return_dict if return_dict is not None else self.use_return_dict

total_runs = math.ceil(self.requested_forecast_length / self.model_forecast_length)
# device = past_values.device
device = past_values.device # Get device of input_sequence

# self.model.to(device) # Move model to the same device as input_sequence
past_values = past_values.to(device)
# this should be handled by Trainer
# device = self.model.device # Get device of model
# past_values = past_values.to(device)

# device = next(self.model.parameters()).device
# double check need for no_grad()
# with torch.no_grad():
sequence_length = past_values.size(1)
predicted_sequence = past_values.clone() # Initialize predicted sequence with input sequence
Expand Down

0 comments on commit f543d40

Please sign in to comment.