Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add manual (global) LoRA weight #16

Merged
merged 6 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Excerpt from [this](./examples/simple.ipynb) example.
- [Trainable parameters](README.md#trainable-parameters)
- [Setting trainability of adapters dynamically](README.md#setting-trainability-of-adapters-dynamically)
- [Setting and resetting the scaling pass value](README.md#setting-and-resetting-the-scaling-pass-value)
- [Setting and getting the global LoRA weight](README.md#setting-and-getting-the-global-lora-weight)

### Converting a model
```python
Expand Down Expand Up @@ -212,6 +213,17 @@ model.set_scaling_pass_value(0)
model.set_scaling_pass_value(None)
```

### Setting and getting the global LoRA weight
```python
model: xLoRAModel = ... # Load the model

# Multiply the output of each LoRA adapter by 2, additionally to the scalings.
model.set_global_scaling_weight(2)

# Returns 2
res = model.get_global_scaling_weight()
```

## API
The X-LoRA API is composed of 3 parts: the "Global API", the "Model API" and the "Utility API". Generally the global API is used to create X-LoRA models and the model API is used to interface with the models while the Utility API provides useful utility functions.

Expand All @@ -222,20 +234,21 @@ The X-LoRA API is composed of 3 parts: the "Global API", the "Model API" and the
- `xlora.xlora_utils.load_scalings_log`
- `xlora.xlora_utils.load_model`
- [Model API](README.md#model-api): `xLoraModel.*`
- [Scalings](README.md#scalings-1)
- [Scalings](README.md#scalings)
- `xLoraModel.disable_scalings_logging`
- `xLoraModel.enable_scalings_logging`
- `xLoraModel.flush_log_scalings`
- `xLoraModel.get_scalings_log`
- `xLoraModel.set_scaling_pass_value`
- `xLoraModel.get_latest_scalings`
- `xLoraModel.set_global_lora_weight`
- `xLoraModel.get_global_lora_weight`
- [Trainable parameters](README.md#trainable-parameters-1)
- `xLoraModel.get_nb_trainable_parameters`
- `xLoraModel.print_trainable_parameters`
- [Trainable adapters](README.md#setting-the-trainable-adapters)
- `xLoraModel.set_use_trainable_adapters`
- `xLoraModel.get_use_trainable_adapters`
- [Scalings](README.md#scalings-1)
- `xLoraModel.set_scaling_pass_value`
- `xLoraModel.get_latest_scalings`

### X-LoRA Config
The X-LoRA Config saves the full configuration of an X-LoRA model.
Expand Down Expand Up @@ -301,6 +314,14 @@ Args:
- `xLoraModel.get_scalings_log(self) -> List[Tensor]`
- Returns a shallow (only copying the list itself not the tensors) copy of the list containing the scalings log. Editing the list does not change the underlying log.
The tensors are of shape (batch_size, seq_len, n_layers, n_classes). The seq_len dim may vary with input dimension.
- `xLoraModel.set_scaling_pass_value(self, value: Union[Number, None])`
- Manually set the scalings to a specific value during the scaling pass, forever. Call this function with None to enable the default scalings. This is reflected in the config.
- `xLoraModel.get_latest_scalings(self) -> Optional[Tensor]`
- Returns the latest scalings prediction, or None if no scalings have been predicted. The tensor is of shape (batch_size, seq_len, n_layers, n_classes).
- `xLoraModel.set_global_lora_weight(self, weight: float)`
- Set the global LoRA weight, a scalar to multiply the output of each LoRA adapter by. This is reflected in the config.
- `xLoraModel.get_global_scaling_weight(self) -> float`
- Get the global LoRA weight.
#### Trainable parameters
- `xLoraModel.get_nb_trainable_parameters() -> Tuple[int, int]`
- Return a tuple `(num_trainable, num_all_params)`
Expand All @@ -311,11 +332,6 @@ Args:
- Set the trainability of the adapters. This is reflected in the config.
- `xLoraModel.get_use_trainable_adapters(self) -> bool`
- Get the trainable or not trainable state of the adapters.
#### Scalings
- `xLoraModel.set_scaling_pass_value(self, value: Union[Number, None])`
- Manually set the scalings to a specific value during the scaling pass, forever. Call this function with None to enable the default scalings. This is reflected in the config.
- `xLoraModel.get_latest_scalings(self) -> Optional[Tensor]`
- Returns the latest scalings prediction, or None if no scalings have been predicted. The tensor is of shape (batch_size, seq_len, n_layers, n_classes).

## Original paper and citation

Expand Down
28 changes: 27 additions & 1 deletion examples/simple.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,26 @@
"model_created.set_scaling_pass_value(None)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setting and getting the global LoRA weight"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Multiply the output of each LoRA adapter by 2, additionally to the scalings.\n",
"model_created.set_global_scaling_weight(2)\n",
"\n",
"# Returns 2\n",
"res = model_created.get_global_scaling_weight()"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -236,8 +256,14 @@
}
],
"metadata": {
"kernelspec": {
"display_name": "pytorch-3.12",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python"
"name": "python",
"version": "3.12.1"
}
},
"nbformat": 4,
Expand Down
10 changes: 5 additions & 5 deletions src/xlora/xlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __new__(cls):
def convert_layers_to_xlora(
base: PeftModel,
verbose: bool,
top_k_lora: Optional[int],
config: xLoRAConfig,
) -> int:
"""
Returns the number of swapped layers.
Expand All @@ -53,7 +53,7 @@ def convert_layers_to_xlora(
target=module,
target_forward=module.forward,
layer_number=total_swapped,
top_k_lora=top_k_lora,
config=config,
)
module.forward = new_layer.forward # type: ignore[method-assign]
total_swapped += 1
Expand All @@ -64,7 +64,7 @@ def convert_layers_to_xlora(
target=module,
target_forward=module.forward,
layer_number=total_swapped,
top_k_lora=top_k_lora,
config=config,
)
module.forward = new_layer.forward # type: ignore[method-assign]
total_swapped += 1
Expand All @@ -75,7 +75,7 @@ def convert_layers_to_xlora(
target=module,
target_forward=module.forward,
layer_number=total_swapped,
top_k_lora=top_k_lora,
config=config,
)
module.forward = new_layer.forward # type: ignore[method-assign]
total_swapped += 1
Expand Down Expand Up @@ -169,7 +169,7 @@ def hook(module, *args, **kwargs) -> None:
total_swapped = convert_layers_to_xlora(
model_peft,
verbose,
xlora_config.top_k_lora,
xlora_config,
)

n_classes = len(adapters)
Expand Down
3 changes: 3 additions & 0 deletions src/xlora/xlora_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class xLoRAConfig:
Make the adapters trainable.
scaling_pass_value (`float`, *optional*, defaults to 0):
Scaling pass value.
global_scaling_weight (`float`, *optional*, defaults to 1):
Weight to multiply output of each LoRA adapter by.
"""

model_type = "xlora"
Expand All @@ -60,6 +62,7 @@ class xLoRAConfig:
softmax_temperature: float = 1.0
top_k_lora: Optional[int] = None
scaling_pass_value: float = 0.0
global_scaling_weight: float = 1.0
EricLBuehler marked this conversation as resolved.
Show resolved Hide resolved

def __post_init__(self):
if self.enable_softmax_topk and self.top_k_lora is None:
Expand Down
42 changes: 28 additions & 14 deletions src/xlora/xlora_insertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,22 @@ class xLoRALayer:
xLoRA algorithm.
"""

__slots__ = {"model", "target_forward", "target", "layer_number", "disabled", "top_k_lora"}
__slots__ = {"model", "target_forward", "target", "layer_number", "disabled", "config"}

def __init__(
self,
model: PeftModel,
target: lora.LoraLayer,
target_forward: Callable[..., Any],
layer_number: int,
top_k_lora: Optional[int],
config: xLoRAConfig,
) -> None:
self.model = model
self.target_forward = target_forward
self.target = target
self.layer_number = layer_number
self.disabled = False # TODO(EricLBuehler): Pending removal following analysis
self.top_k_lora = top_k_lora
self.config = config

@staticmethod
def apply_scalings_to_x(x: torch.Tensor, scalings_layer: torch.Tensor, adapter: int) -> torch.Tensor:
Expand All @@ -48,8 +48,8 @@ def get_maybe_topk_scalings(self) -> torch.Tensor:
# xlora_scalings = [batch_size, seq_len, n_classes]
xlora_scalings: Tensor = self.model.internal_xlora_scalings[:, :, self.layer_number, :] # type: ignore

if self.top_k_lora is not None:
_, topk_indices = torch.topk(xlora_scalings, k=self.top_k_lora, dim=1)
if self.config.top_k_lora is not None:
_, topk_indices = torch.topk(xlora_scalings, k=self.config.top_k_lora, dim=1)

# Mask the topk to True, the rest to False
mask = torch.zeros_like(xlora_scalings, dtype=torch.bool)
Expand All @@ -73,9 +73,9 @@ def __init__(
target: lora.Linear,
target_forward: Callable[..., Any],
layer_number: int,
top_k_lora: Optional[int],
config: xLoRAConfig,
) -> None:
super().__init__(model, target, target_forward, layer_number, top_k_lora)
super().__init__(model, target, target_forward, layer_number, config)

def forward(self, x: Tensor, *args: Any, **kwargs: Any) -> Tensor:
"""
Expand Down Expand Up @@ -104,7 +104,7 @@ def forward(self, x: Tensor, *args: Any, **kwargs: Any) -> Tensor:
scaling = self.target.scaling[active_adapter]
x = x.to(lora_A.weight.dtype) # type: ignore
x_mod = self.apply_scalings_to_x(x, xlora_scalings, adapter_n)
result += lora_B(lora_A(dropout(x_mod))) * scaling
result += lora_B(lora_A(dropout(x_mod))) * scaling * self.config.global_scaling_weight
EricLBuehler marked this conversation as resolved.
Show resolved Hide resolved

result = result.to(previous_dtype)
return result
Expand All @@ -117,9 +117,9 @@ def __init__(
target: lora.Embedding,
target_forward: Callable[..., Any],
layer_number: int,
top_k_lora: Optional[int],
config: xLoRAConfig,
) -> None:
super().__init__(model, target, target_forward, layer_number, top_k_lora)
super().__init__(model, target, target_forward, layer_number, config)

def forward(self, x: Tensor, *args: Any, **kwargs: Any) -> Tensor:
"""
Expand All @@ -146,7 +146,7 @@ def forward(self, x: Tensor, *args: Any, **kwargs: Any) -> Tensor:
scaling = self.target.scaling[active_adapter]
x_mod = self.apply_scalings_to_x(x, xlora_scalings, adapter_n)
after_A = self.target._embed(x_mod, embedding_A) # type: ignore
result += (after_A @ embedding_B) * scaling
result += (after_A @ embedding_B) * scaling * self.config.global_scaling_weight

return result

Expand All @@ -158,9 +158,9 @@ def __init__(
target: lora.Conv2d,
target_forward: Callable[..., Any],
layer_number: int,
top_k_lora: Optional[int],
config: xLoRAConfig,
) -> None:
super().__init__(model, target, target_forward, layer_number, top_k_lora)
super().__init__(model, target, target_forward, layer_number, config)

def forward(self, x: Tensor, *args: Any, **kwargs: Any) -> Tensor:
"""
Expand Down Expand Up @@ -188,7 +188,7 @@ def forward(self, x: Tensor, *args: Any, **kwargs: Any) -> Tensor:
scaling = self.target.scaling[active_adapter]
x = x.to(lora_A.weight.dtype) # type: ignore
x_mod = self.apply_scalings_to_x(x, xlora_scalings, adapter_n)
result += lora_B(lora_A(dropout(x_mod))) * scaling
result += lora_B(lora_A(dropout(x_mod))) * scaling * self.config.global_scaling_weight

result = result.to(previous_dtype)
return result
Expand Down Expand Up @@ -228,6 +228,20 @@ def generate(self, *args, **kwargs):
param.requires_grad = False
return res

def set_global_scaling_weight(self, weight: float):
"""
Set the global LoRA weight, a scalar to multiply the output of each LoRA adapter by. This is reflected in the config.
"""
classifier: xLoRAClassifier = self.model.internal_xlora_classifier # type: ignore
classifier.config.global_scaling_weight = weight
EricLBuehler marked this conversation as resolved.
Show resolved Hide resolved

def get_global_scaling_weight(self) -> float:
"""
Get the global LoRA weight.
"""
classifier: xLoRAClassifier = self.model.internal_xlora_classifier # type: ignore
return classifier.config.global_scaling_weight

def get_latest_scalings(self) -> Optional[Tensor]:
"""
Returns the latest scalings prediction, or None if no scalings have been predicted. The tensor is of shape (batch_size, seq_len, n_layers, n_classes).
Expand Down
Loading