Skip to content

Commit

Permalink
Update colab examples (#86)
Browse files Browse the repository at this point in the history
* update colab examples

The `from optimum.tpu` version imports models that are specifically optimized for inference.

* udpate exmaples

* update fsdp_v2. get_fsdp_training_args

* load model in bf16

* make style
  • Loading branch information
wenxindongwork authored Aug 22, 2024
1 parent 426d7be commit d4e2294
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 8 deletions.
17 changes: 14 additions & 3 deletions examples/language-modeling/gemma_tuning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@
"outputs": [],
"source": [
"from optimum.tpu import fsdp_v2\n",
"\n",
"\n",
"fsdp_v2.use_fsdp_v2()"
]
},
Expand All @@ -141,6 +143,8 @@
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"\n",
"\n",
"dataset = load_dataset(\"databricks/databricks-dolly-15k\", split=\"train\")"
]
},
Expand Down Expand Up @@ -199,6 +203,8 @@
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"\n",
"model_id = \"google/gemma-2b\"\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
Expand Down Expand Up @@ -249,8 +255,11 @@
"metadata": {},
"outputs": [],
"source": [
"from optimum.tpu import AutoModelForCausalLM\n",
"model = AutoModelForCausalLM.from_pretrained(model_id, use_cache=False)"
"import torch\n",
"from transformers import AutoModelForCausalLM\n",
"\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(model_id, use_cache=False, torch_dtype=torch.bfloat16)"
]
},
{
Expand All @@ -270,6 +279,7 @@
"source": [
"from peft import LoraConfig\n",
"\n",
"\n",
"# Set up PEFT LoRA for fine-tuning.\n",
"lora_config = LoraConfig(\n",
" r=8,\n",
Expand All @@ -293,8 +303,9 @@
"metadata": {},
"outputs": [],
"source": [
"from trl import SFTTrainer\n",
"from transformers import TrainingArguments\n",
"from trl import SFTTrainer\n",
"\n",
"\n",
"# Set up the FSDP arguments\n",
"fsdp_training_args = fsdp_v2.get_fsdp_training_args(model)\n",
Expand Down
5 changes: 2 additions & 3 deletions examples/language-modeling/llama_tuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,13 @@ Then, the tokenizer and model need to be loaded. We will choose [`meta-llama/Met

```python
import torch
from transformers import AutoTokenizer
from optimum.tpu import AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "meta-llama/Meta-Llama-3-8B"
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Add custom token for padding Llama
tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})
model = AutoModelForCausalLM.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
```

To tune the model with the [Abirate/english_quotes](https://huggingface.co/datasets/Abirate/english_quotes) dataset, you can load it and obtain the `quote` column:
Expand Down
8 changes: 6 additions & 2 deletions optimum/tpu/fsdp_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,11 @@ def get_fsdp_training_args(model: PreTrainedModel) -> Dict:
model_type = model.config.model_type
matched_model = False
if model_type == "gemma":
from transformers import GemmaForCausalLM as HFGemmaForCausalLLM

from .modeling_gemma import GemmaForCausalLM

if isinstance(model, GemmaForCausalLM):
if isinstance(model, GemmaForCausalLM) or isinstance(model, HFGemmaForCausalLLM):
logger = logging.get_logger(__name__)
from torch_xla import __version__ as xla_version
if xla_version == "2.3.0":
Expand All @@ -94,9 +96,11 @@ def get_fsdp_training_args(model: PreTrainedModel) -> Dict:
cls_to_wrap = "GemmaDecoderLayer"
matched_model = True
elif model_type == "llama":
from transformers import LlamaForCausalLM as HFLlamaForCausalLLM

from .modeling_llama import LlamaForCausalLM

if isinstance(model, LlamaForCausalLM):
if isinstance(model, LlamaForCausalLM) or isinstance(model, HFLlamaForCausalLLM):
cls_to_wrap = "LlamaDecoderLayer"
matched_model = True

Expand Down

0 comments on commit d4e2294

Please sign in to comment.