Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to load a channel_mix TTM model #228

Open
LightingMc opened this issue Dec 23, 2024 · 5 comments
Open

How to load a channel_mix TTM model #228

LightingMc opened this issue Dec 23, 2024 · 5 comments

Comments

@LightingMc
Copy link

I saw that the currently available pre-trained models don't have a channel mixer in there. How do I introduce a channel_mixer into them?

@LightingMc
Copy link
Author

By that I mean the inter-channel mixer.

image

For example. I loaded

`
import torch
from tsfm_public.toolkit.get_model import get_model

TTM_MODEL_PATH = "ibm-granite/granite-timeseries-ttm-r2"
CONTEXT_LENGTH = 1024
PREDICTION_LENGTH = 720

model = get_model(TTM_MODEL_PATH, context_length=CONTEXT_LENGTH, prediction_length=PREDICTION_LENGTH,num_input_channels=1000).backbone

for name,_ in model.named_parameters():
print(name)
`
The result only had the feature_mixer and the patch_mixer, no channel_mixer. I tried this with all of the backbones. All had the same result.

Result:

encoder.patcher.weight
encoder.patcher.bias
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.0.patch_mixer.norm.norm.weight
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.0.patch_mixer.norm.norm.bias
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.0.patch_mixer.mlp.fc1.weight
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.0.patch_mixer.mlp.fc1.bias
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.0.patch_mixer.mlp.fc2.weight
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.0.patch_mixer.mlp.fc2.bias
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.0.patch_mixer.gating_block.attn_layer.weight
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.0.patch_mixer.gating_block.attn_layer.bias
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.0.feature_mixer.norm.norm.weight
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.0.feature_mixer.norm.norm.bias
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.0.feature_mixer.mlp.fc1.weight
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.0.feature_mixer.mlp.fc1.bias
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.0.feature_mixer.mlp.fc2.weight
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.0.feature_mixer.mlp.fc2.bias
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.0.feature_mixer.gating_block.attn_layer.weight
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.0.feature_mixer.gating_block.attn_layer.bias
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.1.patch_mixer.norm.norm.weight
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.1.patch_mixer.norm.norm.bias
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.1.patch_mixer.mlp.fc1.weight
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.1.patch_mixer.mlp.fc1.bias
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.1.patch_mixer.mlp.fc2.weight
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.1.patch_mixer.mlp.fc2.bias
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.1.patch_mixer.gating_block.attn_layer.weight
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.1.patch_mixer.gating_block.attn_layer.bias
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.1.feature_mixer.norm.norm.weight
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.1.feature_mixer.norm.norm.bias
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.1.feature_mixer.mlp.fc1.weight
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.1.feature_mixer.mlp.fc1.bias
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.1.feature_mixer.mlp.fc2.weight
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.1.feature_mixer.mlp.fc2.bias
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.1.feature_mixer.gating_block.attn_layer.weight
encoder.mlp_mixer_encoder.mixers.0.mixer_layers.1.feature_mixer.gating_block.attn_layer.bias
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.0.patch_mixer.norm.norm.weight
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.0.patch_mixer.norm.norm.bias
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.0.patch_mixer.mlp.fc1.weight
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.0.patch_mixer.mlp.fc1.bias
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.0.patch_mixer.mlp.fc2.weight
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.0.patch_mixer.mlp.fc2.bias
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.0.patch_mixer.gating_block.attn_layer.weight
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.0.patch_mixer.gating_block.attn_layer.bias
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.0.feature_mixer.norm.norm.weight
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.0.feature_mixer.norm.norm.bias
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.0.feature_mixer.mlp.fc1.weight
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.0.feature_mixer.mlp.fc1.bias
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.0.feature_mixer.mlp.fc2.weight
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.0.feature_mixer.mlp.fc2.bias
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.0.feature_mixer.gating_block.attn_layer.weight
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.0.feature_mixer.gating_block.attn_layer.bias
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.1.patch_mixer.norm.norm.weight
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.1.patch_mixer.norm.norm.bias
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.1.patch_mixer.mlp.fc1.weight
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.1.patch_mixer.mlp.fc1.bias
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.1.patch_mixer.mlp.fc2.weight
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.1.patch_mixer.mlp.fc2.bias
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.1.patch_mixer.gating_block.attn_layer.weight
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.1.patch_mixer.gating_block.attn_layer.bias
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.1.feature_mixer.norm.norm.weight
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.1.feature_mixer.norm.norm.bias
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.1.feature_mixer.mlp.fc1.weight
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.1.feature_mixer.mlp.fc1.bias
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.1.feature_mixer.mlp.fc2.weight
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.1.feature_mixer.mlp.fc2.bias
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.1.feature_mixer.gating_block.attn_layer.weight
encoder.mlp_mixer_encoder.mixers.1.mixer_layers.1.feature_mixer.gating_block.attn_layer.bias
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.0.patch_mixer.norm.norm.weight
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.0.patch_mixer.norm.norm.bias
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.0.patch_mixer.mlp.fc1.weight
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.0.patch_mixer.mlp.fc1.bias
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.0.patch_mixer.mlp.fc2.weight
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.0.patch_mixer.mlp.fc2.bias
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.0.patch_mixer.gating_block.attn_layer.weight
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.0.patch_mixer.gating_block.attn_layer.bias
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.0.feature_mixer.norm.norm.weight
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.0.feature_mixer.norm.norm.bias
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.0.feature_mixer.mlp.fc1.weight
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.0.feature_mixer.mlp.fc1.bias
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.0.feature_mixer.mlp.fc2.weight
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.0.feature_mixer.mlp.fc2.bias
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.0.feature_mixer.gating_block.attn_layer.weight
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.0.feature_mixer.gating_block.attn_layer.bias
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.1.patch_mixer.norm.norm.weight
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.1.patch_mixer.norm.norm.bias
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.1.patch_mixer.mlp.fc1.weight
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.1.patch_mixer.mlp.fc1.bias
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.1.patch_mixer.mlp.fc2.weight
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.1.patch_mixer.mlp.fc2.bias
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.1.patch_mixer.gating_block.attn_layer.weight
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.1.patch_mixer.gating_block.attn_layer.bias
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.1.feature_mixer.norm.norm.weight
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.1.feature_mixer.norm.norm.bias
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.1.feature_mixer.mlp.fc1.weight
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.1.feature_mixer.mlp.fc1.bias
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.1.feature_mixer.mlp.fc2.weight
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.1.feature_mixer.mlp.fc2.bias
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.1.feature_mixer.gating_block.attn_layer.weight
encoder.mlp_mixer_encoder.mixers.2.mixer_layers.1.feature_mixer.gating_block.attn_layer.bias
_

@vg11072001
Copy link

vg11072001 commented Dec 23, 2024

@LightingMc you need to define the mode = 'mix_channel' by default they set the config for mode to 'common_channel'.

mode = 'mix_channel'

model = get_model(
    TTM_MODEL_PATH,
    context_length=CONTEXT_LENGTH,
    prediction_length=PREDICTION_LENGTH,
    num_input_channels=1000,
    mode=mode  # Pass  'mix_channel'
).backbone

i hope this solves your issue!

@LightingMc
Copy link
Author

Thanks a lot. It worked. This is exactly what I was looking for.

I got this error, which I was expecting. So, I understand that in a multivariate case, the user has to train their own model for their given number of channels. But will these new channel_feature_mixer layers be easier to train from scratch since the rest of the model is already trained? I mean will the dataset size and training time requirements be the same as training an entire TTM model from scratch? or will they be less instead?

Some weights of TinyTimeMixerForPrediction were not initialized from the model checkpoint at ibm-granite/granite-timeseries-ttm-r2 and are newly initialized: ['backbone.encoder.mlp_mixer_encoder.mixers.0.mixer_layers.0.channel_feature_mixer.gating_block.attn_layer.bias', 'backbone.encoder.mlp_mixer_encoder.mixers.0.mixer_layers.0.channel_feature_mixer.gating_block.attn_layer.weight', 'backbone.encoder.mlp_mixer_encoder.mixers.0.mixer_layers.0.channel_feature_mixer.mlp.fc1.bias', 'backbone.encoder.mlp_mixer_encoder.mixers.0.mixer_layers.0.channel_feature_mixer.mlp.fc1.weight', 'backbone.encoder.mlp_mixer_encoder.mixers.0.mixer_layers.0.channel_feature_mixer.mlp.fc2.bias', 'backbone.encoder.mlp_mixer_encoder.mixers.0.mixer_layers.0.channel_feature_mixer.mlp.fc2.weight', 'backbone.encoder.mlp_mixer_encoder.mixers.0.mixer_layers.0.channel_feature_mixer.norm.norm.bias', 'backbone.encoder.mlp_mixer_encoder.mixers.0.mixer_layers.0.channel_feature_mixer.norm.norm.weight', 'backbone.encoder.mlp_mixer_encoder.mixers.0.mixer_layers.1.channel_feature_mixer.gating_block.attn_layer.bias', 'backbone.encoder.mlp_mixer_encoder.mixers.0.mixer_layers.1.channel_feature_mixer.gating_block.attn_layer.weight', 'backbone.encoder.mlp_mixer_encoder.mixers.0.mixer_layers.1.channel_feature_mixer.mlp.fc1.bias', 'backbone.encoder.mlp_mixer_encoder.mixers.0.mixer_layers.1.channel_feature_mixer.mlp.fc1.weight', 'backbone.encoder.mlp_mixer_encoder.mixers.0.mixer_layers.1.channel_feature_mixer.mlp.fc2.bias', 'backbone.encoder.mlp_mixer_encoder.mixers.0.mixer_layers.1.channel_feature_mixer.mlp.fc2.weight', 'backbone.encoder.mlp_mixer_encoder.mixers.0.mixer_layers.1.channel_feature_mixer.norm.norm.bias', 'backbone.encoder.mlp_mixer_encoder.mixers.0.mixer_layers.1.channel_feature_mixer.norm.norm.weight', 'backbone.encoder.mlp_mixer_encoder.mixers.1.mixer_layers.0.channel_feature_mixer.gating_block.attn_layer.bias', 'backbone.encoder.mlp_mixer_encoder.mixers.1.mixer_layers.0.channel_feature_mixer.gating_block.attn_layer.weight', 'backbone.encoder.mlp_mixer_encoder.mixers.1.mixer_layers.0.channel_feature_mixer.mlp.fc1.bias', 'backbone.encoder.mlp_mixer_encoder.mixers.1.mixer_layers.0.channel_feature_mixer.mlp.fc1.weight', 'backbone.encoder.mlp_mixer_encoder.mixers.1.mixer_layers.0.channel_feature_mixer.mlp.fc2.bias', 'backbone.encoder.mlp_mixer_encoder.mixers.1.mixer_layers.0.channel_feature_mixer.mlp.fc2.weight', 'backbone.encoder.mlp_mixer_encoder.mixers.1.mixer_layers.0.channel_feature_mixer.norm.norm.bias', 'backbone.encoder.mlp_mixer_encoder.mixers.1.mixer_layers.0.channel_feature_mixer.norm.norm.weight', 'backbone.encoder.mlp_mixer_encoder.mixers.1.mixer_layers.1.channel_feature_mixer.gating_block.attn_layer.bias', 'backbone.encoder.mlp_mixer_encoder.mixers.1.mixer_layers.1.channel_feature_mixer.gating_block.attn_layer.weight', 'backbone.encoder.mlp_mixer_encoder.mixers.1.mixer_layers.1.channel_feature_mixer.mlp.fc1.bias', 'backbone.encoder.mlp_mixer_encoder.mixers.1.mixer_layers.1.channel_feature_mixer.mlp.fc1.weight', 'backbone.encoder.mlp_mixer_encoder.mixers.1.mixer_layers.1.channel_feature_mixer.mlp.fc2.bias', 'backbone.encoder.mlp_mixer_encoder.mixers.1.mixer_layers.1.channel_feature_mixer.mlp.fc2.weight', 'backbone.encoder.mlp_mixer_encoder.mixers.1.mixer_layers.1.channel_feature_mixer.norm.norm.bias', 'backbone.encoder.mlp_mixer_encoder.mixers.1.mixer_layers.1.channel_feature_mixer.norm.norm.weight', 'backbone.encoder.mlp_mixer_encoder.mixers.2.mixer_layers.0.channel_feature_mixer.gating_block.attn_layer.bias', 'backbone.encoder.mlp_mixer_encoder.mixers.2.mixer_layers.0.channel_feature_mixer.gating_block.attn_layer.weight', 'backbone.encoder.mlp_mixer_encoder.mixers.2.mixer_layers.0.channel_feature_mixer.mlp.fc1.bias', 'backbone.encoder.mlp_mixer_encoder.mixers.2.mixer_layers.0.channel_feature_mixer.mlp.fc1.weight', 'backbone.encoder.mlp_mixer_encoder.mixers.2.mixer_layers.0.channel_feature_mixer.mlp.fc2.bias', 'backbone.encoder.mlp_mixer_encoder.mixers.2.mixer_layers.0.channel_feature_mixer.mlp.fc2.weight', 'backbone.encoder.mlp_mixer_encoder.mixers.2.mixer_layers.0.channel_feature_mixer.norm.norm.bias', 'backbone.encoder.mlp_mixer_encoder.mixers.2.mixer_layers.0.channel_feature_mixer.norm.norm.weight', 'backbone.encoder.mlp_mixer_encoder.mixers.2.mixer_layers.1.channel_feature_mixer.gating_block.attn_layer.bias', 'backbone.encoder.mlp_mixer_encoder.mixers.2.mixer_layers.1.channel_feature_mixer.gating_block.attn_layer.weight', 'backbone.encoder.mlp_mixer_encoder.mixers.2.mixer_layers.1.channel_feature_mixer.mlp.fc1.bias', 'backbone.encoder.mlp_mixer_encoder.mixers.2.mixer_layers.1.channel_feature_mixer.mlp.fc1.weight', 'backbone.encoder.mlp_mixer_encoder.mixers.2.mixer_layers.1.channel_feature_mixer.mlp.fc2.bias', 'backbone.encoder.mlp_mixer_encoder.mixers.2.mixer_layers.1.channel_feature_mixer.mlp.fc2.weight', 'backbone.encoder.mlp_mixer_encoder.mixers.2.mixer_layers.1.channel_feature_mixer.norm.norm.bias', 'backbone.encoder.mlp_mixer_encoder.mixers.2.mixer_layers.1.channel_feature_mixer.norm.norm.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

@vg11072001
Copy link

vg11072001 commented Dec 23, 2024

Yeah, well, according to their paper:

  • One of their goals is that the model should be faster in pretraining as well. So, I think if you have 3 or 4 or 5 NVIDIA A100 GPUs and a 1-billion dataset (means good amount), and considering the config parameters, you can train faster from scratch.
  • They also mentioned that if you have a small dataset, you can just proceed with fine-tuning when dealing with exogenous variables and still get good results considering channel_feature_mixer.

@wgifford
Copy link
Collaborator

wgifford commented Jan 6, 2025

A detailed example of introducing a mixing in the decoder can be found here. In this example, we still leverage the pre-trained model for the encoder (backbone) but then apply mixing in the decoder (via the decoder_mode="mix_channel" option when loading the model).

https://github.com/ibm-granite-community/granite-timeseries-cookbook/blob/main/recipes/Time_Series/Bike_Sharing_Finetuning_with_Exogenous.ipynb

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants