Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rationalization #117

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 69 additions & 31 deletions actions/explanation/rationalize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pandas as pd

from actions.prediction.predict import prediction_with_custom_input
from timeout import timeout
import json
import pandas as pd
Expand Down Expand Up @@ -58,9 +59,73 @@ def formalize_output(dataset_name, text):
return return_s


@timeout(60)
def get_few_shot_result(few_shot_str, text, intro, instruction, conversation, pred_str):
return_s = ""
prompt = f"{few_shot_str}" \
f"{text}\n" \
f"{intro}\n" \
f"{instruction}\n"
print(f"[rationalize operation]\n=== PROMPT ===\n{prompt}")

input_ids = conversation.decoder.gpt_tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.to(device="cpu")
generation = conversation.decoder.gpt_model.generate(
input_ids,
max_length=2048,
no_repeat_ngram_size=2,
temperature=0.7,
top_p=0.7
)
decoded_generation = conversation.decoder.gpt_tokenizer.decode(generation[0], skip_special_tokens=True)

explanation = decoded_generation.split(instruction)[1]

return_s += "<b>Original text:</b> " + text \
+ "<br><b>Prediction:</b> " + pred_str \
+ "<br><b>Explanation:</b> " + explanation

return return_s


# @timeout(60)
def rationalize_operation(conversation, parse_text, i, simulation, data_path="./cache/", **kwargs):
# TODO: Custom input – if conversation.used and conversation.custom_input:
dataset_name = conversation.describe.get_dataset_name()

if dataset_name == "boolq":
gpt_rationales = "cache/boolq/GPT-3.5_rationales_BoolQ_val_400.csv"
elif dataset_name == "olid":
gpt_rationales = "cache/olid/GPT-4_rationales_OLID_val_132.csv"
elif dataset_name == "daily_dialog":
gpt_rationales = "cache/daily_dialog/GPT-4_rationales_DD_test_200.csv"
else:
raise NotImplementedError(f"Dataset {dataset_name} is not supported!")

if conversation.custom_input is not None and conversation.used is False:
if not conversation.decoder.gpt_parser_initialized:
return f"Rationalize operation not enabled for {conversation.decoder.parser_name}"

few_shot_str = get_few_shot_str(gpt_rationales)

res = prediction_with_custom_input(conversation)
df = pd.read_csv(f"./cache/{dataset_name}/{dataset_name}_custom_input.csv")
prediction = df["Prediction"][df['Prediction'].index.to_list()[-1]]

pred_str = conversation.class_names[prediction]

if dataset_name == "boolq":
intro = f"Answer: {pred_str}"
instruction = "Please explain the answer: "
elif dataset_name == "olid":
other_class_names = ", ".join(
[conversation.class_names[c] for c in conversation.class_names if conversation.class_names[c] not in [pred_str, "dummy"]])
intro = f"The dialogue act of this text has been classified as {pred_str} (over {other_class_names})."
instruction = "Please explain why: "
elif dataset_name == "daily_dialog":
intro = f"The tweet has been classified as {pred_str}."
instruction = "Please explain why: "

return_s = get_few_shot_result(few_shot_str, conversation.custom_input, intro, instruction, conversation, pred_str)
return return_s, 1

id_list = []
for item in parse_text:
Expand All @@ -70,13 +135,13 @@ def rationalize_operation(conversation, parse_text, i, simulation, data_path="./
except ValueError:
pass

dataset_name = conversation.describe.get_dataset_name()
dataset = conversation.temp_dataset.contents["X"]
model = conversation.get_var("model").contents

if len(conversation.temp_dataset.contents["X"]) == 0:
return "There are no instances that meet this description!", 0
results = get_results(dataset_name, data_path)

# Few-shot setting
few_shot = True

Expand All @@ -95,8 +160,6 @@ def rationalize_operation(conversation, parse_text, i, simulation, data_path="./
pred_str = label_dict[pred]
intro = f"Answer: {pred_str}"
instruction = "Please explain the answer: "
gpt_rationales = "cache/boolq/GPT-3.5_rationales_BoolQ_val_400.csv"

elif dataset_name == "daily_dialog":
text = "Dialog: '" + instance[0] + "'"
label_dict = conversation.class_names
Expand All @@ -105,16 +168,12 @@ def rationalize_operation(conversation, parse_text, i, simulation, data_path="./
[label_dict[c] for c in conversation.class_names if label_dict[c] not in [pred, "dummy"]])
intro = f"The dialogue act of this text has been classified as {pred_str} (over {other_class_names})."
instruction = "Please explain why: "
gpt_rationales = "cache/daily_dialog/GPT-4_rationales_DD_test_200.csv"

elif dataset_name == "olid":
text = "Tweet: '" + instance[0] + "'"
label_dict = {0: "non-offensive", 1: "offensive"}
pred_str = label_dict[pred]
intro = f"The tweet has been classified as {pred_str}."
instruction = "Please explain why: "
gpt_rationales = "cache/olid/GPT-4_rationales_OLID_val_132.csv"

else:
return f"Dataset {dataset_name} currently not supported by rationalize operation", 1

Expand All @@ -131,27 +190,6 @@ def rationalize_operation(conversation, parse_text, i, simulation, data_path="./
+ "<br><br><b>Prediction:</b> " + pred_str \
+ "<br><br><b>Explanation:</b> " + explanation
else:
prompt = f"{few_shot_str}" \
f"{text}\n" \
f"{intro}\n" \
f"{instruction}\n"
print(f"[rationalize operation]\n=== PROMPT ===\n{prompt}")

input_ids = conversation.decoder.gpt_tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.to(device="cpu")
generation = conversation.decoder.gpt_model.generate(
input_ids,
max_length=2048,
no_repeat_ngram_size=2,
temperature=0.7,
top_p=0.7
)
decoded_generation = conversation.decoder.gpt_tokenizer.decode(generation[0], skip_special_tokens=True)

explanation = decoded_generation.split(instruction)[1]

return_s += "<b>Original text:</b> " + text \
+ "<br><b>Prediction:</b> " + pred_str \
+ "<br><b>Explanation:</b> " + explanation
return_s += get_few_shot_result(few_shot_str, text, intro, instruction, conversation, pred_str)

return return_s, 1
32 changes: 32 additions & 0 deletions prompts/explanation/custom_input_rationalization.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
User: explain this in natural language
Parsed: rationalize [E]

User: explain this with a rationale
Parsed: rationalize [E]

User: generate a natural language explanation for this data point
Parsed: rationalize [E]

User: rationalize the prediction for it
Parsed: rationalize [E]

User: provide an explanation for this custom input text in everyday language
Parsed: rationalize [E]

User: interpret this in plain language
Parsed: rationalize [E]

User: for this new data point, deliver a natural language explanation
Parsed: rationalize [E]

User: give a rationale for the input text
Parsed: rationalize [E]

User: can you explain the model behavior on this instance in natural language?
Parsed: rationalize [E]

User: could you apply reasoning to this new id?
Parsed: rationalize [E]

User: can you simplify this for me?
Parsed: rationalize [E]