Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
wgifford committed Oct 18, 2024
1 parent c3db412 commit 79566b3
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 19 deletions.
21 changes: 8 additions & 13 deletions notebooks/hfdemo/patchtsmixer_HF_blog.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@
"import os\n",
"import random\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"\n",
"# Third Party\n",
"from transformers import (\n",
" EarlyStoppingCallback,\n",
Expand All @@ -87,9 +91,6 @@
" Trainer,\n",
" TrainingArguments,\n",
")\n",
"import numpy as np\n",
"import pandas as pd\n",
"import torch\n",
"\n",
"# First Party\n",
"from tsfm_public.toolkit.dataset import ForecastDFDataset\n",
Expand Down Expand Up @@ -992,9 +993,7 @@
],
"source": [
"print(\"Loading pretrained model\")\n",
"finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained(\n",
" \"patchtsmixer/electricity/model/pretrain/\"\n",
")\n",
"finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained(\"patchtsmixer/electricity/model/pretrain/\")\n",
"print(\"Done\")"
]
},
Expand Down Expand Up @@ -1296,14 +1295,12 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"cell_type": "markdown",
"metadata": {
"vscode": {
"languageId": "plaintext"
}
},
"outputs": [],
"source": [
"By doing a simple linear probing, MSE decreased from 0.3 to 0.271 achiving the SOTA results."
]
Expand Down Expand Up @@ -1511,9 +1508,7 @@
],
"source": [
"# Reload the model\n",
"finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained(\n",
" \"patchtsmixer/electricity/model/pretrain/\"\n",
")\n",
"finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained(\"patchtsmixer/electricity/model/pretrain/\")\n",
"finetune_forecast_trainer = Trainer(\n",
" model=finetune_forecast_model,\n",
" args=finetune_forecast_args,\n",
Expand Down Expand Up @@ -1579,7 +1574,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
"version": "3.12.6"
}
},
"nbformat": 4,
Expand Down
20 changes: 14 additions & 6 deletions notebooks/hfdemo/ttm_getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
"source": [
"import warnings\n",
"\n",
"\n",
"# Suppress all warnings\n",
"warnings.filterwarnings(\"ignore\")"
]
Expand Down Expand Up @@ -128,7 +129,7 @@
"# TTM Revision (1 or 2)\n",
"TTM_REVISION = 2\n",
"\n",
"# Context length, Or Length of the history. \n",
"# Context length, Or Length of the history.\n",
"# Currently supported values are: 512/1024/1536 for TTM-R-2, and 512/1024 for TTM-R1\n",
"CONTEXT_LENGTH = 512\n",
"\n",
Expand Down Expand Up @@ -165,7 +166,7 @@
" if CONTEXT_LENGTH == 512:\n",
" TTM_MODEL_REVISION = \"main\"\n",
" elif CONTEXT_LENGTH == 1024:\n",
" TTM_MODEL_REVISION=\"1024_96_v1\"\n",
" TTM_MODEL_REVISION = \"1024_96_v1\"\n",
" else:\n",
" raise ValueError(f\"Unsupported CONTEXT_LENGTH for TTM_MODEL_PATH={TTM_MODEL_PATH}\")\n",
"elif TTM_REVISION == 2:\n",
Expand All @@ -175,9 +176,9 @@
" if CONTEXT_LENGTH == 512:\n",
" TTM_MODEL_REVISION = \"main\"\n",
" elif CONTEXT_LENGTH == 1024:\n",
" TTM_MODEL_REVISION=\"1024-96-r2\"\n",
" TTM_MODEL_REVISION = \"1024-96-r2\"\n",
" elif CONTEXT_LENGTH == 1536:\n",
" TTM_MODEL_REVISION=\"1536-96-r2\"\n",
" TTM_MODEL_REVISION = \"1536-96-r2\"\n",
" else:\n",
" raise ValueError(f\"Unsupported CONTEXT_LENGTH for TTM_MODEL_PATH={TTM_MODEL_PATH}\")\n",
"else:\n",
Expand Down Expand Up @@ -712,7 +713,9 @@
}
],
"source": [
"fewshot_finetune_eval(dataset_name=TARGET_DATASET, context_length=CONTEXT_LENGTH, batch_size=64, fewshot_percent=5, learning_rate=0.001)"
"fewshot_finetune_eval(\n",
" dataset_name=TARGET_DATASET, context_length=CONTEXT_LENGTH, batch_size=64, fewshot_percent=5, learning_rate=0.001\n",
")"
]
},
{
Expand Down Expand Up @@ -992,7 +995,12 @@
],
"source": [
"fewshot_finetune_eval(\n",
" dataset_name=TARGET_DATASET, context_length=CONTEXT_LENGTH, batch_size=64, prediction_filter_length=48, fewshot_percent=5, learning_rate=None\n",
" dataset_name=TARGET_DATASET,\n",
" context_length=CONTEXT_LENGTH,\n",
" batch_size=64,\n",
" prediction_filter_length=48,\n",
" fewshot_percent=5,\n",
" learning_rate=None,\n",
")"
]
},
Expand Down

0 comments on commit 79566b3

Please sign in to comment.