Skip to content

Commit

Permalink
Correct LoRA weights merging
Browse files Browse the repository at this point in the history
Correction of the merging code between the model's original layer weights and the LoRA model weights.
This respect the principle of LoRA to dispose of the LoRA layers once we don't plan on training it more bur more importantly allows us to save and load the model as a ".keras" file.
  • Loading branch information
BastienHot committed Mar 9, 2024
1 parent faf8ec1 commit eff5786
Showing 1 changed file with 4 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,10 @@ def call(self, inputs):
B_weights = value_lora_layer.B.kernel # (1, 12, 64) (b, c, d)
increment_weights = tf.einsum("ab,bcd->acd", A_weights, B_weights) * (ALPHA / RANK)
value_lora_layer.original_layer.kernel.assign_add(increment_weights)

# Put back in place the original layers with updated weights
self_attention_layer._query_dense = query_lora_layer.original_layer
self_attention_layer._value_dense = value_lora_layer.original_layer

"""
We are now all set to generate text with our LoRA model :).
Expand Down

0 comments on commit eff5786

Please sign in to comment.