diff --git a/notebooks/hfdemo/patchtsmixer_HF_blog.ipynb b/notebooks/hfdemo/patchtsmixer_HF_blog.ipynb index acc69ae6..712537c5 100644 --- a/notebooks/hfdemo/patchtsmixer_HF_blog.ipynb +++ b/notebooks/hfdemo/patchtsmixer_HF_blog.ipynb @@ -79,6 +79,10 @@ "import os\n", "import random\n", "\n", + "import numpy as np\n", + "import pandas as pd\n", + "import torch\n", + "\n", "# Third Party\n", "from transformers import (\n", " EarlyStoppingCallback,\n", @@ -87,9 +91,6 @@ " Trainer,\n", " TrainingArguments,\n", ")\n", - "import numpy as np\n", - "import pandas as pd\n", - "import torch\n", "\n", "# First Party\n", "from tsfm_public.toolkit.dataset import ForecastDFDataset\n", @@ -992,9 +993,7 @@ ], "source": [ "print(\"Loading pretrained model\")\n", - "finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained(\n", - " \"patchtsmixer/electricity/model/pretrain/\"\n", - ")\n", + "finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained(\"patchtsmixer/electricity/model/pretrain/\")\n", "print(\"Done\")" ] }, @@ -1296,14 +1295,12 @@ ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": { "vscode": { "languageId": "plaintext" } }, - "outputs": [], "source": [ "By doing a simple linear probing, MSE decreased from 0.3 to 0.271 achiving the SOTA results." ] @@ -1511,9 +1508,7 @@ ], "source": [ "# Reload the model\n", - "finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained(\n", - " \"patchtsmixer/electricity/model/pretrain/\"\n", - ")\n", + "finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained(\"patchtsmixer/electricity/model/pretrain/\")\n", "finetune_forecast_trainer = Trainer(\n", " model=finetune_forecast_model,\n", " args=finetune_forecast_args,\n", @@ -1579,7 +1574,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.12.6" } }, "nbformat": 4, diff --git a/notebooks/hfdemo/ttm_getting_started.ipynb b/notebooks/hfdemo/ttm_getting_started.ipynb index b025e252..c0703150 100644 --- a/notebooks/hfdemo/ttm_getting_started.ipynb +++ b/notebooks/hfdemo/ttm_getting_started.ipynb @@ -101,6 +101,7 @@ "source": [ "import warnings\n", "\n", + "\n", "# Suppress all warnings\n", "warnings.filterwarnings(\"ignore\")" ] @@ -128,7 +129,7 @@ "# TTM Revision (1 or 2)\n", "TTM_REVISION = 2\n", "\n", - "# Context length, Or Length of the history. \n", + "# Context length, Or Length of the history.\n", "# Currently supported values are: 512/1024/1536 for TTM-R-2, and 512/1024 for TTM-R1\n", "CONTEXT_LENGTH = 512\n", "\n", @@ -165,7 +166,7 @@ " if CONTEXT_LENGTH == 512:\n", " TTM_MODEL_REVISION = \"main\"\n", " elif CONTEXT_LENGTH == 1024:\n", - " TTM_MODEL_REVISION=\"1024_96_v1\"\n", + " TTM_MODEL_REVISION = \"1024_96_v1\"\n", " else:\n", " raise ValueError(f\"Unsupported CONTEXT_LENGTH for TTM_MODEL_PATH={TTM_MODEL_PATH}\")\n", "elif TTM_REVISION == 2:\n", @@ -175,9 +176,9 @@ " if CONTEXT_LENGTH == 512:\n", " TTM_MODEL_REVISION = \"main\"\n", " elif CONTEXT_LENGTH == 1024:\n", - " TTM_MODEL_REVISION=\"1024-96-r2\"\n", + " TTM_MODEL_REVISION = \"1024-96-r2\"\n", " elif CONTEXT_LENGTH == 1536:\n", - " TTM_MODEL_REVISION=\"1536-96-r2\"\n", + " TTM_MODEL_REVISION = \"1536-96-r2\"\n", " else:\n", " raise ValueError(f\"Unsupported CONTEXT_LENGTH for TTM_MODEL_PATH={TTM_MODEL_PATH}\")\n", "else:\n", @@ -712,7 +713,9 @@ } ], "source": [ - "fewshot_finetune_eval(dataset_name=TARGET_DATASET, context_length=CONTEXT_LENGTH, batch_size=64, fewshot_percent=5, learning_rate=0.001)" + "fewshot_finetune_eval(\n", + " dataset_name=TARGET_DATASET, context_length=CONTEXT_LENGTH, batch_size=64, fewshot_percent=5, learning_rate=0.001\n", + ")" ] }, { @@ -992,7 +995,12 @@ ], "source": [ "fewshot_finetune_eval(\n", - " dataset_name=TARGET_DATASET, context_length=CONTEXT_LENGTH, batch_size=64, prediction_filter_length=48, fewshot_percent=5, learning_rate=None\n", + " dataset_name=TARGET_DATASET,\n", + " context_length=CONTEXT_LENGTH,\n", + " batch_size=64,\n", + " prediction_filter_length=48,\n", + " fewshot_percent=5,\n", + " learning_rate=None,\n", ")" ] },