diff --git a/tests/models/tinytimemixer/test_modeling_tinytimemixer.py b/tests/models/tinytimemixer/test_modeling_tinytimemixer.py index 96bf7a49..2b418a1a 100644 --- a/tests/models/tinytimemixer/test_modeling_tinytimemixer.py +++ b/tests/models/tinytimemixer/test_modeling_tinytimemixer.py @@ -16,6 +16,7 @@ from tsfm_public.models.tinytimemixer import ( TinyTimeMixerConfig, + TinyTimeMixerForMaskedPrediction, TinyTimeMixerForPrediction, TinyTimeMixerModel, ) @@ -83,6 +84,14 @@ def setUpClass(cls): max(cls.params["context_length"], cls.params["patch_length"]) - cls.params["patch_length"] ) // cls.params["patch_stride"] + 1 + cls.num_masked_patches = ( + max( + cls.params["context_length"] + cls.params["prediction_length"], + cls.params["patch_length"], + ) + - cls.params["patch_length"] + ) // cls.params["patch_stride"] + 1 + # batch_size = 32 batch_size = 2 cls.batch_size = batch_size @@ -116,6 +125,13 @@ def setUpClass(cls): cls.params["d_model"], ) + cls.enc_masked_output = torch.rand( + batch_size, + cls.params["num_input_channels"], + cls.num_masked_patches, + cls.params["d_model"], + ) + cls.dec_output = torch.rand( batch_size, cls.params["num_input_channels"], @@ -123,6 +139,13 @@ def setUpClass(cls): cls.params["decoder_d_model"], ) + cls.dec_masked_output = torch.rand( + batch_size, + cls.params["num_input_channels"], + cls.num_masked_patches, + cls.params["decoder_d_model"], + ) + cls.flat_enc_output = torch.rand( batch_size, cls.num_patches, @@ -187,6 +210,7 @@ def check_module( task, params=None, output_hidden_states=True, + mask_prediction=False, future_observed_mask=None, # None, int, bool past_observed_mask=None, # None, int, bool input_data=None, @@ -195,8 +219,10 @@ def check_module( input_data = self.__class__.data config = TinyTimeMixerConfig(**params) if task == "forecast": - mdl = TinyTimeMixerForPrediction(config) - + if mask_prediction is False: + mdl = TinyTimeMixerForPrediction(config) + else: + mdl = TinyTimeMixerForMaskedPrediction(config) if ( "target_channel_filtered" in params and params["target_channel_filtered"] @@ -236,10 +262,16 @@ def check_module( else: past_observed_mask = None - enc_output = self.__class__.enc_output + if mask_prediction is False: + enc_output = self.__class__.enc_output + else: + enc_output = self.__class__.enc_masked_output if config.use_decoder: - dec_output = self.__class__.dec_output + if mask_prediction is False: + dec_output = self.__class__.dec_output + else: + dec_output = self.__class__.dec_masked_output else: dec_output = enc_output @@ -294,8 +326,9 @@ def check_module( enc_output_shape[-2] += 1 dec_output_shape[-2] += 1 - self.assertEqual(list(output.backbone_hidden_state.shape), enc_output_shape) - self.assertEqual(list(output.decoder_hidden_state.shape), dec_output_shape) + if mask_prediction is False: + self.assertEqual(list(output.backbone_hidden_state.shape), enc_output_shape) + self.assertEqual(list(output.decoder_hidden_state.shape), dec_output_shape) # self.assertEqual(output.backbone_hidden_state.shape, enc_output.shape) # self.assertEqual(output.decoder_hidden_state.shape, dec_output.shape) @@ -312,9 +345,10 @@ def check_module( [True, False, "mean", "std"], [True, False], [None, [0, 2]], - ["mse", "mae", "pinball", "huber", None], + ["mse", "nll", "mae", "pinball", "huber", None], [8, 16], [True, False], + [True, False], ) ) ) @@ -328,6 +362,7 @@ def test_forecast( loss, prediction_filter_length, target_pred_length_filtered, + mask_prediction, ): params = self.__class__.params.copy() params.update( @@ -341,14 +376,14 @@ def test_forecast( target_pred_length_filtered=target_pred_length_filtered, target_channel_filtered=False, ) - - self.check_module(task="forecast", params=params) + self.check_module(task="forecast", params=params, mask_prediction=mask_prediction) def test_var0_mask(self): params = self.__class__.params.copy() self.check_module( task="forecast", params=params, + mask_prediction=False, future_observed_mask="bool", past_observed_mask="bool", input_data=self.__class__.constant_data, @@ -358,9 +393,10 @@ def test_var0_mask(self): list( itertools.product( [None, [0, 2]], - ["mse", "mae", None], + ["mse", "nll", "mae", None], [8, 16], [True, False], + [True, False], [None, "int", "bool"], [None, "int", "bool"], ) @@ -372,6 +408,7 @@ def test_observed_mask( loss, prediction_filter_length, target_pred_length_filtered, + mask_prediction, past_observed_mask, future_observed_mask, ): @@ -387,6 +424,7 @@ def test_observed_mask( self.check_module( task="forecast", params=params, + mask_prediction=mask_prediction, future_observed_mask=future_observed_mask, past_observed_mask=past_observed_mask, ) @@ -489,10 +527,14 @@ def forecast_full_module( params=None, output_hidden_states=False, return_dict=None, + mask_prediction=False, ): config = TinyTimeMixerConfig(**params) - mdl = TinyTimeMixerForPrediction(config) + if mask_prediction is False: + mdl = TinyTimeMixerForPrediction(config) + else: + mdl = TinyTimeMixerForMaskedPrediction(config) target_val = self.__class__.correct_forecast_output @@ -501,11 +543,16 @@ def forecast_full_module( if config.prediction_channel_indices is not None: target_val = self.__class__.correct_sel_forecast_output - enc_output = self.__class__.enc_output + if mask_prediction: + enc_output = self.__class__.enc_masked_output + else: + enc_output = self.__class__.enc_output if config.use_decoder: - dec_output = self.__class__.dec_output - + if mask_prediction: + dec_output = self.__class__.dec_masked_output + else: + dec_output = self.__class__.dec_output else: dec_output = enc_output @@ -548,8 +595,9 @@ def forecast_full_module( enc_output_shape[-2] += 1 dec_output_shape[-2] += 1 - self.assertEqual(list(output.backbone_hidden_state.shape), enc_output_shape) - self.assertEqual(list(output.decoder_hidden_state.shape), dec_output_shape) + if mask_prediction is False: + self.assertEqual(list(output.backbone_hidden_state.shape), enc_output_shape) + self.assertEqual(list(output.decoder_hidden_state.shape), dec_output_shape) # if output_hidden_states is True: # print("ooo", len(output.hidden_states))