diff --git a/examples/DVCLive-HuggingFace.ipynb b/examples/DVCLive-HuggingFace.ipynb new file mode 100644 index 00000000..4adf3c09 --- /dev/null +++ b/examples/DVCLive-HuggingFace.ipynb @@ -0,0 +1,167 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install accelerate datasets dvclive evaluate 'transformers[torch]' --upgrade" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!git init -q\n", + "!git config --local user.email \"you@example.com\"\n", + "!git config --local user.name \"Your Name\"\n", + "!dvc init -q\n", + "!git commit -m \"DVC init\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "from transformers import AutoTokenizer\n", + "\n", + "dataset = load_dataset(\"imdb\")\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-cased\")\n", + "\n", + "def tokenize_function(examples):\n", + " return tokenizer(examples[\"text\"], padding=\"max_length\", truncation=True)\n", + "\n", + "small_train_dataset = dataset[\"train\"].shuffle(seed=42).select(range(2000)).map(tokenize_function, batched=True)\n", + "small_eval_dataset = dataset[\"test\"].shuffle(seed=42).select(range(200)).map(tokenize_function, batched=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import evaluate\n", + "\n", + "metric = evaluate.load(\"f1\")\n", + "\n", + "def compute_metrics(eval_pred):\n", + " logits, labels = eval_pred\n", + " predictions = np.argmax(logits, axis=-1)\n", + " return metric.compute(predictions=predictions, references=labels)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Tracking experiments with DVCLive" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from dvclive.huggingface import DVCLiveCallback\n", + "from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer\n", + "\n", + "for epochs in (5, 10, 15):\n", + " model = AutoModelForSequenceClassification.from_pretrained(\"distilbert-base-cased\", num_labels=2)\n", + " for param in model.base_model.parameters():\n", + " param.requires_grad = False\n", + "\n", + " training_args = TrainingArguments(\n", + " evaluation_strategy=\"epoch\", \n", + " learning_rate=3e-4,\n", + " logging_strategy=\"epoch\",\n", + " num_train_epochs=epochs,\n", + " output_dir=\"output\", \n", + " overwrite_output_dir=True,\n", + " load_best_model_at_end=True,\n", + " report_to=\"none\",\n", + " save_strategy=\"epoch\",\n", + " weight_decay=0.01,\n", + " )\n", + "\n", + " trainer = Trainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=small_train_dataset,\n", + " eval_dataset=small_eval_dataset,\n", + " compute_metrics=compute_metrics,\n", + " callbacks=[DVCLiveCallback(report=\"notebook\", save_dvc_exp=True, log_model=\"last\")],\n", + " )\n", + " trainer.train()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Comparing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import dvc.api\n", + "import pandas as pd\n", + "\n", + "columns = [\"Experiment\", \"epoch\", \"eval.f1\"]\n", + "\n", + "df = pd.DataFrame(dvc.api.exp_show(), columns=columns)\n", + "\n", + "df.dropna(inplace=True)\n", + "df.reset_index(drop=True, inplace=True)\n", + "df\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!dvc plots diff $(dvc exp list --names-only)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import HTML\n", + "HTML(filename='./dvc_plots/index.html')" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/dvclive/huggingface.py b/src/dvclive/huggingface.py index e026643f..49fa47e3 100644 --- a/src/dvclive/huggingface.py +++ b/src/dvclive/huggingface.py @@ -1,5 +1,7 @@ # ruff: noqa: ARG002 -from typing import Optional +import logging +import os +from typing import Literal, Optional, Union from transformers import ( TrainerCallback, @@ -12,13 +14,35 @@ from dvclive import Live from dvclive.utils import standardize_metric_name +logger = logging.getLogger("dvclive") + class DVCLiveCallback(TrainerCallback): - def __init__(self, model_file=None, live: Optional[Live] = None, **kwargs): + def __init__( + self, + live: Optional[Live] = None, + log_model: Optional[Union[Literal["all"], bool]] = None, + **kwargs, + ): super().__init__() - self.model_file = model_file + self._log_model = log_model + self.model_file = kwargs.pop("model_file", None) + if self.model_file: + logger.warning( + "model_file is deprecated and will be removed" + " in the next major version, use log_model instead" + ) self.live = live if live is not None else Live(**kwargs) + def on_train_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + self.live.log_params(args.to_dict()) + def on_log( self, args: TrainingArguments, @@ -31,6 +55,16 @@ def on_log( self.live.log_metric(standardize_metric_name(key, __name__), value) self.live.next_step() + def on_save( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + if self._log_model == "all" and state.is_world_process_zero: + self.live.log_artifact(args.output_dir) + def on_epoch_end( self, args: TrainingArguments, @@ -53,10 +87,12 @@ def on_train_end( control: TrainerControl, **kwargs, ): - if args.load_best_model_at_end: - trainer = Trainer( + if self._log_model is True and state.is_world_process_zero: + fake_trainer = Trainer( args=args, model=kwargs.get("model"), tokenizer=kwargs.get("tokenizer") ) - trainer.save_model() - self.live.log_artifact(args.output_dir) + name = "best" if args.load_best_model_at_end else "last" + output_dir = os.path.join(args.output_dir, name) + fake_trainer.save_model(output_dir) + self.live.log_artifact(output_dir, name=name, type="model", copy=True) self.live.end() diff --git a/tests/test_frameworks/test_huggingface.py b/tests/test_frameworks/test_huggingface.py index 39fa3afc..bc65e44a 100644 --- a/tests/test_frameworks/test_huggingface.py +++ b/tests/test_frameworks/test_huggingface.py @@ -4,6 +4,7 @@ from dvclive import Live from dvclive.plots.metric import Metric +from dvclive.serialize import load_yaml from dvclive.utils import parse_metrics try: @@ -99,6 +100,7 @@ def args(): "foo", evaluation_strategy="epoch", num_train_epochs=2, + save_strategy="epoch", ) @@ -131,14 +133,17 @@ def test_huggingface_integration(tmp_dir, model, args, data, mocker): assert len(logs[os.path.join(scalars, "epoch.tsv")]) == 3 assert len(logs[os.path.join(scalars, "eval", "loss.tsv")]) == 2 + params = load_yaml(live.params_file) + assert params["num_train_epochs"] == 2 -def test_huggingface_model_file(tmp_dir, model, args, data, mocker): - model_path = tmp_dir / "model_hf" - model_save = mocker.spy(model, "save_pretrained") - live_callback = DVCLiveCallback(model_file=model_path) +@pytest.mark.parametrize("log_model", ["all", True, None]) +@pytest.mark.parametrize("best", [True, False]) +def test_huggingface_log_model(tmp_dir, model, args, data, mocker, log_model, best): + live_callback = DVCLiveCallback(log_model=log_model) log_artifact = mocker.patch.object(live_callback.live, "log_artifact") + args.load_best_model_at_end = best trainer = Trainer( model, args, @@ -149,12 +154,21 @@ def test_huggingface_model_file(tmp_dir, model, args, data, mocker): trainer.add_callback(live_callback) trainer.train() - assert model_path.is_dir() + expected_call_count = { + "all": 2, + True: 1, + None: 0, + } + assert log_artifact.call_count == expected_call_count[log_model] - assert (model_path / "pytorch_model.bin").exists() - assert (model_path / "config.json").exists() - assert model_save.call_count == 2 - log_artifact.assert_called_with(model_path) + if log_model == "last": + name = "best" if best else "last" + log_artifact.assert_called_with( + os.path.join(args.output_dir, name), + name=name, + type="model", + copy=True, + ) def test_huggingface_pass_logger(): @@ -164,11 +178,14 @@ def test_huggingface_pass_logger(): assert DVCLiveCallback(live=logger).live is logger -def test_huggingface_log_artifact(tmp_dir, model, args, data, mocker): - live_callback = DVCLiveCallback() +def test_huggingface_model_file(tmp_dir, model, args, data, mocker): + logger = mocker.patch("dvclive.huggingface.logger") + + model_path = tmp_dir / "model_hf" + + live_callback = DVCLiveCallback(model_file=model_path) log_artifact = mocker.patch.object(live_callback.live, "log_artifact") - args.load_best_model_at_end = True trainer = Trainer( model, args, @@ -179,4 +196,13 @@ def test_huggingface_log_artifact(tmp_dir, model, args, data, mocker): trainer.add_callback(live_callback) trainer.train() - log_artifact.assert_called_with(trainer.args.output_dir) + assert model_path.is_dir() + + assert (model_path / "pytorch_model.bin").exists() + assert (model_path / "config.json").exists() + log_artifact.assert_called_with(model_path) + + logger.warning.assert_called_with( + "model_file is deprecated and will be removed" + " in the next major version, use log_model instead" + )