Skip to content

Commit

Permalink
Code and Demo for Redirection (#250)
Browse files Browse the repository at this point in the history
Co-authored-by: seanzhangkx8 <[email protected]>
  • Loading branch information
vianxnguyen and seanzhangkx8 authored Dec 30, 2024
1 parent af8adcd commit 39e84fa
Show file tree
Hide file tree
Showing 16 changed files with 1,280 additions and 5 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
[![Discord Community](https://img.shields.io/static/v1?logo=discord&style=flat&color=red&label=discord&message=community)](https://discord.gg/WMFqMWgz6P)


This toolkit contains tools to extract conversational features and analyze social phenomena in conversations, using a [single unified interface](https://convokit.cornell.edu/documentation/architecture.html) inspired by (and compatible with) scikit-learn. Several large [conversational datasets](https://github.com/CornellNLP/ConvoKit#datasets) are included together with scripts exemplifying the use of the toolkit on these datasets. The latest version is [3.0.2](https://github.com/CornellNLP/ConvoKit/releases/tag/v3.0.2) (released December 27, 2024); follow the [project on GitHub](https://github.com/CornellNLP/ConvoKit) to keep track of updates.
This toolkit contains tools to extract conversational features and analyze social phenomena in conversations, using a [single unified interface](https://convokit.cornell.edu/documentation/architecture.html) inspired by (and compatible with) scikit-learn. Several large [conversational datasets](https://github.com/CornellNLP/ConvoKit#datasets) are included together with scripts exemplifying the use of the toolkit on these datasets. The latest version is [3.1.0](https://github.com/CornellNLP/ConvoKit/releases/tag/v3.1.0) (released December 30, 2024); follow the [project on GitHub](https://github.com/CornellNLP/ConvoKit) to keep track of updates.

Join our [Discord community](https://discord.gg/WMFqMWgz6P) to stay informed, connect with fellow developers, and be part of an engaging space where we share progress, discuss features, and tackle issues together.

Expand Down
2 changes: 2 additions & 0 deletions convokit/redirection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .likelihoodModel import *
from .redirection import *
33 changes: 33 additions & 0 deletions convokit/redirection/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from peft import LoraConfig
from transformers import BitsAndBytesConfig
import torch

DEFAULT_BNB_CONFIG = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)

DEFAULT_LORA_CONFIG = LoraConfig(
r=16,
lora_dropout=0.05,
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
task_type="CAUSAL_LM",
)

DEFAULT_TRAIN_CONFIG = {
"output_dir": "checkpoints",
"logging_dir": "logging",
"logging_steps": 25,
"eval_steps": 50,
"num_train_epochs": 2,
"per_device_train_batch_size": 1,
"per_device_eval_batch_size": 1,
"evaluation_strategy": "steps",
"save_strategy": "steps",
"save_steps": 50,
"optim": "paged_adamw_8bit",
"learning_rate": 2e-4,
"max_seq_length": 512,
"load_best_model_at_end": True,
}
100 changes: 100 additions & 0 deletions convokit/redirection/contextSelector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from .preprocessing import default_speaker_prefixes


def default_previous_context_selector(convo):
"""
Default function to compute previous contexts for Redirection. For
actual contexts, uses the current utterance and immediate previous
utterance by speaker with different role. For reference contexts, uses
the previous utterance by the same role speaker instead of the current
utterance as a point of reference.
:param convo: ConvoKit Conversation object to compute contexts over
:return: Tuple of actual contexts and reference contexts
"""
actual_contexts = {}
reference_contexts = {}
utts = [utt for utt in convo.iter_utterances()]
roles = list({utt.meta["role"] for utt in utts})
assert len(roles) == 2
spk_prefixes = default_speaker_prefixes(roles)
role_to_prefix = {roles[i]: spk_prefixes[i] for i in range(len(roles))}
role_1 = roles[0]
role_2 = roles[1]
prev_spk = None
prev_1, prev_2, cur_1, cur_2 = None, None, None, None
for i, utt in enumerate(utts):
utt_text = utt.text
cur_spk = utt.meta["role"]
if prev_spk is not None and cur_spk != prev_spk:
if role_2 in cur_spk:
prev_1 = cur_1
else:
prev_2 = cur_2

if prev_1 and prev_2 is not None:
if role_2 in cur_spk:
prev = prev_1
prev_prev = prev_2
else:
prev = prev_2
prev_prev = prev_1

prev_prev_text, prev_prev_role = prev_prev
prev_text, prev_role = prev

prev_prev_data = role_to_prefix[prev_prev_role] + prev_prev_text
prev_data = role_to_prefix[prev_role] + prev_text
cur_data = role_to_prefix[cur_spk] + utt_text

actual_contexts[utt.id] = [prev_data, cur_data]
reference_contexts[utt.id] = [prev_data, prev_prev_data]

if role_1 in cur_spk:
cur_1 = (utt_text, cur_spk)
if role_2 in cur_spk:
cur_2 = (utt_text, cur_spk)

prev_spk = cur_spk

return actual_contexts, reference_contexts


def default_future_context_selector(convo):
"""
Default function to compute future contexts for Redirection. Uses the
immediate successor utterance from a different role speaker.
:param convo: ConvoKit Conversation object to compute contexts over
:return: Dictionary of Utterance id to future contexts
"""
future_contexts = {}
cur_1 = None
cur_2 = None
utts = [utt for utt in convo.iter_utterances()]
roles = list({utt.meta["role"] for utt in utts})
assert len(roles) == 2
spk_prefixes = default_speaker_prefixes(roles)
role_to_prefix = {roles[i]: spk_prefixes[i] for i in range(len(roles))}
role_1 = roles[0]
role_2 = roles[1]
n = len(utts)
for i in range(n - 1, -1, -1):
utt = utts[i]
utt_text = utt.text
cur_spk = utt.meta["role"]
if role_2 in cur_spk:
cur_2 = (utt_text, cur_spk)
if cur_1 is not None:
future_text, future_role = cur_1
future_data = role_to_prefix[future_role] + future_text
future_contexts[utt.id] = [future_data]
else:
cur_1 = (utt_text, cur_spk)
if cur_2 is not None:
future_text, future_role = cur_2
future_data = role_to_prefix[future_role] + future_text
future_contexts[utt.id] = [future_data]
return future_contexts
152 changes: 152 additions & 0 deletions convokit/redirection/gemmaLikelihoodModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from .likelihoodModel import LikelihoodModel
import torch
from peft import LoraConfig, get_peft_model, AutoPeftModelForCausalLM, PeftModel
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
DataCollatorForLanguageModeling,
TrainingArguments,
)
from trl import SFTTrainer
from .config import DEFAULT_TRAIN_CONFIG, DEFAULT_BNB_CONFIG, DEFAULT_LORA_CONFIG


class GemmaLikelihoodModel(LikelihoodModel):
"""
Likelihood model supported by Gemma, used to compute utterance likelihoods.
:param hf_token: Huggingface authentication token
:param model_id: Gemma model id version
:param device: Device to use
:param train_config: Training config for fine-tuning
:param bnb_config: bitsandbytes config for quantization
:param lora_config: LoRA config for fine-tuning
"""

def __init__(
self,
hf_token,
model_id="google/gemma-2b",
device="cuda" if torch.cuda.is_available() else "cpu",
train_config=DEFAULT_TRAIN_CONFIG,
bnb_config=DEFAULT_BNB_CONFIG,
lora_config=DEFAULT_LORA_CONFIG,
):
self.tokenizer = AutoTokenizer.from_pretrained(
model_id, token=hf_token, padding_side="right"
)
self.model = AutoModelForCausalLM.from_pretrained(
model_id, quantization_config=bnb_config, device_map="auto", token=hf_token
)
self.data_collator = DataCollatorForLanguageModeling(tokenizer=self.tokenizer, mlm=False)
self.hf_token = hf_token
self.device = device
self.train_config = train_config
self.lora_config = lora_config
self.bnb_config = bnb_config
self.max_length = self.train_config["max_seq_length"]

def name(self):
return self.__class__.name

def fit(self, train_data, val_data):
"""
Fine-tunes the Gemma model on the provided `train_data` and validates
on `val_data`.
:param train_data: Data to fine-tune model
:param val_data: Data to validate model
"""
training_args = TrainingArguments(
output_dir=self.train_config["output_dir"],
logging_dir=self.train_config["logging_dir"],
logging_steps=self.train_config["logging_steps"],
eval_steps=self.train_config["eval_steps"],
num_train_epochs=self.train_config["num_train_epochs"],
per_device_train_batch_size=self.train_config["per_device_train_batch_size"],
per_device_eval_batch_size=self.train_config["per_device_eval_batch_size"],
evaluation_strategy=self.train_config["evaluation_strategy"],
save_strategy=self.train_config["save_strategy"],
save_steps=self.train_config["save_steps"],
optim=self.train_config["optim"],
learning_rate=self.train_config["learning_rate"],
load_best_model_at_end=self.train_config["load_best_model_at_end"],
)

trainer = SFTTrainer(
model=self.model,
train_dataset=train_data,
eval_dataset=val_data,
args=training_args,
peft_config=self.lora_config,
max_seq_length=self.train_config["max_seq_length"],
)
trainer.train()

def _calculate_likelihood_prob(self, past_context, future_context):
"""
Computes the utterance likelihoods given the previous context to
condition on and the future context to predict.
:param past_context: Context to condition
:param future_context: Context to predict likelihood
:return: Likelihoods of contexts
"""
past_context = "\n\n".join(past_context)
future_context = "\n\n".join(future_context)

context_ids = self.tokenizer.encode(
past_context,
truncation=True,
max_length=self.max_length,
return_tensors="pt",
)
future_ids = self.tokenizer.encode(
future_context,
truncation=True,
max_length=self.max_length,
return_tensors="pt",
)
input_ids = torch.cat([context_ids, future_ids], dim=1)
if input_ids.shape[1] > self.max_length:
input_ids = input_ids[:, -self.max_length :]
input_ids = input_ids.to(self.device)
with torch.no_grad():
probs = torch.nn.functional.softmax(self.model(input_ids)[0], dim=-1)
cond_log_probs = []
for i, future_id in enumerate(future_ids[0]):
index = i + (input_ids.shape[1] - future_ids.shape[1]) - 1
logprob = torch.log(probs[0, index, future_id])
cond_log_probs.append(logprob.item())
result = sum(cond_log_probs)
return result

def transform(self, test_data, verbosity=5):
"""
Computes the utterance likelihoods for the provided `test_data`.
:param test_data: Data to compute likelihoods over
:param verbosity: Verbosity to print updated messages
:return: Likelihoods of the `test_data`
"""
prev_contexts, future_contexts = test_data
likelihoods = []
for i in range(len(prev_contexts)):
if i % verbosity == 0 and i > 0:
print(i, "/", len(test_data))
convo_likelihoods = {}
convo_prev_contexts = prev_contexts[i]
convo_future_contexts = future_contexts[i]
for utt_id in convo_prev_contexts:
if utt_id not in convo_future_contexts:
continue
utt_prev_context = convo_prev_contexts[utt_id]
utt_future_context = convo_future_contexts[utt_id]
convo_likelihoods[utt_id] = self._calculate_likelihood_prob(
past_context=utt_prev_context, future_context=utt_future_context
)
likelihoods.append(convo_likelihoods)
return likelihoods
51 changes: 51 additions & 0 deletions convokit/redirection/likelihoodModel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from abc import ABC, abstractmethod
from typing import Callable


class LikelihoodModel(ABC):
"""
Abstract class representing a model to compute utterance likelihoods
based on provided context. Different models (Gemma, Llama, Mistral, etc.)
can be supported by inheriting from this base class.
"""

def __init__(self):
self._name = None

@property
def name(self):
"""
Name of the likelihood model.
"""
return self._name

@name.setter
def name(self, name):
"""
Sets the name of the likelihood model.
:param name: Name of model
"""
self._name = name

@abstractmethod
def fit(self, train_data, val_data):
"""
Fine-tunes the likelihood model on the provided `train_data` and
validates on `val_data`.
:param train_data: Data to fine-tune model
:param val_data: Data to validate model
"""
pass

@abstractmethod
def transform(self, test_data):
"""
Computes the utterance likelihoods for the provided `test_data`.
:param test_data: Data to compute likelihoods over
:return: Likelihoods of the `test_data`
"""
pass
Loading

0 comments on commit 39e84fa

Please sign in to comment.