Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Discussion: Update dataloader to skip rows that dont require training #2344

Open
felipemello1 opened this issue Feb 5, 2025 · 4 comments
Open
Assignees
Labels
best practice Things we should be doing but aren't discussion Start a discussion triage review This issue should be discussed in weekly review

Comments

@felipemello1
Copy link
Contributor

felipemello1 commented Feb 5, 2025

#2341

When a) train_on_input=False and b) message is too long that output is truncated, there may be a batch without trainable tokens, raising an error on the loss because of division by zero.

Beyond raising an inconvenient bug, this is a waste of compute, and fixing the loss seems to be fixing a symptom, instead of the root cause.

In the dataloader, should we skip rows that dont have trainable embeddings?

@RdoubleA
Copy link
Contributor

RdoubleA commented Feb 5, 2025

It's difficult to know if this is the case fro the DataLoader perspective. You would have to catch this and raise/skip in the recipe or add additional preprocessing to your dataset to prevent this. We don't have the ability to skip an item within the SFTDataset's __getitem__.

@felipemello1
Copy link
Contributor Author

felipemello1 commented Feb 5, 2025

I see. It could be added to the recipe:

for batch in dataloader():
	if not check_batch_requires_grad(batch):
		continue

I am sure @ebsmothers will love the idea!

@EugenHotaj
Copy link
Contributor

Maybe slightly orthogonal but it would also be great to expose a way to do left truncation instead of right truncation. I think this is even more important once the "mask all previous turns" strategy is introduced.

(I'm actually a bit surprised to see that torchtune supports left truncation only since I thought right truncation was the standard practice for llama models).

@felipemello1 felipemello1 added discussion Start a discussion best practice Things we should be doing but aren't triage review This issue should be discussed in weekly review labels Feb 5, 2025
@pocca2048
Copy link

You would have to catch this and raise/skip in the recipe

I agree. Actually, I implemented skipping logic in the recipe w/o fixing the loss.

@felipemello1 felipemello1 self-assigned this Feb 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
best practice Things we should be doing but aren't discussion Start a discussion triage review This issue should be discussed in weekly review
Projects
None yet
Development

No branches or pull requests

4 participants