Skip to content

Commit

Permalink
Merge pull request #16 from EricLBuehler/manual_lora_weight
Browse files Browse the repository at this point in the history
Add manual (global) LoRA weight
  • Loading branch information
EricLBuehler authored Feb 12, 2024
2 parents 9c78fe6 + aec2115 commit 45900f4
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 29 deletions.
34 changes: 25 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Excerpt from [this](./examples/simple.ipynb) example.
- [Trainable parameters](README.md#trainable-parameters)
- [Setting trainability of adapters dynamically](README.md#setting-trainability-of-adapters-dynamically)
- [Setting and resetting the scaling pass value](README.md#setting-and-resetting-the-scaling-pass-value)
- [Setting and getting the global LoRA weight](README.md#setting-and-getting-the-global-lora-weight)

### Converting a model
```python
Expand Down Expand Up @@ -212,6 +213,17 @@ model.set_scaling_pass_value(0)
model.set_scaling_pass_value(None)
```

### Setting and getting the global LoRA weight
```python
model: xLoRAModel = ... # Load the model

# Multiply the output of each LoRA adapter by 2, additionally to the scalings.
model.set_global_scaling_weight(2)

# Returns 2
res = model.get_global_scaling_weight()
```

## API
The X-LoRA API is composed of 3 parts: the "Global API", the "Model API" and the "Utility API". Generally the global API is used to create X-LoRA models and the model API is used to interface with the models while the Utility API provides useful utility functions.

Expand All @@ -222,20 +234,21 @@ The X-LoRA API is composed of 3 parts: the "Global API", the "Model API" and the
- `xlora.xlora_utils.load_scalings_log`
- `xlora.xlora_utils.load_model`
- [Model API](README.md#model-api): `xLoraModel.*`
- [Scalings](README.md#scalings-1)
- [Scalings](README.md#scalings)
- `xLoraModel.disable_scalings_logging`
- `xLoraModel.enable_scalings_logging`
- `xLoraModel.flush_log_scalings`
- `xLoraModel.get_scalings_log`
- `xLoraModel.set_scaling_pass_value`
- `xLoraModel.get_latest_scalings`
- `xLoraModel.set_global_lora_weight`
- `xLoraModel.get_global_lora_weight`
- [Trainable parameters](README.md#trainable-parameters-1)
- `xLoraModel.get_nb_trainable_parameters`
- `xLoraModel.print_trainable_parameters`
- [Trainable adapters](README.md#setting-the-trainable-adapters)
- `xLoraModel.set_use_trainable_adapters`
- `xLoraModel.get_use_trainable_adapters`
- [Scalings](README.md#scalings-1)
- `xLoraModel.set_scaling_pass_value`
- `xLoraModel.get_latest_scalings`

### X-LoRA Config
The X-LoRA Config saves the full configuration of an X-LoRA model.
Expand Down Expand Up @@ -301,6 +314,14 @@ Args:
- `xLoraModel.get_scalings_log(self) -> List[Tensor]`
- Returns a shallow (only copying the list itself not the tensors) copy of the list containing the scalings log. Editing the list does not change the underlying log.
The tensors are of shape (batch_size, seq_len, n_layers, n_classes). The seq_len dim may vary with input dimension.
- `xLoraModel.set_scaling_pass_value(self, value: Union[Number, None])`
- Manually set the scalings to a specific value during the scaling pass, forever. Call this function with None to enable the default scalings. This is reflected in the config.
- `xLoraModel.get_latest_scalings(self) -> Optional[Tensor]`
- Returns the latest scalings prediction, or None if no scalings have been predicted. The tensor is of shape (batch_size, seq_len, n_layers, n_classes).
- `xLoraModel.set_global_lora_weight(self, weight: float)`
- Set the global LoRA weight, a scalar to multiply the output of each LoRA adapter by. This is reflected in the config.
- `xLoraModel.get_global_scaling_weight(self) -> float`
- Get the global LoRA weight.
#### Trainable parameters
- `xLoraModel.get_nb_trainable_parameters() -> Tuple[int, int]`
- Return a tuple `(num_trainable, num_all_params)`
Expand All @@ -311,11 +332,6 @@ Args:
- Set the trainability of the adapters. This is reflected in the config.
- `xLoraModel.get_use_trainable_adapters(self) -> bool`
- Get the trainable or not trainable state of the adapters.
#### Scalings
- `xLoraModel.set_scaling_pass_value(self, value: Union[Number, None])`
- Manually set the scalings to a specific value during the scaling pass, forever. Call this function with None to enable the default scalings. This is reflected in the config.
- `xLoraModel.get_latest_scalings(self) -> Optional[Tensor]`
- Returns the latest scalings prediction, or None if no scalings have been predicted. The tensor is of shape (batch_size, seq_len, n_layers, n_classes).
## Original paper and citation
Expand Down
28 changes: 27 additions & 1 deletion examples/simple.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,26 @@
"model_created.set_scaling_pass_value(None)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setting and getting the global LoRA weight"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Multiply the output of each LoRA adapter by 2, additionally to the scalings.\n",
"model_created.set_global_scaling_weight(2)\n",
"\n",
"# Returns 2\n",
"res = model_created.get_global_scaling_weight()"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -236,8 +256,14 @@
}
],
"metadata": {
"kernelspec": {
"display_name": "pytorch-3.12",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python"
"name": "python",
"version": "3.12.1"
}
},
"nbformat": 4,
Expand Down
10 changes: 5 additions & 5 deletions src/xlora/xlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __new__(cls):
def convert_layers_to_xlora(
base: PeftModel,
verbose: bool,
top_k_lora: Optional[int],
config: xLoRAConfig,
) -> int:
"""
Returns the number of swapped layers.
Expand All @@ -53,7 +53,7 @@ def convert_layers_to_xlora(
target=module,
target_forward=module.forward,
layer_number=total_swapped,
top_k_lora=top_k_lora,
config=config,
)
module.forward = new_layer.forward # type: ignore[method-assign]
total_swapped += 1
Expand All @@ -64,7 +64,7 @@ def convert_layers_to_xlora(
target=module,
target_forward=module.forward,
layer_number=total_swapped,
top_k_lora=top_k_lora,
config=config,
)
module.forward = new_layer.forward # type: ignore[method-assign]
total_swapped += 1
Expand All @@ -75,7 +75,7 @@ def convert_layers_to_xlora(
target=module,
target_forward=module.forward,
layer_number=total_swapped,
top_k_lora=top_k_lora,
config=config,
)
module.forward = new_layer.forward # type: ignore[method-assign]
total_swapped += 1
Expand Down Expand Up @@ -169,7 +169,7 @@ def hook(module, *args, **kwargs) -> None:
total_swapped = convert_layers_to_xlora(
model_peft,
verbose,
xlora_config.top_k_lora,
xlora_config,
)

n_classes = len(adapters)
Expand Down
3 changes: 3 additions & 0 deletions src/xlora/xlora_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class xLoRAConfig:
Make the adapters trainable.
scaling_pass_value (`float`, *optional*, defaults to 0):
Scaling pass value.
global_scaling_weight (`float`, *optional*, defaults to 1):
Weight to multiply output of each LoRA adapter by.
"""

model_type = "xlora"
Expand All @@ -60,6 +62,7 @@ class xLoRAConfig:
softmax_temperature: float = 1.0
top_k_lora: Optional[int] = None
scaling_pass_value: float = 0.0
global_scaling_weight: float = 1.0

def __post_init__(self):
if self.enable_softmax_topk and self.top_k_lora is None:
Expand Down
42 changes: 28 additions & 14 deletions src/xlora/xlora_insertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,22 @@ class xLoRALayer:
xLoRA algorithm.
"""

__slots__ = {"model", "target_forward", "target", "layer_number", "disabled", "top_k_lora"}
__slots__ = {"model", "target_forward", "target", "layer_number", "disabled", "config"}

def __init__(
self,
model: PeftModel,
target: lora.LoraLayer,
target_forward: Callable[..., Any],
layer_number: int,
top_k_lora: Optional[int],
config: xLoRAConfig,
) -> None:
self.model = model
self.target_forward = target_forward
self.target = target
self.layer_number = layer_number
self.disabled = False # TODO(EricLBuehler): Pending removal following analysis
self.top_k_lora = top_k_lora
self.config = config

@staticmethod
def apply_scalings_to_x(x: torch.Tensor, scalings_layer: torch.Tensor, adapter: int) -> torch.Tensor:
Expand All @@ -48,8 +48,8 @@ def get_maybe_topk_scalings(self) -> torch.Tensor:
# xlora_scalings = [batch_size, seq_len, n_classes]
xlora_scalings: Tensor = self.model.internal_xlora_scalings[:, :, self.layer_number, :] # type: ignore

if self.top_k_lora is not None:
_, topk_indices = torch.topk(xlora_scalings, k=self.top_k_lora, dim=1)
if self.config.top_k_lora is not None:
_, topk_indices = torch.topk(xlora_scalings, k=self.config.top_k_lora, dim=1)

# Mask the topk to True, the rest to False
mask = torch.zeros_like(xlora_scalings, dtype=torch.bool)
Expand All @@ -73,9 +73,9 @@ def __init__(
target: lora.Linear,
target_forward: Callable[..., Any],
layer_number: int,
top_k_lora: Optional[int],
config: xLoRAConfig,
) -> None:
super().__init__(model, target, target_forward, layer_number, top_k_lora)
super().__init__(model, target, target_forward, layer_number, config)

def forward(self, x: Tensor, *args: Any, **kwargs: Any) -> Tensor:
"""
Expand Down Expand Up @@ -104,7 +104,7 @@ def forward(self, x: Tensor, *args: Any, **kwargs: Any) -> Tensor:
scaling = self.target.scaling[active_adapter]
x = x.to(lora_A.weight.dtype) # type: ignore
x_mod = self.apply_scalings_to_x(x, xlora_scalings, adapter_n)
result += lora_B(lora_A(dropout(x_mod))) * scaling
result += lora_B(lora_A(dropout(x_mod))) * scaling * self.config.global_scaling_weight

result = result.to(previous_dtype)
return result
Expand All @@ -117,9 +117,9 @@ def __init__(
target: lora.Embedding,
target_forward: Callable[..., Any],
layer_number: int,
top_k_lora: Optional[int],
config: xLoRAConfig,
) -> None:
super().__init__(model, target, target_forward, layer_number, top_k_lora)
super().__init__(model, target, target_forward, layer_number, config)

def forward(self, x: Tensor, *args: Any, **kwargs: Any) -> Tensor:
"""
Expand All @@ -146,7 +146,7 @@ def forward(self, x: Tensor, *args: Any, **kwargs: Any) -> Tensor:
scaling = self.target.scaling[active_adapter]
x_mod = self.apply_scalings_to_x(x, xlora_scalings, adapter_n)
after_A = self.target._embed(x_mod, embedding_A) # type: ignore
result += (after_A @ embedding_B) * scaling
result += (after_A @ embedding_B) * scaling * self.config.global_scaling_weight

return result

Expand All @@ -158,9 +158,9 @@ def __init__(
target: lora.Conv2d,
target_forward: Callable[..., Any],
layer_number: int,
top_k_lora: Optional[int],
config: xLoRAConfig,
) -> None:
super().__init__(model, target, target_forward, layer_number, top_k_lora)
super().__init__(model, target, target_forward, layer_number, config)

def forward(self, x: Tensor, *args: Any, **kwargs: Any) -> Tensor:
"""
Expand Down Expand Up @@ -188,7 +188,7 @@ def forward(self, x: Tensor, *args: Any, **kwargs: Any) -> Tensor:
scaling = self.target.scaling[active_adapter]
x = x.to(lora_A.weight.dtype) # type: ignore
x_mod = self.apply_scalings_to_x(x, xlora_scalings, adapter_n)
result += lora_B(lora_A(dropout(x_mod))) * scaling
result += lora_B(lora_A(dropout(x_mod))) * scaling * self.config.global_scaling_weight

result = result.to(previous_dtype)
return result
Expand Down Expand Up @@ -228,6 +228,20 @@ def generate(self, *args, **kwargs):
param.requires_grad = False
return res

def set_global_scaling_weight(self, weight: float):
"""
Set the global LoRA weight, a scalar to multiply the output of each LoRA adapter by. This is reflected in the config.
"""
classifier: xLoRAClassifier = self.model.internal_xlora_classifier # type: ignore
classifier.config.global_scaling_weight = weight

def get_global_scaling_weight(self) -> float:
"""
Get the global LoRA weight.
"""
classifier: xLoRAClassifier = self.model.internal_xlora_classifier # type: ignore
return classifier.config.global_scaling_weight

def get_latest_scalings(self) -> Optional[Tensor]:
"""
Returns the latest scalings prediction, or None if no scalings have been predicted. The tensor is of shape (batch_size, seq_len, n_layers, n_classes).
Expand Down

0 comments on commit 45900f4

Please sign in to comment.