Skip to content

Commit

Permalink
[fix] bugs in prompt attack
Browse files Browse the repository at this point in the history
  • Loading branch information
Immortalise committed Feb 21, 2024
1 parent fff2d1f commit 162d1b6
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions promptbench/prompt_attack/attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,22 @@

class Attack(object):
def __init__(self, model, attack_name, dataset, prompt, eval_func, unmodifiable_words=None, verbose=True):
"""
model: the model to attack
attack_name: the name of the attack, e.g. "textfooler", "textbugger", "deepwordbug", "bertattack", "checklist", "stresstest", "semantic"
dataset: the dataset for prompt attack
prompt: the prompt to attack
eval_func: the evaluation function to evaluate the performance of a prompt, the interface is eval_func(prompt, dataset, model), in this function, you need to implement the logic to get the prediction of the model on the prompt, and evaluate the correctness of the prediction, finally, return the accuracy of the model on the prompt.
unmodifiable_words: the words that are not allowed to be attacked
verbose: whether to print the attack process
return: None
"""
self.model = model
self.attack_name = attack_name
self.dataset = dataset
self.prompt = prompt
self.eval_func = eval_func
self.goal_function = AdvPromptGoalFunction(self.model,
self.dataset,
eval_func,
Expand Down Expand Up @@ -205,8 +217,7 @@ def attack(self):
for language in prompts_dict.keys():
prompts = prompts_dict[language]
for prompt in prompts:
from ..utils import inference_total_dataset
acc = inference_total_dataset(prompt, self.model, self.dataset)
acc = self.eval_func(prompt, self.dataset, self.model)
results[prompt] = acc

return results
Expand Down

0 comments on commit 162d1b6

Please sign in to comment.