Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HuggingFace improvements #649

Merged
merged 7 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 167 additions & 0 deletions examples/DVCLive-HuggingFace.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install accelerate datasets dvclive evaluate 'transformers[torch]' --upgrade"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!git init -q\n",
"!git config --local user.email \"[email protected]\"\n",
"!git config --local user.name \"Your Name\"\n",
"!dvc init -q\n",
"!git commit -m \"DVC init\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"from transformers import AutoTokenizer\n",
"\n",
"dataset = load_dataset(\"imdb\")\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"distilbert-base-cased\")\n",
"\n",
"def tokenize_function(examples):\n",
" return tokenizer(examples[\"text\"], padding=\"max_length\", truncation=True)\n",
"\n",
"small_train_dataset = dataset[\"train\"].shuffle(seed=42).select(range(2000)).map(tokenize_function, batched=True)\n",
"small_eval_dataset = dataset[\"test\"].shuffle(seed=42).select(range(200)).map(tokenize_function, batched=True)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import evaluate\n",
"\n",
"metric = evaluate.load(\"f1\")\n",
"\n",
"def compute_metrics(eval_pred):\n",
" logits, labels = eval_pred\n",
" predictions = np.argmax(logits, axis=-1)\n",
" return metric.compute(predictions=predictions, references=labels)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tracking experiments with DVCLive"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from dvclive.huggingface import DVCLiveCallback\n",
"from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer\n",
"\n",
"for epochs in (5, 10, 15):\n",
" model = AutoModelForSequenceClassification.from_pretrained(\"distilbert-base-cased\", num_labels=2)\n",
" for param in model.base_model.parameters():\n",
" param.requires_grad = False\n",
"\n",
" training_args = TrainingArguments(\n",
" evaluation_strategy=\"epoch\", \n",
" learning_rate=3e-4,\n",
" logging_strategy=\"epoch\",\n",
" num_train_epochs=epochs,\n",
" output_dir=\"output\", \n",
" overwrite_output_dir=True,\n",
" load_best_model_at_end=True,\n",
" report_to=\"none\",\n",
" save_strategy=\"epoch\",\n",
" weight_decay=0.01,\n",
" )\n",
"\n",
" trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=small_train_dataset,\n",
" eval_dataset=small_eval_dataset,\n",
" compute_metrics=compute_metrics,\n",
" callbacks=[DVCLiveCallback(report=\"notebook\", save_dvc_exp=True, log_model=\"last\")],\n",
" )\n",
" trainer.train()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Comparing"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import dvc.api\n",
"import pandas as pd\n",
"\n",
"columns = [\"Experiment\", \"epoch\", \"eval.f1\"]\n",
"\n",
"df = pd.DataFrame(dvc.api.exp_show(), columns=columns)\n",
"\n",
"df.dropna(inplace=True)\n",
"df.reset_index(drop=True, inplace=True)\n",
"df\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!dvc plots diff $(dvc exp list --names-only)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import HTML\n",
"HTML(filename='./dvc_plots/index.html')"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
50 changes: 43 additions & 7 deletions src/dvclive/huggingface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# ruff: noqa: ARG002
from typing import Optional
import logging
import os
from typing import Literal, Optional, Union

from transformers import (
TrainerCallback,
Expand All @@ -12,13 +14,35 @@
from dvclive import Live
from dvclive.utils import standardize_metric_name

logger = logging.getLogger("dvclive")


class DVCLiveCallback(TrainerCallback):
def __init__(self, model_file=None, live: Optional[Live] = None, **kwargs):
def __init__(
self,
live: Optional[Live] = None,
log_model: Optional[Union[Literal["all"], bool]] = None,
**kwargs,
):
super().__init__()
self.model_file = model_file
self._log_model = log_model
self.model_file = kwargs.pop("model_file", None)
if self.model_file:
logger.warning(
"model_file is deprecated and will be removed"
" in the next major version, use log_model instead"
)
self.live = live if live is not None else Live(**kwargs)

def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
self.live.log_params(args.to_dict())

def on_log(
self,
args: TrainingArguments,
Expand All @@ -31,6 +55,16 @@ def on_log(
self.live.log_metric(standardize_metric_name(key, __name__), value)
self.live.next_step()

def on_save(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
if self._log_model == "all" and state.is_world_process_zero:
self.live.log_artifact(args.output_dir)

def on_epoch_end(
self,
args: TrainingArguments,
Expand All @@ -53,10 +87,12 @@ def on_train_end(
control: TrainerControl,
**kwargs,
):
if args.load_best_model_at_end:
trainer = Trainer(
if self._log_model is True and state.is_world_process_zero:
fake_trainer = Trainer(
args=args, model=kwargs.get("model"), tokenizer=kwargs.get("tokenizer")
)
trainer.save_model()
self.live.log_artifact(args.output_dir)
name = "best" if args.load_best_model_at_end else "last"
output_dir = os.path.join(args.output_dir, name)
fake_trainer.save_model(output_dir)
self.live.log_artifact(output_dir, name=name, type="model", copy=True)
self.live.end()
52 changes: 39 additions & 13 deletions tests/test_frameworks/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from dvclive import Live
from dvclive.plots.metric import Metric
from dvclive.serialize import load_yaml
from dvclive.utils import parse_metrics

try:
Expand Down Expand Up @@ -99,6 +100,7 @@ def args():
"foo",
evaluation_strategy="epoch",
num_train_epochs=2,
save_strategy="epoch",
)


Expand Down Expand Up @@ -131,14 +133,17 @@ def test_huggingface_integration(tmp_dir, model, args, data, mocker):
assert len(logs[os.path.join(scalars, "epoch.tsv")]) == 3
assert len(logs[os.path.join(scalars, "eval", "loss.tsv")]) == 2

params = load_yaml(live.params_file)
assert params["num_train_epochs"] == 2

def test_huggingface_model_file(tmp_dir, model, args, data, mocker):
model_path = tmp_dir / "model_hf"
model_save = mocker.spy(model, "save_pretrained")

live_callback = DVCLiveCallback(model_file=model_path)
@pytest.mark.parametrize("log_model", ["all", True, None])
@pytest.mark.parametrize("best", [True, False])
def test_huggingface_log_model(tmp_dir, model, args, data, mocker, log_model, best):
live_callback = DVCLiveCallback(log_model=log_model)
log_artifact = mocker.patch.object(live_callback.live, "log_artifact")

args.load_best_model_at_end = best
trainer = Trainer(
model,
args,
Expand All @@ -149,12 +154,21 @@ def test_huggingface_model_file(tmp_dir, model, args, data, mocker):
trainer.add_callback(live_callback)
trainer.train()

assert model_path.is_dir()
expected_call_count = {
"all": 2,
True: 1,
None: 0,
}
assert log_artifact.call_count == expected_call_count[log_model]

assert (model_path / "pytorch_model.bin").exists()
assert (model_path / "config.json").exists()
assert model_save.call_count == 2
log_artifact.assert_called_with(model_path)
if log_model == "last":
name = "best" if best else "last"
log_artifact.assert_called_with(
os.path.join(args.output_dir, name),
name=name,
type="model",
copy=True,
)


def test_huggingface_pass_logger():
Expand All @@ -164,11 +178,14 @@ def test_huggingface_pass_logger():
assert DVCLiveCallback(live=logger).live is logger


def test_huggingface_log_artifact(tmp_dir, model, args, data, mocker):
live_callback = DVCLiveCallback()
def test_huggingface_model_file(tmp_dir, model, args, data, mocker):
logger = mocker.patch("dvclive.huggingface.logger")

model_path = tmp_dir / "model_hf"

live_callback = DVCLiveCallback(model_file=model_path)
log_artifact = mocker.patch.object(live_callback.live, "log_artifact")

args.load_best_model_at_end = True
trainer = Trainer(
model,
args,
Expand All @@ -179,4 +196,13 @@ def test_huggingface_log_artifact(tmp_dir, model, args, data, mocker):
trainer.add_callback(live_callback)
trainer.train()

log_artifact.assert_called_with(trainer.args.output_dir)
assert model_path.is_dir()

assert (model_path / "pytorch_model.bin").exists()
assert (model_path / "config.json").exists()
log_artifact.assert_called_with(model_path)

logger.warning.assert_called_with(
"model_file is deprecated and will be removed"
" in the next major version, use log_model instead"
)