From 912a077afbe663cbb2c5362ac69038b3315567d2 Mon Sep 17 00:00:00 2001 From: daavoo Date: Wed, 2 Aug 2023 12:27:30 +0200 Subject: [PATCH 1/7] huggingface: log some parameters --- src/dvclive/huggingface.py | 17 +++++++++++++++++ tests/test_frameworks/test_huggingface.py | 6 ++++++ 2 files changed, 23 insertions(+) diff --git a/src/dvclive/huggingface.py b/src/dvclive/huggingface.py index e026643f..0cc5a209 100644 --- a/src/dvclive/huggingface.py +++ b/src/dvclive/huggingface.py @@ -19,6 +19,23 @@ def __init__(self, model_file=None, live: Optional[Live] = None, **kwargs): self.model_file = model_file self.live = live if live is not None else Live(**kwargs) + def on_train_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + for key, value in args.to_dict().items(): + if key in ( + "num_train_epochs", + "weight_decay", + "max_grad_norm", + "warmup_ratio", + "warmup_steps", + ): + self.live.log_param(key, value) + def on_log( self, args: TrainingArguments, diff --git a/tests/test_frameworks/test_huggingface.py b/tests/test_frameworks/test_huggingface.py index 39fa3afc..fcd5d814 100644 --- a/tests/test_frameworks/test_huggingface.py +++ b/tests/test_frameworks/test_huggingface.py @@ -4,6 +4,7 @@ from dvclive import Live from dvclive.plots.metric import Metric +from dvclive.serialize import load_yaml from dvclive.utils import parse_metrics try: @@ -99,6 +100,7 @@ def args(): "foo", evaluation_strategy="epoch", num_train_epochs=2, + save_strategy="epoch", ) @@ -131,6 +133,9 @@ def test_huggingface_integration(tmp_dir, model, args, data, mocker): assert len(logs[os.path.join(scalars, "epoch.tsv")]) == 3 assert len(logs[os.path.join(scalars, "eval", "loss.tsv")]) == 2 + params = load_yaml(live.params_file) + assert params["num_train_epochs"] == 2 + def test_huggingface_model_file(tmp_dir, model, args, data, mocker): model_path = tmp_dir / "model_hf" @@ -180,3 +185,4 @@ def test_huggingface_log_artifact(tmp_dir, model, args, data, mocker): trainer.train() log_artifact.assert_called_with(trainer.args.output_dir) + trainer.train() From 24fb584e7cbdc0f794f3e2ef560632428ec528ae Mon Sep 17 00:00:00 2001 From: daavoo Date: Fri, 4 Aug 2023 14:42:50 +0200 Subject: [PATCH 2/7] huggingface: Add `log_model`. - If `None` (default) will not log any artifact. - If `all` will call log_artifact with `output_dir` at each `on_save` call. - If `last` will save the model `on_train_end` and call `log_artifact` with type=model and copy=True. --- src/dvclive/huggingface.py | 32 +++++++++-------- tests/test_frameworks/test_huggingface.py | 44 ++++++++--------------- 2 files changed, 32 insertions(+), 44 deletions(-) diff --git a/src/dvclive/huggingface.py b/src/dvclive/huggingface.py index 0cc5a209..21df876e 100644 --- a/src/dvclive/huggingface.py +++ b/src/dvclive/huggingface.py @@ -1,5 +1,6 @@ # ruff: noqa: ARG002 -from typing import Optional +import os +from typing import Literal, Optional from transformers import ( TrainerCallback, @@ -14,9 +15,14 @@ class DVCLiveCallback(TrainerCallback): - def __init__(self, model_file=None, live: Optional[Live] = None, **kwargs): + def __init__( + self, + live: Optional[Live] = None, + log_model: Optional[Literal["all", "last"]] = None, + **kwargs, + ): super().__init__() - self.model_file = model_file + self._log_model = log_model self.live = live if live is not None else Live(**kwargs) def on_train_begin( @@ -48,20 +54,15 @@ def on_log( self.live.log_metric(standardize_metric_name(key, __name__), value) self.live.next_step() - def on_epoch_end( + def on_save( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs, ): - if self.model_file: - model = kwargs["model"] - model.save_pretrained(self.model_file) - tokenizer = kwargs.get("tokenizer") - if tokenizer: - tokenizer.save_pretrained(self.model_file) - self.live.log_artifact(self.model_file) + if self._log_model == "all" and state.is_world_process_zero: + self.live.log_artifact(args.output_dir) def on_train_end( self, @@ -70,10 +71,11 @@ def on_train_end( control: TrainerControl, **kwargs, ): - if args.load_best_model_at_end: - trainer = Trainer( + if self._log_model == "last" and state.is_world_process_zero: + fake_trainer = Trainer( args=args, model=kwargs.get("model"), tokenizer=kwargs.get("tokenizer") ) - trainer.save_model() - self.live.log_artifact(args.output_dir) + output_dir = os.path.join(args.output_dir, "last") + fake_trainer.save_model(output_dir) + self.live.log_artifact(output_dir, type="model", copy=True) self.live.end() diff --git a/tests/test_frameworks/test_huggingface.py b/tests/test_frameworks/test_huggingface.py index fcd5d814..ca5de12d 100644 --- a/tests/test_frameworks/test_huggingface.py +++ b/tests/test_frameworks/test_huggingface.py @@ -137,11 +137,9 @@ def test_huggingface_integration(tmp_dir, model, args, data, mocker): assert params["num_train_epochs"] == 2 -def test_huggingface_model_file(tmp_dir, model, args, data, mocker): - model_path = tmp_dir / "model_hf" - model_save = mocker.spy(model, "save_pretrained") - - live_callback = DVCLiveCallback(model_file=model_path) +@pytest.mark.parametrize("log_model", ["all", "last", None]) +def test_huggingface_log_model(tmp_dir, model, args, data, mocker, log_model): + live_callback = DVCLiveCallback(log_model=log_model) log_artifact = mocker.patch.object(live_callback.live, "log_artifact") trainer = Trainer( @@ -154,12 +152,19 @@ def test_huggingface_model_file(tmp_dir, model, args, data, mocker): trainer.add_callback(live_callback) trainer.train() - assert model_path.is_dir() + expected_call_count = { + "all": 2, + "last": 1, + None: 0, + } + assert log_artifact.call_count == expected_call_count[log_model] - assert (model_path / "pytorch_model.bin").exists() - assert (model_path / "config.json").exists() - assert model_save.call_count == 2 - log_artifact.assert_called_with(model_path) + if log_model == "last": + log_artifact.assert_called_with( + os.path.join(args.output_dir, "last"), + type="model", + copy=True, + ) def test_huggingface_pass_logger(): @@ -167,22 +172,3 @@ def test_huggingface_pass_logger(): assert DVCLiveCallback().live is not logger assert DVCLiveCallback(live=logger).live is logger - - -def test_huggingface_log_artifact(tmp_dir, model, args, data, mocker): - live_callback = DVCLiveCallback() - log_artifact = mocker.patch.object(live_callback.live, "log_artifact") - - args.load_best_model_at_end = True - trainer = Trainer( - model, - args, - train_dataset=data[0], - eval_dataset=data[1], - compute_metrics=compute_metrics, - ) - trainer.add_callback(live_callback) - trainer.train() - - log_artifact.assert_called_with(trainer.args.output_dir) - trainer.train() From 8cac907db6a210f2a0216f8e3340f40e9d90662c Mon Sep 17 00:00:00 2001 From: daavoo Date: Fri, 4 Aug 2023 15:35:44 +0200 Subject: [PATCH 3/7] examples: Add DVCLive-HuggingFace notebook --- examples/DVCLive-HuggingFace.ipynb | 167 +++++++++++++++++++++++++++++ 1 file changed, 167 insertions(+) create mode 100644 examples/DVCLive-HuggingFace.ipynb diff --git a/examples/DVCLive-HuggingFace.ipynb b/examples/DVCLive-HuggingFace.ipynb new file mode 100644 index 00000000..4adf3c09 --- /dev/null +++ b/examples/DVCLive-HuggingFace.ipynb @@ -0,0 +1,167 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install accelerate datasets dvclive evaluate 'transformers[torch]' --upgrade" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!git init -q\n", + "!git config --local user.email \"you@example.com\"\n", + "!git config --local user.name \"Your Name\"\n", + "!dvc init -q\n", + "!git commit -m \"DVC init\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "from transformers import AutoTokenizer\n", + "\n", + "dataset = load_dataset(\"imdb\")\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-cased\")\n", + "\n", + "def tokenize_function(examples):\n", + " return tokenizer(examples[\"text\"], padding=\"max_length\", truncation=True)\n", + "\n", + "small_train_dataset = dataset[\"train\"].shuffle(seed=42).select(range(2000)).map(tokenize_function, batched=True)\n", + "small_eval_dataset = dataset[\"test\"].shuffle(seed=42).select(range(200)).map(tokenize_function, batched=True)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import evaluate\n", + "\n", + "metric = evaluate.load(\"f1\")\n", + "\n", + "def compute_metrics(eval_pred):\n", + " logits, labels = eval_pred\n", + " predictions = np.argmax(logits, axis=-1)\n", + " return metric.compute(predictions=predictions, references=labels)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Tracking experiments with DVCLive" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from dvclive.huggingface import DVCLiveCallback\n", + "from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer\n", + "\n", + "for epochs in (5, 10, 15):\n", + " model = AutoModelForSequenceClassification.from_pretrained(\"distilbert-base-cased\", num_labels=2)\n", + " for param in model.base_model.parameters():\n", + " param.requires_grad = False\n", + "\n", + " training_args = TrainingArguments(\n", + " evaluation_strategy=\"epoch\", \n", + " learning_rate=3e-4,\n", + " logging_strategy=\"epoch\",\n", + " num_train_epochs=epochs,\n", + " output_dir=\"output\", \n", + " overwrite_output_dir=True,\n", + " load_best_model_at_end=True,\n", + " report_to=\"none\",\n", + " save_strategy=\"epoch\",\n", + " weight_decay=0.01,\n", + " )\n", + "\n", + " trainer = Trainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=small_train_dataset,\n", + " eval_dataset=small_eval_dataset,\n", + " compute_metrics=compute_metrics,\n", + " callbacks=[DVCLiveCallback(report=\"notebook\", save_dvc_exp=True, log_model=\"last\")],\n", + " )\n", + " trainer.train()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Comparing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import dvc.api\n", + "import pandas as pd\n", + "\n", + "columns = [\"Experiment\", \"epoch\", \"eval.f1\"]\n", + "\n", + "df = pd.DataFrame(dvc.api.exp_show(), columns=columns)\n", + "\n", + "df.dropna(inplace=True)\n", + "df.reset_index(drop=True, inplace=True)\n", + "df\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!dvc plots diff $(dvc exp list --names-only)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import HTML\n", + "HTML(filename='./dvc_plots/index.html')" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} From fbf8865f93d7a906d354d6797b8092d781037f18 Mon Sep 17 00:00:00 2001 From: daavoo Date: Mon, 7 Aug 2023 10:17:21 +0200 Subject: [PATCH 4/7] Don't cherry-pick args --- src/dvclive/huggingface.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/dvclive/huggingface.py b/src/dvclive/huggingface.py index 21df876e..8847f888 100644 --- a/src/dvclive/huggingface.py +++ b/src/dvclive/huggingface.py @@ -32,15 +32,7 @@ def on_train_begin( control: TrainerControl, **kwargs, ): - for key, value in args.to_dict().items(): - if key in ( - "num_train_epochs", - "weight_decay", - "max_grad_norm", - "warmup_ratio", - "warmup_steps", - ): - self.live.log_param(key, value) + self.live.log_params(args.to_dict()) def on_log( self, From 172290c68704c7eca59fedefdd710159f281247c Mon Sep 17 00:00:00 2001 From: daavoo Date: Tue, 15 Aug 2023 11:20:45 +0200 Subject: [PATCH 5/7] huggingface: Conditional model name based on load_best_model_at_end --- src/dvclive/huggingface.py | 9 +++++---- tests/test_frameworks/test_huggingface.py | 8 ++++++-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/dvclive/huggingface.py b/src/dvclive/huggingface.py index 8847f888..9ccbd56c 100644 --- a/src/dvclive/huggingface.py +++ b/src/dvclive/huggingface.py @@ -1,6 +1,6 @@ # ruff: noqa: ARG002 import os -from typing import Literal, Optional +from typing import Literal, Optional, Union from transformers import ( TrainerCallback, @@ -18,7 +18,7 @@ class DVCLiveCallback(TrainerCallback): def __init__( self, live: Optional[Live] = None, - log_model: Optional[Literal["all", "last"]] = None, + log_model: Optional[Union[Literal["all"], bool]] = None, **kwargs, ): super().__init__() @@ -67,7 +67,8 @@ def on_train_end( fake_trainer = Trainer( args=args, model=kwargs.get("model"), tokenizer=kwargs.get("tokenizer") ) - output_dir = os.path.join(args.output_dir, "last") + name = "best" if args.load_best_model_at_end else "last" + output_dir = os.path.join(args.output_dir, name) fake_trainer.save_model(output_dir) - self.live.log_artifact(output_dir, type="model", copy=True) + self.live.log_artifact(output_dir, name=name, type="model", copy=True) self.live.end() diff --git a/tests/test_frameworks/test_huggingface.py b/tests/test_frameworks/test_huggingface.py index ca5de12d..9687df8f 100644 --- a/tests/test_frameworks/test_huggingface.py +++ b/tests/test_frameworks/test_huggingface.py @@ -138,10 +138,12 @@ def test_huggingface_integration(tmp_dir, model, args, data, mocker): @pytest.mark.parametrize("log_model", ["all", "last", None]) -def test_huggingface_log_model(tmp_dir, model, args, data, mocker, log_model): +@pytest.mark.parametrize("best", [True, False]) +def test_huggingface_log_model(tmp_dir, model, args, data, mocker, log_model, best): live_callback = DVCLiveCallback(log_model=log_model) log_artifact = mocker.patch.object(live_callback.live, "log_artifact") + args.load_best_model_at_end = best trainer = Trainer( model, args, @@ -160,8 +162,10 @@ def test_huggingface_log_model(tmp_dir, model, args, data, mocker, log_model): assert log_artifact.call_count == expected_call_count[log_model] if log_model == "last": + name = "best" if best else "last" log_artifact.assert_called_with( - os.path.join(args.output_dir, "last"), + os.path.join(args.output_dir, name), + name=name, type="model", copy=True, ) From 5a0e75025e74b66e1bae914b73c3013cb7f5ba8e Mon Sep 17 00:00:00 2001 From: daavoo Date: Tue, 15 Aug 2023 11:33:09 +0200 Subject: [PATCH 6/7] huggingface: Keep model_file behavior --- src/dvclive/huggingface.py | 24 ++++++++++++++++++ tests/test_frameworks/test_huggingface.py | 30 +++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/src/dvclive/huggingface.py b/src/dvclive/huggingface.py index 9ccbd56c..3e715f5f 100644 --- a/src/dvclive/huggingface.py +++ b/src/dvclive/huggingface.py @@ -1,4 +1,5 @@ # ruff: noqa: ARG002 +import logging import os from typing import Literal, Optional, Union @@ -13,6 +14,8 @@ from dvclive import Live from dvclive.utils import standardize_metric_name +logger = logging.getLogger("dvclive") + class DVCLiveCallback(TrainerCallback): def __init__( @@ -23,6 +26,12 @@ def __init__( ): super().__init__() self._log_model = log_model + self.model_file = kwargs.pop("model_file", None) + if self.model_file: + logger.warning( + "model_file is deprecated and will be removed" + " in the next major version, use log_model instead" + ) self.live = live if live is not None else Live(**kwargs) def on_train_begin( @@ -56,6 +65,21 @@ def on_save( if self._log_model == "all" and state.is_world_process_zero: self.live.log_artifact(args.output_dir) + def on_epoch_end( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + if self.model_file: + model = kwargs["model"] + model.save_pretrained(self.model_file) + tokenizer = kwargs.get("tokenizer") + if tokenizer: + tokenizer.save_pretrained(self.model_file) + self.live.log_artifact(self.model_file) + def on_train_end( self, args: TrainingArguments, diff --git a/tests/test_frameworks/test_huggingface.py b/tests/test_frameworks/test_huggingface.py index 9687df8f..c6664e14 100644 --- a/tests/test_frameworks/test_huggingface.py +++ b/tests/test_frameworks/test_huggingface.py @@ -176,3 +176,33 @@ def test_huggingface_pass_logger(): assert DVCLiveCallback().live is not logger assert DVCLiveCallback(live=logger).live is logger + + +def test_huggingface_model_file(tmp_dir, model, args, data, mocker): + logger = mocker.patch("dvclive.huggingface.logger") + + model_path = tmp_dir / "model_hf" + + live_callback = DVCLiveCallback(model_file=model_path) + log_artifact = mocker.patch.object(live_callback.live, "log_artifact") + + trainer = Trainer( + model, + args, + train_dataset=data[0], + eval_dataset=data[1], + compute_metrics=compute_metrics, + ) + trainer.add_callback(live_callback) + trainer.train() + + assert model_path.is_dir() + + assert (model_path / "pytorch_model.bin").exists() + assert (model_path / "config.json").exists() + log_artifact.assert_called_with(model_path) + + logger.warning.assert_called_with( + "model_file is deprecated and will be removed" + " in the next major version, use log_model instead" + ) From 2c03182d3e0894c100c6c03ac063b396e453ed84 Mon Sep 17 00:00:00 2001 From: daavoo Date: Tue, 15 Aug 2023 12:13:30 +0200 Subject: [PATCH 7/7] Use `True` instead of `last`. --- src/dvclive/huggingface.py | 2 +- tests/test_frameworks/test_huggingface.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dvclive/huggingface.py b/src/dvclive/huggingface.py index 3e715f5f..49fa47e3 100644 --- a/src/dvclive/huggingface.py +++ b/src/dvclive/huggingface.py @@ -87,7 +87,7 @@ def on_train_end( control: TrainerControl, **kwargs, ): - if self._log_model == "last" and state.is_world_process_zero: + if self._log_model is True and state.is_world_process_zero: fake_trainer = Trainer( args=args, model=kwargs.get("model"), tokenizer=kwargs.get("tokenizer") ) diff --git a/tests/test_frameworks/test_huggingface.py b/tests/test_frameworks/test_huggingface.py index c6664e14..bc65e44a 100644 --- a/tests/test_frameworks/test_huggingface.py +++ b/tests/test_frameworks/test_huggingface.py @@ -137,7 +137,7 @@ def test_huggingface_integration(tmp_dir, model, args, data, mocker): assert params["num_train_epochs"] == 2 -@pytest.mark.parametrize("log_model", ["all", "last", None]) +@pytest.mark.parametrize("log_model", ["all", True, None]) @pytest.mark.parametrize("best", [True, False]) def test_huggingface_log_model(tmp_dir, model, args, data, mocker, log_model, best): live_callback = DVCLiveCallback(log_model=log_model) @@ -156,7 +156,7 @@ def test_huggingface_log_model(tmp_dir, model, args, data, mocker, log_model, be expected_call_count = { "all": 2, - "last": 1, + True: 1, None: 0, } assert log_artifact.call_count == expected_call_count[log_model]