Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support for gradient_checkpointing for Sentence Transformers training #5030

Closed
wants to merge 11 commits into from

Conversation

sjrl
Copy link
Contributor

@sjrl sjrl commented May 26, 2023

Related Issues

  • fixes N/A

Proposed Changes:

  • Added support for gradient_checkpointing for Sentence Transformers training. This can greatly reduce the memory usage on a GPU allowing for much larger batch sizes at a smallish expense of training time. This is worth it for Multiple Negatives Ranking Loss (MNRL) training in Sentence Transformers because it has been shown that when using MNRL larger batch sizes (upwards of 128) can significantly boost retrieval metrics.
  • More details on gradient checkpointing can be found here

How did you test it?

  • Added two new integration tests to test SentenceTransformer retriever training

Notes for the reviewer

Checklist

  • I have read the contributors guidelines and the code of conduct
  • I have updated the related issue with new insights and changes
  • I added tests that demonstrate the correct behavior of the change
  • I've used one of the conventional commit types for my PR title: fix:, feat:, build:, chore:, ci:, docs:, style:, refactor:, perf:, test:.
  • I documented my code
  • I ran pre-commit hooks and fixed any issue

@sjrl sjrl requested a review from a team as a code owner May 26, 2023 08:51
@sjrl sjrl requested review from ZanSara and removed request for a team May 26, 2023 08:51
@coveralls
Copy link
Collaborator

coveralls commented May 26, 2023

Pull Request Test Coverage Report for Build 5256397223

  • 0 of 0 changed or added relevant lines in 0 files are covered.
  • 133 unchanged lines in 4 files lost coverage.
  • Overall coverage increased (+0.07%) to 42.0%

Files with Coverage Reduction New Missed Lines %
nodes/prompt/invocation_layer/cohere.py 4 75.61%
nodes/prompt/invocation_layer/hugging_face.py 7 87.2%
nodes/retriever/dense.py 55 25.98%
nodes/retriever/_embedding_encoder.py 67 36.47%
Totals Coverage Status
Change from base Build 5253128824: 0.07%
Covered Lines: 9421
Relevant Lines: 22431

💛 - Coveralls

Copy link
Contributor

@ZanSara ZanSara left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit puzzled by one bit of code specifically, let's talk about it before moving forward.

In addition, please next time open an issue before the PR, so we can have a discussion on the issue about the feature's design, and on the PR about the implementation details. I'm curious where the demand for this specific feature comes from 🙂

haystack/nodes/retriever/_embedding_encoder.py Outdated Show resolved Hide resolved
haystack/nodes/retriever/dense.py Outdated Show resolved Hide resolved
haystack/nodes/retriever/dense.py Outdated Show resolved Hide resolved
test/nodes/test_retriever.py Outdated Show resolved Hide resolved
test/nodes/test_retriever.py Outdated Show resolved Hide resolved
@sjrl sjrl force-pushed the update-st-training branch from c553e33 to 72156d1 Compare June 6, 2023 14:05
@sjrl sjrl requested a review from ZanSara June 6, 2023 14:46
Copy link
Contributor

@ZanSara ZanSara left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more small thing to fix and this is ready to merge

test/nodes/test_retriever.py Outdated Show resolved Hide resolved
@sjrl sjrl requested a review from ZanSara June 7, 2023 10:54
Comment on lines +655 to +659
retriever = EmbeddingRetriever(
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
model_format="sentence_transformers",
use_gpu=False,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm really sorry I didn't notice it earlier 🙈 but is this initialization loading the model? If yes, can we mock the model as well so that we don't actually download and load the weights from hf?

If we can do that we'll be able to add the @pytest.mark.unit marker to this test. Otherwise it won't run in CI.

@masci masci changed the base branch from main to v1.x November 24, 2023 12:04
@masci masci added the 1.x label Nov 24, 2023
@ZanSara
Copy link
Contributor

ZanSara commented Nov 29, 2023

As agreed with @sjrl we can close this for now and re-implement it in v2 later, if still relevant

@sjrl sjrl closed this Dec 11, 2023
@sjrl sjrl deleted the update-st-training branch June 3, 2024 08:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants