-
Notifications
You must be signed in to change notification settings - Fork 98
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Questions about your LoRA codes #162
Comments
Your thought is absolutely right. But the multiple A doesn't mean multiple A across devices. It's just "A for query", "A for key" and "A for value" since our implementation of query_key_value is a single linear. I think the A of q, k, v should be independent, although most implementation doesn't take this into account (even in peft). |
Yeah you're right! so you used a
If I want to use LoRA with model parallel training in your sat lib, maybe I need to modify the code of LoraLinear class or even more, in order to ensure all model parallel process(, which divides a whole linear layer into several ColumnParallelLinear or RowParallelLinear layer)to use only one LoRA A & B matrix of q/k/v).Otherwise, I can also divide the LoRA A & B matrix among all model parallel processes, but this scheme may have a potential problem: the value of LoRA's hyperparameter r may not be completely divided by or even smaller than the number of processes in a model parallel group.
|
No. You don't need to modify anything. This LoraMixin supports model parallel setting. Just build your model with model parallel. And then add the mixin: model = YourModule() # This should be an SAT model
model.add_mixin("lora", LoraMixin(xxx)) |
But I read your codes carefully again, I still have question.
Above is my all questions and related thoughts, I still don't know what I said is right or wrong. If wrong, where is the bug of my thought... |
Good question. This is why this line of code contains a
For each process, B differs because it should be different. You can see it as split across devices, as indicated by
For A, it's same across devices, because the existence of Now that the initialization of A is same (which is tackled in SAT as shown below), and the gradient is same all the time. A keeps same during training. SwissArmyTransformer/sat/training/deepspeed_training.py Lines 183 to 198 in aa1277e
|
Thks a lot!!! I've understood it thorougly. Taking LoRA matrix A as an example, it takes the same input |
Cool. |
I read your LoRA codes in
sat/model/finetune/lora2.py
directory carefully, but I really have some question about the LoRA code when using Model Parallel to train/test.ColumnParallelLinear
, the original weight matrix of linear layer (matrix W) is partitioned as the following manner:W = [W_1, ..., W_p]
My thought is:
Although the original weight matrix W is partitioned across the model parallel process group, the LoRA matrix A of the original weight should have only one.
If my thought is right, there is a conflict between my thought and your LoRA code implementation: in the code of line 101, you used the partitioned weight of matrixW
to create the LoRA matrix A. If the original weight matrixW
is partitioned/divided inton
parts, there are alson
different LoRA matrix A, each of which is located in a model parallel process. What's more, then
LoRA matrix A among different model parallel processes may have absolutely different value. The same applys to the LoRA matrix B.SwissArmyTransformer/sat/model/finetune/lora2.py
Line 101 in aa1277e
input x
andLoRA matrix A
, you apply thecopy_to_model_parallel_region
function on the multiplication results. This function uses anall_reduce
collective operation on gradient during the backward time. The LoRA matrix A in every model parallel process is different from each other, i.e. the output of multiplication betweeninput x
andLoRA matrix A
is different, can we directly use theall_reduce
during the backward time?SwissArmyTransformer/sat/model/finetune/lora2.py
Line 131 in aa1277e
Looking forward for you replying!
The text was updated successfully, but these errors were encountered: