From 32f3b9201514e3d11e67d7b2bd9e8cc6b0b0912b Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Mon, 12 Feb 2024 06:32:34 -0500 Subject: [PATCH 1/5] Add getter, setter for topk lora --- src/xlora/xlora_insertion.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/xlora/xlora_insertion.py b/src/xlora/xlora_insertion.py index dfca700..0cc5578 100644 --- a/src/xlora/xlora_insertion.py +++ b/src/xlora/xlora_insertion.py @@ -228,6 +228,20 @@ def generate(self, *args, **kwargs): param.requires_grad = False return res + def set_topk_lora(self, value: Optional[int]): + """ + Sparsely select the specified top_k LoRA experts instead of the default dense method. Set to None to use dense. + """ + classifier: xLoRAClassifier = self.model.internal_xlora_classifier # type: ignore + classifier.config.top_k_lora = value + + def get_topk_lora(self) -> Optional[int]: + """ + Get the current top_k LoRA experts value. + """ + classifier: xLoRAClassifier = self.model.internal_xlora_classifier # type: ignore + return classifier.config.top_k_lora + def set_global_scaling_weight(self, weight: float): """ Set the global LoRA weight, a scalar to multiply the output of each LoRA adapter by. This is reflected in the config. From efa8df41361e37c5beba8c2cfec9994635f5ec5e Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Mon, 12 Feb 2024 06:35:04 -0500 Subject: [PATCH 2/5] Add to docs --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index 68a72dd..ea92c28 100644 --- a/README.md +++ b/README.md @@ -332,6 +332,11 @@ Args: - Set the trainability of the adapters. This is reflected in the config. - `xLoraModel.get_use_trainable_adapters(self) -> bool` - Get the trainable or not trainable state of the adapters. +#### Top-k +- `xLoraModel.set_topk_lora(self, value: Optional[int])` + - Sparsely select the specified top_k LoRA experts instead of the default dense method. Set to None to use dense. +- `xLoraModel.get_topk_lora(self) -> Optional[int]` + - Get the current top_k LoRA experts value. ## Original paper and citation From d0bbdeb51b688f272fbec7ae085f4f07b1b8520d Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Mon, 12 Feb 2024 06:35:32 -0500 Subject: [PATCH 3/5] Make public --- src/xlora/xlora.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/xlora/xlora.py b/src/xlora/xlora.py index 7686065..76d235d 100644 --- a/src/xlora/xlora.py +++ b/src/xlora/xlora.py @@ -216,6 +216,12 @@ def hook(module, *args, **kwargs) -> None: assert not hasattr(model_peft, "get_global_scaling_weight") model_peft.get_global_scaling_weight = peft_model_wrapper.get_global_scaling_weight # type: ignore + assert not hasattr(model_peft, "set_topk_lora") + model_peft.set_topk_lora = peft_model_wrapper.set_topk_lora # type: ignore + + assert not hasattr(model_peft, "get_topk_lora") + model_peft.get_topk_lora = peft_model_wrapper.get_topk_lora # type: ignore + model_peft.get_nb_trainable_parameters = peft_model_wrapper.get_nb_trainable_parameters # type: ignore model_peft.print_trainable_parameters = peft_model_wrapper.print_trainable_parameters # type: ignore From 6d9c520bea378508373d732dffd2acb613315e11 Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Mon, 12 Feb 2024 06:37:26 -0500 Subject: [PATCH 4/5] Add to examples --- README.md | 9 +++++++++ examples/simple.ipynb | 20 ++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/README.md b/README.md index ea92c28..f53fd03 100644 --- a/README.md +++ b/README.md @@ -224,6 +224,15 @@ model.set_global_scaling_weight(2) res = model.get_global_scaling_weight() ``` +### Setting and getting the top-k lora value +```python +# Use the top 2 lora experts +model_created.set_topk_lora(2) + +# Returns 2 +res = model_created.get_topk_lora() +``` + ## API The X-LoRA API is composed of 3 parts: the "Global API", the "Model API" and the "Utility API". Generally the global API is used to create X-LoRA models and the model API is used to interface with the models while the Utility API provides useful utility functions. diff --git a/examples/simple.ipynb b/examples/simple.ipynb index 006acbc..4ae3c72 100644 --- a/examples/simple.ipynb +++ b/examples/simple.ipynb @@ -180,6 +180,26 @@ "model_created.print_trainable_parameters()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setting and getting the top-k lora value" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Use the top 2 lora experts\n", + "model_created.set_topk_lora(2)\n", + "\n", + "# Returns 2\n", + "res = model_created.get_topk_lora()" + ] + }, { "cell_type": "markdown", "metadata": {}, From ce717007bf823e33874dd51db97f49331dc7191f Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Mon, 12 Feb 2024 06:38:19 -0500 Subject: [PATCH 5/5] Update docs --- README.md | 2 +- src/xlora/xlora_insertion.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index f53fd03..c27fe8f 100644 --- a/README.md +++ b/README.md @@ -343,7 +343,7 @@ Args: - Get the trainable or not trainable state of the adapters. #### Top-k - `xLoraModel.set_topk_lora(self, value: Optional[int])` - - Sparsely select the specified top_k LoRA experts instead of the default dense method. Set to None to use dense. + - Sparsely select the specified top_k LoRA experts instead of the default dense method. Set to None to use dense. This is reflected in the config. - `xLoraModel.get_topk_lora(self) -> Optional[int]` - Get the current top_k LoRA experts value. diff --git a/src/xlora/xlora_insertion.py b/src/xlora/xlora_insertion.py index 0cc5578..7bfc8c0 100644 --- a/src/xlora/xlora_insertion.py +++ b/src/xlora/xlora_insertion.py @@ -230,7 +230,7 @@ def generate(self, *args, **kwargs): def set_topk_lora(self, value: Optional[int]): """ - Sparsely select the specified top_k LoRA experts instead of the default dense method. Set to None to use dense. + Sparsely select the specified top_k LoRA experts instead of the default dense method. Set to None to use dense. This is reflected in the config. """ classifier: xLoRAClassifier = self.model.internal_xlora_classifier # type: ignore classifier.config.top_k_lora = value