From d4e2294ae91e868995c0717c77dee63d0a5c1e04 Mon Sep 17 00:00:00 2001 From: wenxindongwork <161090399+wenxindongwork@users.noreply.github.com> Date: Thu, 22 Aug 2024 08:37:48 -0700 Subject: [PATCH] Update colab examples (#86) * update colab examples The `from optimum.tpu` version imports models that are specifically optimized for inference. * udpate exmaples * update fsdp_v2. get_fsdp_training_args * load model in bf16 * make style --- examples/language-modeling/gemma_tuning.ipynb | 17 ++++++++++++++--- examples/language-modeling/llama_tuning.md | 5 ++--- optimum/tpu/fsdp_v2.py | 8 ++++++-- 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/examples/language-modeling/gemma_tuning.ipynb b/examples/language-modeling/gemma_tuning.ipynb index 1e810613..fc7f4717 100644 --- a/examples/language-modeling/gemma_tuning.ipynb +++ b/examples/language-modeling/gemma_tuning.ipynb @@ -118,6 +118,8 @@ "outputs": [], "source": [ "from optimum.tpu import fsdp_v2\n", + "\n", + "\n", "fsdp_v2.use_fsdp_v2()" ] }, @@ -141,6 +143,8 @@ "outputs": [], "source": [ "from datasets import load_dataset\n", + "\n", + "\n", "dataset = load_dataset(\"databricks/databricks-dolly-15k\", split=\"train\")" ] }, @@ -199,6 +203,8 @@ "outputs": [], "source": [ "from transformers import AutoTokenizer\n", + "\n", + "\n", "model_id = \"google/gemma-2b\"\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(model_id)\n", @@ -249,8 +255,11 @@ "metadata": {}, "outputs": [], "source": [ - "from optimum.tpu import AutoModelForCausalLM\n", - "model = AutoModelForCausalLM.from_pretrained(model_id, use_cache=False)" + "import torch\n", + "from transformers import AutoModelForCausalLM\n", + "\n", + "\n", + "model = AutoModelForCausalLM.from_pretrained(model_id, use_cache=False, torch_dtype=torch.bfloat16)" ] }, { @@ -270,6 +279,7 @@ "source": [ "from peft import LoraConfig\n", "\n", + "\n", "# Set up PEFT LoRA for fine-tuning.\n", "lora_config = LoraConfig(\n", " r=8,\n", @@ -293,8 +303,9 @@ "metadata": {}, "outputs": [], "source": [ - "from trl import SFTTrainer\n", "from transformers import TrainingArguments\n", + "from trl import SFTTrainer\n", + "\n", "\n", "# Set up the FSDP arguments\n", "fsdp_training_args = fsdp_v2.get_fsdp_training_args(model)\n", diff --git a/examples/language-modeling/llama_tuning.md b/examples/language-modeling/llama_tuning.md index 00b629be..9d38130a 100644 --- a/examples/language-modeling/llama_tuning.md +++ b/examples/language-modeling/llama_tuning.md @@ -47,14 +47,13 @@ Then, the tokenizer and model need to be loaded. We will choose [`meta-llama/Met ```python import torch -from transformers import AutoTokenizer -from optimum.tpu import AutoModelForCausalLM +from transformers import AutoTokenizer, AutoModelForCausalLM model_id = "meta-llama/Meta-Llama-3-8B" tokenizer = AutoTokenizer.from_pretrained(model_id) # Add custom token for padding Llama tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token}) -model = AutoModelForCausalLM.from_pretrained(model_id) +model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) ``` To tune the model with the [Abirate/english_quotes](https://huggingface.co/datasets/Abirate/english_quotes) dataset, you can load it and obtain the `quote` column: diff --git a/optimum/tpu/fsdp_v2.py b/optimum/tpu/fsdp_v2.py index 5d1d61cb..8a138793 100644 --- a/optimum/tpu/fsdp_v2.py +++ b/optimum/tpu/fsdp_v2.py @@ -82,9 +82,11 @@ def get_fsdp_training_args(model: PreTrainedModel) -> Dict: model_type = model.config.model_type matched_model = False if model_type == "gemma": + from transformers import GemmaForCausalLM as HFGemmaForCausalLLM + from .modeling_gemma import GemmaForCausalLM - if isinstance(model, GemmaForCausalLM): + if isinstance(model, GemmaForCausalLM) or isinstance(model, HFGemmaForCausalLLM): logger = logging.get_logger(__name__) from torch_xla import __version__ as xla_version if xla_version == "2.3.0": @@ -94,9 +96,11 @@ def get_fsdp_training_args(model: PreTrainedModel) -> Dict: cls_to_wrap = "GemmaDecoderLayer" matched_model = True elif model_type == "llama": + from transformers import LlamaForCausalLM as HFLlamaForCausalLLM + from .modeling_llama import LlamaForCausalLM - if isinstance(model, LlamaForCausalLM): + if isinstance(model, LlamaForCausalLM) or isinstance(model, HFLlamaForCausalLLM): cls_to_wrap = "LlamaDecoderLayer" matched_model = True