Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update CPT documentation #2229

Open
wants to merge 33 commits into
base: main
Choose a base branch
from
Open

Update CPT documentation #2229

wants to merge 33 commits into from

Conversation

tsachiblau
Copy link
Contributor

Currently, the CPT model lacks a code example.
In this pull request, I provide an explanation and a code example to address this.

Thanks,
Tsachi

tsachiblau and others added 30 commits October 22, 2024 10:57
… created _cpt_forward for readability, updated copyright to 2024, renamed class to CPTPromptInit, changed config variables to lowercase and list[int], removed exception catch from tests, added assertion docs, removed batch_size=1 test, and renamed test file to test_cpt.py.
…lization in config. Renamed cpt_prompt_tuning_init to cpt_prompt_init. Changed the class from PeftConfig to PromptLearningConfig.

model: Removed check_config function.

peft_model: Fixed bugs.

tests: Added PeftTestConfigManagerForDecoderModels in test_decoder_models.py and testing_common.py.
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for quickly following up with the example. Overall, this is a very nice notebook. I have some smaller comments, please check them out.

A tip: You should be able to run ruff on the notebook if you want to auto-format it: https://docs.astral.sh/ruff/faq/#does-ruff-support-jupyter-notebooks.

A question I had, is the code for CPTDataset, CPTDataCollatorForLanguageModeling etc. specifically for this example, or is there some reference code. If the latter, it would make sense to put a link to the reference code.

In addition, I had a bit of trouble getting this notebook to run due to OOM errors after I increased the dataset size. I tried a few common steps to mitigate this but nothing worked:

  • Smaller model: bigscience/bloom-560m
  • Reduced MAX_INPUT_LENGTH to 64
  • Set per_device_eval_batch_size to 1
  • Quantized the model with bitsandbytes (both 8 and 4 bit)

Even with all those steps combined, I couldn't train with 24GB of memory. Somewhat surprisingly, I had to reduce the number of train samples to 50 for training to run. I'd think that with a batch size of 1, it shouldn't matter that much for memory if I have 50 samples or 500. Do you have any idea what could be going on here?

@@ -90,4 +90,4 @@ In CPT, only specific context token embeddings are optimized, while the rest of
To prevent overfitting and maintain stability, CPT uses controlled perturbations to limit the allowed changes to context embeddings within a defined range.
Additionally, to address the phenomenon of recency bias—where examples near the end of the context tend to be prioritized over earlier ones—CPT applies a decay loss factor.

Take a look at [Context-Aware Prompt Tuning for few-shot classification](../task_guides/cpt-few-shot-classification) for a step-by-step guide on how to train a model with CPT.
Take a look at [Example](../../../examples/cpt_finetuning/README.md) for a step-by-step guide on how to train a model with CPT.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I'm not sure if this link is going to work from the built docs. It's better if you link directly to the README, i.e. https://github.com/huggingface/peft/blob/main/examples/cpt_finetuning/README.md (of course, the link won't point anywhere right now, but after merging it will be valid).

@@ -9,6 +9,8 @@ Unless required by applicable law or agreed to in writing, software distributed
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove?

@@ -21,6 +23,9 @@ The abstract from the paper is:

*Traditional fine-tuning is effective but computationally intensive, as it requires updating billions of parameters. CPT, inspired by ICL, PT, and adversarial attacks, refines context embeddings in a parameter-efficient manner. By optimizing context tokens and applying a controlled gradient descent, CPT achieves superior accuracy across various few-shot classification tasks, showing significant improvement over existing methods such as LoRA, PT, and ICL.*

Take a look at [Example](../../../examples/cpt_finetuning/README.md) for a step-by-step guide on how to train a model with CPT.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same argument about the link.

Comment on lines +6 to +11
To overcome these challenges, we introduce Context-aware Prompt Tuning (CPT), a method inspired by ICL, Prompt Tuning (PT), and adversarial attacks.
CPT builds on the ICL strategy of concatenating examples before the input, extending it by incorporating PT-like learning to refine the context embedding through iterative optimization, extracting deeper insights from the training examples. Our approach carefully modifies specific context tokens, considering the unique structure of the examples within the context.

In addition to updating the context with PT-like optimization, CPT draws inspiration from adversarial attacks, adjusting the input based on the labels present in the context while preserving the inherent value of the user-provided data.
To ensure robustness and stability during optimization, we employ a projected gradient descent algorithm, constraining token embeddings to remain close to their original values and safeguarding the quality of the context.
Our method has demonstrated superior accuracy across multiple classification tasks using various LLM models, outperforming existing baselines and effectively addressing the overfitting challenge in few-shot learning.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this section, you use a lot of "we" and "our". Let's try to word it in a more neutral way, as for the reader it could appear like "we" refers to the PEFT maintainers :) So use "The approach" instead of "Our approach" etc.

- Refer to **Section 3.1** of the paper, where template-based tokenization is described as a critical step in structuring inputs for CPT.

#### How it Helps
Templates provide context-aware structure, ensuring the model does not overfit by utilizing structured input-output formats. Using cpt_tokens_type_mask, we gain fine-grained information about the roles of different tokens in the input-output structure. This enables the model to:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Templates provide context-aware structure, ensuring the model does not overfit by utilizing structured input-output formats. Using cpt_tokens_type_mask, we gain fine-grained information about the roles of different tokens in the input-output structure. This enables the model to:
Templates provide context-aware structure, ensuring the model does not overfit by utilizing structured input-output formats. Using `cpt_tokens_type_mask`, we gain fine-grained information about the roles of different tokens in the input-output structure. This enables the model to:

"cell_type": "markdown",
"source": [
"# CPT Training and Inference\n",
"This notebook demonstrates the training and evaluation process of Context-Aware Prompt Tuning (CPT) using the Hugging Face Trainer.\n",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be helpful to link the paper here.

"}\n",
"\n",
"# Initialize the dataset\n",
"CPT_train_dataset = CPTDataset(train_dataset, tokenizer, templates)\n",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not capitalize here: cpt_train_dataset

"source": [
"# Load a pre-trained causal language model\n",
"base_model = AutoModelForCausalLM.from_pretrained(\n",
" 'bigscience/bloom-1b7',\n",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of hard-coding the model id here, can we re-use the tokenizer_name_or_path variable? Of course, it should be renamed in this case, e.g. to model_id.

Comment on lines +234 to +237
"train_dataset = dataset['train'].select(range(4)).map(add_string_labels)\n",
"\n",
"# Subset and process the validation dataset\n",
"test_dataset = dataset['validation'].select(range(20)).map(add_string_labels)\n"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume you chose small subsets to make the notebook run fast. But maybe a little bit more would also be okay? Also, let's add a sentence here that for proper testing, users should use the whole dataset. Maybe there can even be a toggle that users can enable to use the full datasets.

Comment on lines +529 to +530
" trust_remote_code=True,\n",
" local_files_only=False,\n",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we delete these arguments?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants