Skip to content

Commit

Permalink
Merge pull request #17 from EricLBuehler/set_topk
Browse files Browse the repository at this point in the history
Add set topk
  • Loading branch information
EricLBuehler authored Feb 12, 2024
2 parents 1f7abed + ce71700 commit 07c4b00
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 0 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,15 @@ model.set_global_scaling_weight(2)
res = model.get_global_scaling_weight()
```

### Setting and getting the top-k lora value
```python
# Use the top 2 lora experts
model_created.set_topk_lora(2)

# Returns 2
res = model_created.get_topk_lora()
```

## API
The X-LoRA API is composed of 3 parts: the "Global API", the "Model API" and the "Utility API". Generally the global API is used to create X-LoRA models and the model API is used to interface with the models while the Utility API provides useful utility functions.

Expand Down Expand Up @@ -332,6 +341,11 @@ Args:
- Set the trainability of the adapters. This is reflected in the config.
- `xLoraModel.get_use_trainable_adapters(self) -> bool`
- Get the trainable or not trainable state of the adapters.
#### Top-k
- `xLoraModel.set_topk_lora(self, value: Optional[int])`
- Sparsely select the specified top_k LoRA experts instead of the default dense method. Set to None to use dense. This is reflected in the config.
- `xLoraModel.get_topk_lora(self) -> Optional[int]`
- Get the current top_k LoRA experts value.
## Original paper and citation
Expand Down
20 changes: 20 additions & 0 deletions examples/simple.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,26 @@
"model_created.print_trainable_parameters()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setting and getting the top-k lora value"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Use the top 2 lora experts\n",
"model_created.set_topk_lora(2)\n",
"\n",
"# Returns 2\n",
"res = model_created.get_topk_lora()"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
6 changes: 6 additions & 0 deletions src/xlora/xlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,12 @@ def hook(module, *args, **kwargs) -> None:
assert not hasattr(model_peft, "get_global_scaling_weight")
model_peft.get_global_scaling_weight = peft_model_wrapper.get_global_scaling_weight # type: ignore

assert not hasattr(model_peft, "set_topk_lora")
model_peft.set_topk_lora = peft_model_wrapper.set_topk_lora # type: ignore

assert not hasattr(model_peft, "get_topk_lora")
model_peft.get_topk_lora = peft_model_wrapper.get_topk_lora # type: ignore

model_peft.get_nb_trainable_parameters = peft_model_wrapper.get_nb_trainable_parameters # type: ignore

model_peft.print_trainable_parameters = peft_model_wrapper.print_trainable_parameters # type: ignore
Expand Down
14 changes: 14 additions & 0 deletions src/xlora/xlora_insertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,20 @@ def generate(self, *args, **kwargs):
param.requires_grad = False
return res

def set_topk_lora(self, value: Optional[int]):
"""
Sparsely select the specified top_k LoRA experts instead of the default dense method. Set to None to use dense. This is reflected in the config.
"""
classifier: xLoRAClassifier = self.model.internal_xlora_classifier # type: ignore
classifier.config.top_k_lora = value

def get_topk_lora(self) -> Optional[int]:
"""
Get the current top_k LoRA experts value.
"""
classifier: xLoRAClassifier = self.model.internal_xlora_classifier # type: ignore
return classifier.config.top_k_lora

def set_global_scaling_weight(self, weight: float):
"""
Set the global LoRA weight, a scalar to multiply the output of each LoRA adapter by. This is by default 1. This is reflected in the config.
Expand Down

0 comments on commit 07c4b00

Please sign in to comment.