-
Notifications
You must be signed in to change notification settings - Fork 133
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Code and Demo for Redirection (#250)
Co-authored-by: seanzhangkx8 <[email protected]>
- Loading branch information
1 parent
af8adcd
commit 39e84fa
Showing
16 changed files
with
1,280 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .likelihoodModel import * | ||
from .redirection import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from peft import LoraConfig | ||
from transformers import BitsAndBytesConfig | ||
import torch | ||
|
||
DEFAULT_BNB_CONFIG = BitsAndBytesConfig( | ||
load_in_4bit=True, | ||
bnb_4bit_quant_type="nf4", | ||
bnb_4bit_compute_dtype=torch.bfloat16, | ||
) | ||
|
||
DEFAULT_LORA_CONFIG = LoraConfig( | ||
r=16, | ||
lora_dropout=0.05, | ||
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"], | ||
task_type="CAUSAL_LM", | ||
) | ||
|
||
DEFAULT_TRAIN_CONFIG = { | ||
"output_dir": "checkpoints", | ||
"logging_dir": "logging", | ||
"logging_steps": 25, | ||
"eval_steps": 50, | ||
"num_train_epochs": 2, | ||
"per_device_train_batch_size": 1, | ||
"per_device_eval_batch_size": 1, | ||
"evaluation_strategy": "steps", | ||
"save_strategy": "steps", | ||
"save_steps": 50, | ||
"optim": "paged_adamw_8bit", | ||
"learning_rate": 2e-4, | ||
"max_seq_length": 512, | ||
"load_best_model_at_end": True, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
from .preprocessing import default_speaker_prefixes | ||
|
||
|
||
def default_previous_context_selector(convo): | ||
""" | ||
Default function to compute previous contexts for Redirection. For | ||
actual contexts, uses the current utterance and immediate previous | ||
utterance by speaker with different role. For reference contexts, uses | ||
the previous utterance by the same role speaker instead of the current | ||
utterance as a point of reference. | ||
:param convo: ConvoKit Conversation object to compute contexts over | ||
:return: Tuple of actual contexts and reference contexts | ||
""" | ||
actual_contexts = {} | ||
reference_contexts = {} | ||
utts = [utt for utt in convo.iter_utterances()] | ||
roles = list({utt.meta["role"] for utt in utts}) | ||
assert len(roles) == 2 | ||
spk_prefixes = default_speaker_prefixes(roles) | ||
role_to_prefix = {roles[i]: spk_prefixes[i] for i in range(len(roles))} | ||
role_1 = roles[0] | ||
role_2 = roles[1] | ||
prev_spk = None | ||
prev_1, prev_2, cur_1, cur_2 = None, None, None, None | ||
for i, utt in enumerate(utts): | ||
utt_text = utt.text | ||
cur_spk = utt.meta["role"] | ||
if prev_spk is not None and cur_spk != prev_spk: | ||
if role_2 in cur_spk: | ||
prev_1 = cur_1 | ||
else: | ||
prev_2 = cur_2 | ||
|
||
if prev_1 and prev_2 is not None: | ||
if role_2 in cur_spk: | ||
prev = prev_1 | ||
prev_prev = prev_2 | ||
else: | ||
prev = prev_2 | ||
prev_prev = prev_1 | ||
|
||
prev_prev_text, prev_prev_role = prev_prev | ||
prev_text, prev_role = prev | ||
|
||
prev_prev_data = role_to_prefix[prev_prev_role] + prev_prev_text | ||
prev_data = role_to_prefix[prev_role] + prev_text | ||
cur_data = role_to_prefix[cur_spk] + utt_text | ||
|
||
actual_contexts[utt.id] = [prev_data, cur_data] | ||
reference_contexts[utt.id] = [prev_data, prev_prev_data] | ||
|
||
if role_1 in cur_spk: | ||
cur_1 = (utt_text, cur_spk) | ||
if role_2 in cur_spk: | ||
cur_2 = (utt_text, cur_spk) | ||
|
||
prev_spk = cur_spk | ||
|
||
return actual_contexts, reference_contexts | ||
|
||
|
||
def default_future_context_selector(convo): | ||
""" | ||
Default function to compute future contexts for Redirection. Uses the | ||
immediate successor utterance from a different role speaker. | ||
:param convo: ConvoKit Conversation object to compute contexts over | ||
:return: Dictionary of Utterance id to future contexts | ||
""" | ||
future_contexts = {} | ||
cur_1 = None | ||
cur_2 = None | ||
utts = [utt for utt in convo.iter_utterances()] | ||
roles = list({utt.meta["role"] for utt in utts}) | ||
assert len(roles) == 2 | ||
spk_prefixes = default_speaker_prefixes(roles) | ||
role_to_prefix = {roles[i]: spk_prefixes[i] for i in range(len(roles))} | ||
role_1 = roles[0] | ||
role_2 = roles[1] | ||
n = len(utts) | ||
for i in range(n - 1, -1, -1): | ||
utt = utts[i] | ||
utt_text = utt.text | ||
cur_spk = utt.meta["role"] | ||
if role_2 in cur_spk: | ||
cur_2 = (utt_text, cur_spk) | ||
if cur_1 is not None: | ||
future_text, future_role = cur_1 | ||
future_data = role_to_prefix[future_role] + future_text | ||
future_contexts[utt.id] = [future_data] | ||
else: | ||
cur_1 = (utt_text, cur_spk) | ||
if cur_2 is not None: | ||
future_text, future_role = cur_2 | ||
future_data = role_to_prefix[future_role] + future_text | ||
future_contexts[utt.id] = [future_data] | ||
return future_contexts |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
from .likelihoodModel import LikelihoodModel | ||
import torch | ||
from peft import LoraConfig, get_peft_model, AutoPeftModelForCausalLM, PeftModel | ||
from transformers import ( | ||
AutoTokenizer, | ||
AutoModelForCausalLM, | ||
BitsAndBytesConfig, | ||
DataCollatorForLanguageModeling, | ||
TrainingArguments, | ||
) | ||
from trl import SFTTrainer | ||
from .config import DEFAULT_TRAIN_CONFIG, DEFAULT_BNB_CONFIG, DEFAULT_LORA_CONFIG | ||
|
||
|
||
class GemmaLikelihoodModel(LikelihoodModel): | ||
""" | ||
Likelihood model supported by Gemma, used to compute utterance likelihoods. | ||
:param hf_token: Huggingface authentication token | ||
:param model_id: Gemma model id version | ||
:param device: Device to use | ||
:param train_config: Training config for fine-tuning | ||
:param bnb_config: bitsandbytes config for quantization | ||
:param lora_config: LoRA config for fine-tuning | ||
""" | ||
|
||
def __init__( | ||
self, | ||
hf_token, | ||
model_id="google/gemma-2b", | ||
device="cuda" if torch.cuda.is_available() else "cpu", | ||
train_config=DEFAULT_TRAIN_CONFIG, | ||
bnb_config=DEFAULT_BNB_CONFIG, | ||
lora_config=DEFAULT_LORA_CONFIG, | ||
): | ||
self.tokenizer = AutoTokenizer.from_pretrained( | ||
model_id, token=hf_token, padding_side="right" | ||
) | ||
self.model = AutoModelForCausalLM.from_pretrained( | ||
model_id, quantization_config=bnb_config, device_map="auto", token=hf_token | ||
) | ||
self.data_collator = DataCollatorForLanguageModeling(tokenizer=self.tokenizer, mlm=False) | ||
self.hf_token = hf_token | ||
self.device = device | ||
self.train_config = train_config | ||
self.lora_config = lora_config | ||
self.bnb_config = bnb_config | ||
self.max_length = self.train_config["max_seq_length"] | ||
|
||
def name(self): | ||
return self.__class__.name | ||
|
||
def fit(self, train_data, val_data): | ||
""" | ||
Fine-tunes the Gemma model on the provided `train_data` and validates | ||
on `val_data`. | ||
:param train_data: Data to fine-tune model | ||
:param val_data: Data to validate model | ||
""" | ||
training_args = TrainingArguments( | ||
output_dir=self.train_config["output_dir"], | ||
logging_dir=self.train_config["logging_dir"], | ||
logging_steps=self.train_config["logging_steps"], | ||
eval_steps=self.train_config["eval_steps"], | ||
num_train_epochs=self.train_config["num_train_epochs"], | ||
per_device_train_batch_size=self.train_config["per_device_train_batch_size"], | ||
per_device_eval_batch_size=self.train_config["per_device_eval_batch_size"], | ||
evaluation_strategy=self.train_config["evaluation_strategy"], | ||
save_strategy=self.train_config["save_strategy"], | ||
save_steps=self.train_config["save_steps"], | ||
optim=self.train_config["optim"], | ||
learning_rate=self.train_config["learning_rate"], | ||
load_best_model_at_end=self.train_config["load_best_model_at_end"], | ||
) | ||
|
||
trainer = SFTTrainer( | ||
model=self.model, | ||
train_dataset=train_data, | ||
eval_dataset=val_data, | ||
args=training_args, | ||
peft_config=self.lora_config, | ||
max_seq_length=self.train_config["max_seq_length"], | ||
) | ||
trainer.train() | ||
|
||
def _calculate_likelihood_prob(self, past_context, future_context): | ||
""" | ||
Computes the utterance likelihoods given the previous context to | ||
condition on and the future context to predict. | ||
:param past_context: Context to condition | ||
:param future_context: Context to predict likelihood | ||
:return: Likelihoods of contexts | ||
""" | ||
past_context = "\n\n".join(past_context) | ||
future_context = "\n\n".join(future_context) | ||
|
||
context_ids = self.tokenizer.encode( | ||
past_context, | ||
truncation=True, | ||
max_length=self.max_length, | ||
return_tensors="pt", | ||
) | ||
future_ids = self.tokenizer.encode( | ||
future_context, | ||
truncation=True, | ||
max_length=self.max_length, | ||
return_tensors="pt", | ||
) | ||
input_ids = torch.cat([context_ids, future_ids], dim=1) | ||
if input_ids.shape[1] > self.max_length: | ||
input_ids = input_ids[:, -self.max_length :] | ||
input_ids = input_ids.to(self.device) | ||
with torch.no_grad(): | ||
probs = torch.nn.functional.softmax(self.model(input_ids)[0], dim=-1) | ||
cond_log_probs = [] | ||
for i, future_id in enumerate(future_ids[0]): | ||
index = i + (input_ids.shape[1] - future_ids.shape[1]) - 1 | ||
logprob = torch.log(probs[0, index, future_id]) | ||
cond_log_probs.append(logprob.item()) | ||
result = sum(cond_log_probs) | ||
return result | ||
|
||
def transform(self, test_data, verbosity=5): | ||
""" | ||
Computes the utterance likelihoods for the provided `test_data`. | ||
:param test_data: Data to compute likelihoods over | ||
:param verbosity: Verbosity to print updated messages | ||
:return: Likelihoods of the `test_data` | ||
""" | ||
prev_contexts, future_contexts = test_data | ||
likelihoods = [] | ||
for i in range(len(prev_contexts)): | ||
if i % verbosity == 0 and i > 0: | ||
print(i, "/", len(test_data)) | ||
convo_likelihoods = {} | ||
convo_prev_contexts = prev_contexts[i] | ||
convo_future_contexts = future_contexts[i] | ||
for utt_id in convo_prev_contexts: | ||
if utt_id not in convo_future_contexts: | ||
continue | ||
utt_prev_context = convo_prev_contexts[utt_id] | ||
utt_future_context = convo_future_contexts[utt_id] | ||
convo_likelihoods[utt_id] = self._calculate_likelihood_prob( | ||
past_context=utt_prev_context, future_context=utt_future_context | ||
) | ||
likelihoods.append(convo_likelihoods) | ||
return likelihoods |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from abc import ABC, abstractmethod | ||
from typing import Callable | ||
|
||
|
||
class LikelihoodModel(ABC): | ||
""" | ||
Abstract class representing a model to compute utterance likelihoods | ||
based on provided context. Different models (Gemma, Llama, Mistral, etc.) | ||
can be supported by inheriting from this base class. | ||
""" | ||
|
||
def __init__(self): | ||
self._name = None | ||
|
||
@property | ||
def name(self): | ||
""" | ||
Name of the likelihood model. | ||
""" | ||
return self._name | ||
|
||
@name.setter | ||
def name(self, name): | ||
""" | ||
Sets the name of the likelihood model. | ||
:param name: Name of model | ||
""" | ||
self._name = name | ||
|
||
@abstractmethod | ||
def fit(self, train_data, val_data): | ||
""" | ||
Fine-tunes the likelihood model on the provided `train_data` and | ||
validates on `val_data`. | ||
:param train_data: Data to fine-tune model | ||
:param val_data: Data to validate model | ||
""" | ||
pass | ||
|
||
@abstractmethod | ||
def transform(self, test_data): | ||
""" | ||
Computes the utterance likelihoods for the provided `test_data`. | ||
:param test_data: Data to compute likelihoods over | ||
:return: Likelihoods of the `test_data` | ||
""" | ||
pass |
Oops, something went wrong.