From 162d1b61afef373f3607b78b22c791ecc937df34 Mon Sep 17 00:00:00 2001 From: Immortalise Date: Tue, 20 Feb 2024 20:21:17 -0800 Subject: [PATCH] [fix] bugs in prompt attack --- promptbench/prompt_attack/attack.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/promptbench/prompt_attack/attack.py b/promptbench/prompt_attack/attack.py index 0ec9b08..3d1eb61 100644 --- a/promptbench/prompt_attack/attack.py +++ b/promptbench/prompt_attack/attack.py @@ -63,10 +63,22 @@ class Attack(object): def __init__(self, model, attack_name, dataset, prompt, eval_func, unmodifiable_words=None, verbose=True): + """ + model: the model to attack + attack_name: the name of the attack, e.g. "textfooler", "textbugger", "deepwordbug", "bertattack", "checklist", "stresstest", "semantic" + dataset: the dataset for prompt attack + prompt: the prompt to attack + eval_func: the evaluation function to evaluate the performance of a prompt, the interface is eval_func(prompt, dataset, model), in this function, you need to implement the logic to get the prediction of the model on the prompt, and evaluate the correctness of the prediction, finally, return the accuracy of the model on the prompt. + unmodifiable_words: the words that are not allowed to be attacked + verbose: whether to print the attack process + + return: None + """ self.model = model self.attack_name = attack_name self.dataset = dataset self.prompt = prompt + self.eval_func = eval_func self.goal_function = AdvPromptGoalFunction(self.model, self.dataset, eval_func, @@ -205,8 +217,7 @@ def attack(self): for language in prompts_dict.keys(): prompts = prompts_dict[language] for prompt in prompts: - from ..utils import inference_total_dataset - acc = inference_total_dataset(prompt, self.model, self.dataset) + acc = self.eval_func(prompt, self.dataset, self.model) results[prompt] = acc return results