Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update colab examples #86

Merged
merged 7 commits into from
Aug 22, 2024
Merged

Conversation

wenxindongwork
Copy link
Contributor

Use transformers' AutoModelForCausalLM instead of optimum-tpu's AutoModelForCausalLM for finetuning.

The from optimum.tpu version imports models that are specifically optimized for inference. While the colab example works for smaller models, it fails with a HBM OOM error for llama3-70b (on a v4-256). Changing the following import statement solved the problem.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Copy link
Collaborator

@tengomucho tengomucho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right @wenxindongwork, at some point I added some improvements for tuning, but with FSDP the models are essentially the same in transformers and in optimum.tpu, so it might well be easier just to import the transformers version.
Note that on the other hand, this means that the model might be loaded in float32 and end up using more memory, whereas the optimum.tpu's models will load on bfloat16 by default. If you do this, you can end up with OOM on some configurations.

Comment on lines 300 to 304
"cls_to_wrap = \"GemmaDecoderLayer\"\n",
"fsdp_training_args = {\n",
" \"fsdp\": \"full_shard\",\n",
" \"fsdp_config\": fsdp_v2.get_fsdp_config(cls_to_wrap),\n",
"}\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, that was the point of using get_fsdp_training_args, that you do not need to know what classes to wrap on supported models. I would revert this bit

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the quick review! get_fsdp_training_args accepts only the optimum.tpu model class, not the transformers one. I updated the get_fsdp_training_args function, so it should now work.

examples/language-modeling/llama_tuning.md Outdated Show resolved Hide resolved
Copy link
Collaborator

@tengomucho tengomucho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, with your change it looks fine. 🤗
Last question I was wondering was about using the AutoModelForCausalLM class from transformers rather than from optimum.tpu.
AS I mentioned in a comment, in Optimum TPU models we set dtype to bfloat16, because TPU are capable of that and it uses less memory, e.g. here.
I hadn't test the scripts with float32, but if you do and it works fine please let me know, so we can just merge this.

@wenxindongwork
Copy link
Contributor Author

Just updated the examples to load the models in bf16 instead, hope that works!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@tengomucho
Copy link
Collaborator

The code style workflow is failing, can you run it locally (make style) and push again so we can merge this please?

@wenxindongwork
Copy link
Contributor Author

just installed ruff and ran make sytle. Thanks!

@tengomucho tengomucho merged commit d4e2294 into huggingface:main Aug 22, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants