Skip to content

Commit

Permalink
update tests for masking support
Browse files Browse the repository at this point in the history
  • Loading branch information
wgifford committed Jan 22, 2025
1 parent d657fd2 commit 70112cf
Showing 1 changed file with 64 additions and 16 deletions.
80 changes: 64 additions & 16 deletions tests/models/tinytimemixer/test_modeling_tinytimemixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from tsfm_public.models.tinytimemixer import (
TinyTimeMixerConfig,
TinyTimeMixerForMaskedPrediction,
TinyTimeMixerForPrediction,
TinyTimeMixerModel,
)
Expand Down Expand Up @@ -83,6 +84,14 @@ def setUpClass(cls):
max(cls.params["context_length"], cls.params["patch_length"]) - cls.params["patch_length"]
) // cls.params["patch_stride"] + 1

cls.num_masked_patches = (
max(
cls.params["context_length"] + cls.params["prediction_length"],
cls.params["patch_length"],
)
- cls.params["patch_length"]
) // cls.params["patch_stride"] + 1

# batch_size = 32
batch_size = 2
cls.batch_size = batch_size
Expand Down Expand Up @@ -116,13 +125,27 @@ def setUpClass(cls):
cls.params["d_model"],
)

cls.enc_masked_output = torch.rand(
batch_size,
cls.params["num_input_channels"],
cls.num_masked_patches,
cls.params["d_model"],
)

cls.dec_output = torch.rand(
batch_size,
cls.params["num_input_channels"],
cls.num_patches,
cls.params["decoder_d_model"],
)

cls.dec_masked_output = torch.rand(
batch_size,
cls.params["num_input_channels"],
cls.num_masked_patches,
cls.params["decoder_d_model"],
)

cls.flat_enc_output = torch.rand(
batch_size,
cls.num_patches,
Expand Down Expand Up @@ -187,6 +210,7 @@ def check_module(
task,
params=None,
output_hidden_states=True,
mask_prediction=False,
future_observed_mask=None, # None, int, bool
past_observed_mask=None, # None, int, bool
input_data=None,
Expand All @@ -195,8 +219,10 @@ def check_module(
input_data = self.__class__.data
config = TinyTimeMixerConfig(**params)
if task == "forecast":
mdl = TinyTimeMixerForPrediction(config)

if mask_prediction is False:
mdl = TinyTimeMixerForPrediction(config)
else:
mdl = TinyTimeMixerForMaskedPrediction(config)
if (
"target_channel_filtered" in params
and params["target_channel_filtered"]
Expand Down Expand Up @@ -236,10 +262,16 @@ def check_module(
else:
past_observed_mask = None

enc_output = self.__class__.enc_output
if mask_prediction is False:
enc_output = self.__class__.enc_output
else:
enc_output = self.__class__.enc_masked_output

if config.use_decoder:
dec_output = self.__class__.dec_output
if mask_prediction is False:
dec_output = self.__class__.dec_output
else:
dec_output = self.__class__.dec_masked_output
else:
dec_output = enc_output

Expand Down Expand Up @@ -294,8 +326,9 @@ def check_module(
enc_output_shape[-2] += 1
dec_output_shape[-2] += 1

self.assertEqual(list(output.backbone_hidden_state.shape), enc_output_shape)
self.assertEqual(list(output.decoder_hidden_state.shape), dec_output_shape)
if mask_prediction is False:
self.assertEqual(list(output.backbone_hidden_state.shape), enc_output_shape)
self.assertEqual(list(output.decoder_hidden_state.shape), dec_output_shape)

# self.assertEqual(output.backbone_hidden_state.shape, enc_output.shape)
# self.assertEqual(output.decoder_hidden_state.shape, dec_output.shape)
Expand All @@ -312,9 +345,10 @@ def check_module(
[True, False, "mean", "std"],
[True, False],
[None, [0, 2]],
["mse", "mae", "pinball", "huber", None],
["mse", "nll", "mae", "pinball", "huber", None],
[8, 16],
[True, False],
[True, False],
)
)
)
Expand All @@ -328,6 +362,7 @@ def test_forecast(
loss,
prediction_filter_length,
target_pred_length_filtered,
mask_prediction,
):
params = self.__class__.params.copy()
params.update(
Expand All @@ -341,14 +376,14 @@ def test_forecast(
target_pred_length_filtered=target_pred_length_filtered,
target_channel_filtered=False,
)

self.check_module(task="forecast", params=params)
self.check_module(task="forecast", params=params, mask_prediction=mask_prediction)

def test_var0_mask(self):
params = self.__class__.params.copy()
self.check_module(
task="forecast",
params=params,
mask_prediction=False,
future_observed_mask="bool",
past_observed_mask="bool",
input_data=self.__class__.constant_data,
Expand All @@ -358,9 +393,10 @@ def test_var0_mask(self):
list(
itertools.product(
[None, [0, 2]],
["mse", "mae", None],
["mse", "nll", "mae", None],
[8, 16],
[True, False],
[True, False],
[None, "int", "bool"],
[None, "int", "bool"],
)
Expand All @@ -372,6 +408,7 @@ def test_observed_mask(
loss,
prediction_filter_length,
target_pred_length_filtered,
mask_prediction,
past_observed_mask,
future_observed_mask,
):
Expand All @@ -387,6 +424,7 @@ def test_observed_mask(
self.check_module(
task="forecast",
params=params,
mask_prediction=mask_prediction,
future_observed_mask=future_observed_mask,
past_observed_mask=past_observed_mask,
)
Expand Down Expand Up @@ -489,10 +527,14 @@ def forecast_full_module(
params=None,
output_hidden_states=False,
return_dict=None,
mask_prediction=False,
):
config = TinyTimeMixerConfig(**params)

mdl = TinyTimeMixerForPrediction(config)
if mask_prediction is False:
mdl = TinyTimeMixerForPrediction(config)
else:
mdl = TinyTimeMixerForMaskedPrediction(config)

target_val = self.__class__.correct_forecast_output

Expand All @@ -501,11 +543,16 @@ def forecast_full_module(
if config.prediction_channel_indices is not None:
target_val = self.__class__.correct_sel_forecast_output

enc_output = self.__class__.enc_output
if mask_prediction:
enc_output = self.__class__.enc_masked_output
else:
enc_output = self.__class__.enc_output

if config.use_decoder:
dec_output = self.__class__.dec_output

if mask_prediction:
dec_output = self.__class__.dec_masked_output
else:
dec_output = self.__class__.dec_output
else:
dec_output = enc_output

Expand Down Expand Up @@ -548,8 +595,9 @@ def forecast_full_module(
enc_output_shape[-2] += 1
dec_output_shape[-2] += 1

self.assertEqual(list(output.backbone_hidden_state.shape), enc_output_shape)
self.assertEqual(list(output.decoder_hidden_state.shape), dec_output_shape)
if mask_prediction is False:
self.assertEqual(list(output.backbone_hidden_state.shape), enc_output_shape)
self.assertEqual(list(output.decoder_hidden_state.shape), dec_output_shape)

# if output_hidden_states is True:
# print("ooo", len(output.hidden_states))
Expand Down

0 comments on commit 70112cf

Please sign in to comment.