-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update CPT documentation #2229
base: main
Are you sure you want to change the base?
Update CPT documentation #2229
Conversation
… created _cpt_forward for readability, updated copyright to 2024, renamed class to CPTPromptInit, changed config variables to lowercase and list[int], removed exception catch from tests, added assertion docs, removed batch_size=1 test, and renamed test file to test_cpt.py.
…lization in config. Renamed cpt_prompt_tuning_init to cpt_prompt_init. Changed the class from PeftConfig to PromptLearningConfig. model: Removed check_config function. peft_model: Fixed bugs. tests: Added PeftTestConfigManagerForDecoderModels in test_decoder_models.py and testing_common.py.
…dded into _toctree.yml.
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for quickly following up with the example. Overall, this is a very nice notebook. I have some smaller comments, please check them out.
A tip: You should be able to run ruff on the notebook if you want to auto-format it: https://docs.astral.sh/ruff/faq/#does-ruff-support-jupyter-notebooks.
A question I had, is the code for CPTDataset
, CPTDataCollatorForLanguageModeling
etc. specifically for this example, or is there some reference code. If the latter, it would make sense to put a link to the reference code.
In addition, I had a bit of trouble getting this notebook to run due to OOM errors after I increased the dataset size. I tried a few common steps to mitigate this but nothing worked:
- Smaller model:
bigscience/bloom-560m
- Reduced
MAX_INPUT_LENGTH
to 64 - Set
per_device_eval_batch_size
to 1 - Quantized the model with bitsandbytes (both 8 and 4 bit)
Even with all those steps combined, I couldn't train with 24GB of memory. Somewhat surprisingly, I had to reduce the number of train samples to 50 for training to run. I'd think that with a batch size of 1, it shouldn't matter that much for memory if I have 50 samples or 500. Do you have any idea what could be going on here?
@@ -90,4 +90,4 @@ In CPT, only specific context token embeddings are optimized, while the rest of | |||
To prevent overfitting and maintain stability, CPT uses controlled perturbations to limit the allowed changes to context embeddings within a defined range. | |||
Additionally, to address the phenomenon of recency bias—where examples near the end of the context tend to be prioritized over earlier ones—CPT applies a decay loss factor. | |||
|
|||
Take a look at [Context-Aware Prompt Tuning for few-shot classification](../task_guides/cpt-few-shot-classification) for a step-by-step guide on how to train a model with CPT. | |||
Take a look at [Example](../../../examples/cpt_finetuning/README.md) for a step-by-step guide on how to train a model with CPT. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I'm not sure if this link is going to work from the built docs. It's better if you link directly to the README, i.e. https://github.com/huggingface/peft/blob/main/examples/cpt_finetuning/README.md
(of course, the link won't point anywhere right now, but after merging it will be valid).
@@ -9,6 +9,8 @@ Unless required by applicable law or agreed to in writing, software distributed | |||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | |||
specific language governing permissions and limitations under the License. | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove?
@@ -21,6 +23,9 @@ The abstract from the paper is: | |||
|
|||
*Traditional fine-tuning is effective but computationally intensive, as it requires updating billions of parameters. CPT, inspired by ICL, PT, and adversarial attacks, refines context embeddings in a parameter-efficient manner. By optimizing context tokens and applying a controlled gradient descent, CPT achieves superior accuracy across various few-shot classification tasks, showing significant improvement over existing methods such as LoRA, PT, and ICL.* | |||
|
|||
Take a look at [Example](../../../examples/cpt_finetuning/README.md) for a step-by-step guide on how to train a model with CPT. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same argument about the link.
To overcome these challenges, we introduce Context-aware Prompt Tuning (CPT), a method inspired by ICL, Prompt Tuning (PT), and adversarial attacks. | ||
CPT builds on the ICL strategy of concatenating examples before the input, extending it by incorporating PT-like learning to refine the context embedding through iterative optimization, extracting deeper insights from the training examples. Our approach carefully modifies specific context tokens, considering the unique structure of the examples within the context. | ||
|
||
In addition to updating the context with PT-like optimization, CPT draws inspiration from adversarial attacks, adjusting the input based on the labels present in the context while preserving the inherent value of the user-provided data. | ||
To ensure robustness and stability during optimization, we employ a projected gradient descent algorithm, constraining token embeddings to remain close to their original values and safeguarding the quality of the context. | ||
Our method has demonstrated superior accuracy across multiple classification tasks using various LLM models, outperforming existing baselines and effectively addressing the overfitting challenge in few-shot learning. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this section, you use a lot of "we" and "our". Let's try to word it in a more neutral way, as for the reader it could appear like "we" refers to the PEFT maintainers :) So use "The approach" instead of "Our approach" etc.
- Refer to **Section 3.1** of the paper, where template-based tokenization is described as a critical step in structuring inputs for CPT. | ||
|
||
#### How it Helps | ||
Templates provide context-aware structure, ensuring the model does not overfit by utilizing structured input-output formats. Using cpt_tokens_type_mask, we gain fine-grained information about the roles of different tokens in the input-output structure. This enables the model to: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Templates provide context-aware structure, ensuring the model does not overfit by utilizing structured input-output formats. Using cpt_tokens_type_mask, we gain fine-grained information about the roles of different tokens in the input-output structure. This enables the model to: | |
Templates provide context-aware structure, ensuring the model does not overfit by utilizing structured input-output formats. Using `cpt_tokens_type_mask`, we gain fine-grained information about the roles of different tokens in the input-output structure. This enables the model to: |
"cell_type": "markdown", | ||
"source": [ | ||
"# CPT Training and Inference\n", | ||
"This notebook demonstrates the training and evaluation process of Context-Aware Prompt Tuning (CPT) using the Hugging Face Trainer.\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It could be helpful to link the paper here.
"}\n", | ||
"\n", | ||
"# Initialize the dataset\n", | ||
"CPT_train_dataset = CPTDataset(train_dataset, tokenizer, templates)\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not capitalize here: cpt_train_dataset
"source": [ | ||
"# Load a pre-trained causal language model\n", | ||
"base_model = AutoModelForCausalLM.from_pretrained(\n", | ||
" 'bigscience/bloom-1b7',\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of hard-coding the model id here, can we re-use the tokenizer_name_or_path
variable? Of course, it should be renamed in this case, e.g. to model_id
.
"train_dataset = dataset['train'].select(range(4)).map(add_string_labels)\n", | ||
"\n", | ||
"# Subset and process the validation dataset\n", | ||
"test_dataset = dataset['validation'].select(range(20)).map(add_string_labels)\n" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume you chose small subsets to make the notebook run fast. But maybe a little bit more would also be okay? Also, let's add a sentence here that for proper testing, users should use the whole dataset. Maybe there can even be a toggle that users can enable to use the full datasets.
" trust_remote_code=True,\n", | ||
" local_files_only=False,\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we delete these arguments?
Currently, the CPT model lacks a code example.
In this pull request, I provide an explanation and a code example to address this.
Thanks,
Tsachi