From 39e84fae809f4ad8efa0c69ad754184ea56a3727 Mon Sep 17 00:00:00 2001 From: vivian <46759999+vianxnguyen@users.noreply.github.com> Date: Mon, 30 Dec 2024 16:25:23 -0500 Subject: [PATCH] Code and Demo for Redirection (#250) Co-authored-by: seanzhangkx8 <106214464+seanzhangkx8@users.noreply.github.com> --- README.md | 2 +- convokit/redirection/__init__.py | 2 + convokit/redirection/config.py | 33 ++ convokit/redirection/contextSelector.py | 100 ++++ convokit/redirection/gemmaLikelihoodModel.py | 152 ++++++ convokit/redirection/likelihoodModel.py | 51 ++ convokit/redirection/preprocessing.py | 87 +++ convokit/redirection/redirection.py | 163 ++++++ convokit/redirection/redirectionDemo.ipynb | 498 ++++++++++++++++++ convokit/utteranceLikelihood/__init__.py | 1 + .../utteranceLikelihood.py | 153 ++++++ docs/source/analysis.rst | 2 + docs/source/conf.py | 5 +- docs/source/index.rst | 2 +- .../redirectionAndUtteranceLikelihood.rst | 24 + setup.py | 10 +- 16 files changed, 1280 insertions(+), 5 deletions(-) create mode 100644 convokit/redirection/__init__.py create mode 100644 convokit/redirection/config.py create mode 100644 convokit/redirection/contextSelector.py create mode 100644 convokit/redirection/gemmaLikelihoodModel.py create mode 100644 convokit/redirection/likelihoodModel.py create mode 100644 convokit/redirection/preprocessing.py create mode 100644 convokit/redirection/redirection.py create mode 100644 convokit/redirection/redirectionDemo.ipynb create mode 100644 convokit/utteranceLikelihood/__init__.py create mode 100644 convokit/utteranceLikelihood/utteranceLikelihood.py create mode 100644 docs/source/redirectionAndUtteranceLikelihood.rst diff --git a/README.md b/README.md index e5258f71..f59ce210 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ [![Discord Community](https://img.shields.io/static/v1?logo=discord&style=flat&color=red&label=discord&message=community)](https://discord.gg/WMFqMWgz6P) -This toolkit contains tools to extract conversational features and analyze social phenomena in conversations, using a [single unified interface](https://convokit.cornell.edu/documentation/architecture.html) inspired by (and compatible with) scikit-learn. Several large [conversational datasets](https://github.com/CornellNLP/ConvoKit#datasets) are included together with scripts exemplifying the use of the toolkit on these datasets. The latest version is [3.0.2](https://github.com/CornellNLP/ConvoKit/releases/tag/v3.0.2) (released December 27, 2024); follow the [project on GitHub](https://github.com/CornellNLP/ConvoKit) to keep track of updates. +This toolkit contains tools to extract conversational features and analyze social phenomena in conversations, using a [single unified interface](https://convokit.cornell.edu/documentation/architecture.html) inspired by (and compatible with) scikit-learn. Several large [conversational datasets](https://github.com/CornellNLP/ConvoKit#datasets) are included together with scripts exemplifying the use of the toolkit on these datasets. The latest version is [3.1.0](https://github.com/CornellNLP/ConvoKit/releases/tag/v3.1.0) (released December 30, 2024); follow the [project on GitHub](https://github.com/CornellNLP/ConvoKit) to keep track of updates. Join our [Discord community](https://discord.gg/WMFqMWgz6P) to stay informed, connect with fellow developers, and be part of an engaging space where we share progress, discuss features, and tackle issues together. diff --git a/convokit/redirection/__init__.py b/convokit/redirection/__init__.py new file mode 100644 index 00000000..e77b7f19 --- /dev/null +++ b/convokit/redirection/__init__.py @@ -0,0 +1,2 @@ +from .likelihoodModel import * +from .redirection import * diff --git a/convokit/redirection/config.py b/convokit/redirection/config.py new file mode 100644 index 00000000..6473b18e --- /dev/null +++ b/convokit/redirection/config.py @@ -0,0 +1,33 @@ +from peft import LoraConfig +from transformers import BitsAndBytesConfig +import torch + +DEFAULT_BNB_CONFIG = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch.bfloat16, +) + +DEFAULT_LORA_CONFIG = LoraConfig( + r=16, + lora_dropout=0.05, + target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"], + task_type="CAUSAL_LM", +) + +DEFAULT_TRAIN_CONFIG = { + "output_dir": "checkpoints", + "logging_dir": "logging", + "logging_steps": 25, + "eval_steps": 50, + "num_train_epochs": 2, + "per_device_train_batch_size": 1, + "per_device_eval_batch_size": 1, + "evaluation_strategy": "steps", + "save_strategy": "steps", + "save_steps": 50, + "optim": "paged_adamw_8bit", + "learning_rate": 2e-4, + "max_seq_length": 512, + "load_best_model_at_end": True, +} diff --git a/convokit/redirection/contextSelector.py b/convokit/redirection/contextSelector.py new file mode 100644 index 00000000..598f9ae0 --- /dev/null +++ b/convokit/redirection/contextSelector.py @@ -0,0 +1,100 @@ +from .preprocessing import default_speaker_prefixes + + +def default_previous_context_selector(convo): + """ + Default function to compute previous contexts for Redirection. For + actual contexts, uses the current utterance and immediate previous + utterance by speaker with different role. For reference contexts, uses + the previous utterance by the same role speaker instead of the current + utterance as a point of reference. + + :param convo: ConvoKit Conversation object to compute contexts over + + :return: Tuple of actual contexts and reference contexts + """ + actual_contexts = {} + reference_contexts = {} + utts = [utt for utt in convo.iter_utterances()] + roles = list({utt.meta["role"] for utt in utts}) + assert len(roles) == 2 + spk_prefixes = default_speaker_prefixes(roles) + role_to_prefix = {roles[i]: spk_prefixes[i] for i in range(len(roles))} + role_1 = roles[0] + role_2 = roles[1] + prev_spk = None + prev_1, prev_2, cur_1, cur_2 = None, None, None, None + for i, utt in enumerate(utts): + utt_text = utt.text + cur_spk = utt.meta["role"] + if prev_spk is not None and cur_spk != prev_spk: + if role_2 in cur_spk: + prev_1 = cur_1 + else: + prev_2 = cur_2 + + if prev_1 and prev_2 is not None: + if role_2 in cur_spk: + prev = prev_1 + prev_prev = prev_2 + else: + prev = prev_2 + prev_prev = prev_1 + + prev_prev_text, prev_prev_role = prev_prev + prev_text, prev_role = prev + + prev_prev_data = role_to_prefix[prev_prev_role] + prev_prev_text + prev_data = role_to_prefix[prev_role] + prev_text + cur_data = role_to_prefix[cur_spk] + utt_text + + actual_contexts[utt.id] = [prev_data, cur_data] + reference_contexts[utt.id] = [prev_data, prev_prev_data] + + if role_1 in cur_spk: + cur_1 = (utt_text, cur_spk) + if role_2 in cur_spk: + cur_2 = (utt_text, cur_spk) + + prev_spk = cur_spk + + return actual_contexts, reference_contexts + + +def default_future_context_selector(convo): + """ + Default function to compute future contexts for Redirection. Uses the + immediate successor utterance from a different role speaker. + + :param convo: ConvoKit Conversation object to compute contexts over + + :return: Dictionary of Utterance id to future contexts + """ + future_contexts = {} + cur_1 = None + cur_2 = None + utts = [utt for utt in convo.iter_utterances()] + roles = list({utt.meta["role"] for utt in utts}) + assert len(roles) == 2 + spk_prefixes = default_speaker_prefixes(roles) + role_to_prefix = {roles[i]: spk_prefixes[i] for i in range(len(roles))} + role_1 = roles[0] + role_2 = roles[1] + n = len(utts) + for i in range(n - 1, -1, -1): + utt = utts[i] + utt_text = utt.text + cur_spk = utt.meta["role"] + if role_2 in cur_spk: + cur_2 = (utt_text, cur_spk) + if cur_1 is not None: + future_text, future_role = cur_1 + future_data = role_to_prefix[future_role] + future_text + future_contexts[utt.id] = [future_data] + else: + cur_1 = (utt_text, cur_spk) + if cur_2 is not None: + future_text, future_role = cur_2 + future_data = role_to_prefix[future_role] + future_text + future_contexts[utt.id] = [future_data] + return future_contexts diff --git a/convokit/redirection/gemmaLikelihoodModel.py b/convokit/redirection/gemmaLikelihoodModel.py new file mode 100644 index 00000000..9bd5a787 --- /dev/null +++ b/convokit/redirection/gemmaLikelihoodModel.py @@ -0,0 +1,152 @@ +from .likelihoodModel import LikelihoodModel +import torch +from peft import LoraConfig, get_peft_model, AutoPeftModelForCausalLM, PeftModel +from transformers import ( + AutoTokenizer, + AutoModelForCausalLM, + BitsAndBytesConfig, + DataCollatorForLanguageModeling, + TrainingArguments, +) +from trl import SFTTrainer +from .config import DEFAULT_TRAIN_CONFIG, DEFAULT_BNB_CONFIG, DEFAULT_LORA_CONFIG + + +class GemmaLikelihoodModel(LikelihoodModel): + """ + Likelihood model supported by Gemma, used to compute utterance likelihoods. + + :param hf_token: Huggingface authentication token + :param model_id: Gemma model id version + :param device: Device to use + :param train_config: Training config for fine-tuning + :param bnb_config: bitsandbytes config for quantization + :param lora_config: LoRA config for fine-tuning + """ + + def __init__( + self, + hf_token, + model_id="google/gemma-2b", + device="cuda" if torch.cuda.is_available() else "cpu", + train_config=DEFAULT_TRAIN_CONFIG, + bnb_config=DEFAULT_BNB_CONFIG, + lora_config=DEFAULT_LORA_CONFIG, + ): + self.tokenizer = AutoTokenizer.from_pretrained( + model_id, token=hf_token, padding_side="right" + ) + self.model = AutoModelForCausalLM.from_pretrained( + model_id, quantization_config=bnb_config, device_map="auto", token=hf_token + ) + self.data_collator = DataCollatorForLanguageModeling(tokenizer=self.tokenizer, mlm=False) + self.hf_token = hf_token + self.device = device + self.train_config = train_config + self.lora_config = lora_config + self.bnb_config = bnb_config + self.max_length = self.train_config["max_seq_length"] + + def name(self): + return self.__class__.name + + def fit(self, train_data, val_data): + """ + Fine-tunes the Gemma model on the provided `train_data` and validates + on `val_data`. + + :param train_data: Data to fine-tune model + :param val_data: Data to validate model + """ + training_args = TrainingArguments( + output_dir=self.train_config["output_dir"], + logging_dir=self.train_config["logging_dir"], + logging_steps=self.train_config["logging_steps"], + eval_steps=self.train_config["eval_steps"], + num_train_epochs=self.train_config["num_train_epochs"], + per_device_train_batch_size=self.train_config["per_device_train_batch_size"], + per_device_eval_batch_size=self.train_config["per_device_eval_batch_size"], + evaluation_strategy=self.train_config["evaluation_strategy"], + save_strategy=self.train_config["save_strategy"], + save_steps=self.train_config["save_steps"], + optim=self.train_config["optim"], + learning_rate=self.train_config["learning_rate"], + load_best_model_at_end=self.train_config["load_best_model_at_end"], + ) + + trainer = SFTTrainer( + model=self.model, + train_dataset=train_data, + eval_dataset=val_data, + args=training_args, + peft_config=self.lora_config, + max_seq_length=self.train_config["max_seq_length"], + ) + trainer.train() + + def _calculate_likelihood_prob(self, past_context, future_context): + """ + Computes the utterance likelihoods given the previous context to + condition on and the future context to predict. + + :param past_context: Context to condition + :param future_context: Context to predict likelihood + + :return: Likelihoods of contexts + """ + past_context = "\n\n".join(past_context) + future_context = "\n\n".join(future_context) + + context_ids = self.tokenizer.encode( + past_context, + truncation=True, + max_length=self.max_length, + return_tensors="pt", + ) + future_ids = self.tokenizer.encode( + future_context, + truncation=True, + max_length=self.max_length, + return_tensors="pt", + ) + input_ids = torch.cat([context_ids, future_ids], dim=1) + if input_ids.shape[1] > self.max_length: + input_ids = input_ids[:, -self.max_length :] + input_ids = input_ids.to(self.device) + with torch.no_grad(): + probs = torch.nn.functional.softmax(self.model(input_ids)[0], dim=-1) + cond_log_probs = [] + for i, future_id in enumerate(future_ids[0]): + index = i + (input_ids.shape[1] - future_ids.shape[1]) - 1 + logprob = torch.log(probs[0, index, future_id]) + cond_log_probs.append(logprob.item()) + result = sum(cond_log_probs) + return result + + def transform(self, test_data, verbosity=5): + """ + Computes the utterance likelihoods for the provided `test_data`. + + :param test_data: Data to compute likelihoods over + :param verbosity: Verbosity to print updated messages + + :return: Likelihoods of the `test_data` + """ + prev_contexts, future_contexts = test_data + likelihoods = [] + for i in range(len(prev_contexts)): + if i % verbosity == 0 and i > 0: + print(i, "/", len(test_data)) + convo_likelihoods = {} + convo_prev_contexts = prev_contexts[i] + convo_future_contexts = future_contexts[i] + for utt_id in convo_prev_contexts: + if utt_id not in convo_future_contexts: + continue + utt_prev_context = convo_prev_contexts[utt_id] + utt_future_context = convo_future_contexts[utt_id] + convo_likelihoods[utt_id] = self._calculate_likelihood_prob( + past_context=utt_prev_context, future_context=utt_future_context + ) + likelihoods.append(convo_likelihoods) + return likelihoods diff --git a/convokit/redirection/likelihoodModel.py b/convokit/redirection/likelihoodModel.py new file mode 100644 index 00000000..4551cc1a --- /dev/null +++ b/convokit/redirection/likelihoodModel.py @@ -0,0 +1,51 @@ +from abc import ABC, abstractmethod +from typing import Callable + + +class LikelihoodModel(ABC): + """ + Abstract class representing a model to compute utterance likelihoods + based on provided context. Different models (Gemma, Llama, Mistral, etc.) + can be supported by inheriting from this base class. + """ + + def __init__(self): + self._name = None + + @property + def name(self): + """ + Name of the likelihood model. + """ + return self._name + + @name.setter + def name(self, name): + """ + Sets the name of the likelihood model. + + :param name: Name of model + """ + self._name = name + + @abstractmethod + def fit(self, train_data, val_data): + """ + Fine-tunes the likelihood model on the provided `train_data` and + validates on `val_data`. + + :param train_data: Data to fine-tune model + :param val_data: Data to validate model + """ + pass + + @abstractmethod + def transform(self, test_data): + """ + Computes the utterance likelihoods for the provided `test_data`. + + :param test_data: Data to compute likelihoods over + + :return: Likelihoods of the `test_data` + """ + pass diff --git a/convokit/redirection/preprocessing.py b/convokit/redirection/preprocessing.py new file mode 100644 index 00000000..2059006d --- /dev/null +++ b/convokit/redirection/preprocessing.py @@ -0,0 +1,87 @@ +from datasets import Dataset + + +def default_speaker_prefixes(roles): + """ + Gemerates speaker prefixes for speaker roles. + + :param roles: Roles to generate prefixes for. + + :return: List of speaker prefixes + """ + number_of_roles = len(roles) + speakers = ["Speaker " + chr(65 + (i % 26)) + ": " for i in range(number_of_roles)] + return speakers + + +def format_conversations(convos): + """ + Format the conversations used for fine-tuning and inference. + + :param convos: List of conversations to format + + :return: Formatted conversations + """ + formatted_convos = [] + for convo in convos: + utts = [utt for utt in convo.iter_utterances()] + roles = list({utt.meta["role"] for utt in utts}) + spk_prefixes = default_speaker_prefixes(roles) + role_to_prefix = {roles[i]: spk_prefixes[i] for i in range(len(roles))} + formatted_utts = [] + for utt in utts: + utt_text = role_to_prefix[utt.meta["role"]] + utt.text + formatted_utts.append(utt_text) + formatted_convo = "\n\n".join(formatted_utts) + formatted_convos.append(formatted_convo) + return formatted_convos + + +def get_chunk_dataset(tokenizer, convos, max_tokens=512, overlap_tokens=50): + """ + Generate a chunked dataset for training given max sequence length + and overlap length. + + :param tokenizer: Tokenizer of model + :param convos: List of conversations to generate dataset + :param max_tokens: Max sequence length + :param overlap_tokens: Number of overlap tokens for chunks + + :return: Chunk dataset + """ + chunks = [] + for convo in convos: + convo_chunks = chunk_text_with_overlap( + tokenizer, + convo, + max_tokens=max_tokens, + overlap_tokens=overlap_tokens, + ) + chunks += convo_chunks + + data_dict = {"text": chunks} + dataset = Dataset.from_dict(data_dict) + return dataset + + +def chunk_text_with_overlap(tokenizer, text, max_tokens=512, overlap_tokens=50): + """ + Split conversation into chunks for training. + + :param tokenizer: Tokenizer of model + :param text: Text to chunk + :param max_tokens: Max sequence length + :param overlap_tokens: Number of overlap tokens for chunks + + :return: Chunk of texts + """ + tokens = tokenizer.encode(text) + chunks = [] + start = 0 + while start < len(tokens): + end = min(start + max_tokens, len(tokens)) + overlap_end = max(start + max_tokens - overlap_tokens, start) + chunk = tokens[start:overlap_end] + chunks.append(tokenizer.decode(chunk)) + start = overlap_end + return chunks diff --git a/convokit/redirection/redirection.py b/convokit/redirection/redirection.py new file mode 100644 index 00000000..4421ad06 --- /dev/null +++ b/convokit/redirection/redirection.py @@ -0,0 +1,163 @@ +from convokit import Transformer +from .likelihoodModel import LikelihoodModel +from .contextSelector import default_previous_context_selector, default_future_context_selector +import torch +import random +from .preprocessing import format_conversations, get_chunk_dataset +import numpy as np + + +class Redirection(Transformer): + """ + ConvoKit transformer to compute redirection scores, derived from + utterance probabilities from `likelihood_model`. The contexts used + to compute redirection can be defined using `previous_context_selector` + and `future_context_selector`, which are by default the immediate previous + and future contexts from different speaker roles. + + :param likelihood_model: Likelihood model to compute utterance likelihoods + :param previous_context_selector: Computes tuple of actual, reference contexts + used for redirection + :param future_context_selector: Computes future contexts used for redirection + :param redirection_attribute_name: Name of meta-data attribute to + save redirection scores + """ + + def __init__( + self, + likelihood_model, + previous_context_selector=default_previous_context_selector, + future_context_selector=default_future_context_selector, + redirection_attribute_name="redirection", + ): + self.likelihood_model = likelihood_model + self.tokenizer = self.likelihood_model.tokenizer + self.previous_context_selector = previous_context_selector + self.future_context_selector = future_context_selector + self.redirection_attribute_name = redirection_attribute_name + + def fit(self, corpus, train_selector=lambda convo: True, val_selector=lambda convo: True): + """ + Fits the redirection transformer to the corpus by generating the training + and validation data and fine-tuning the likelihood model. + + :param corpus: Corpus to fit transformer + :param train_selector: Selector for train conversations + :param val_selector: Selector for val conversations + """ + train_convos = [convo for convo in corpus.iter_conversations() if train_selector(convo)] + val_convos = [convo for convo in corpus.iter_conversations() if val_selector(convo)] + train_convos_formatted = format_conversations(train_convos) + val_convos_formatted = format_conversations(val_convos) + train_data = get_chunk_dataset( + self.tokenizer, train_convos_formatted, max_tokens=self.likelihood_model.max_length + ) + val_data = get_chunk_dataset( + self.tokenizer, val_convos_formatted, max_tokens=self.likelihood_model.max_length + ) + self.likelihood_model.fit(train_data=train_data, val_data=val_data) + return self + + def transform(self, corpus, selector=lambda convo: True, verbosity=5): + """ + Populates the corpus test data with redirection scores, by computing + previous and future contexts, determining actual and reference likelihoods, + and calculating redirection scores. + + :param corpus: Corpus to transform + :param selector: Selector for test data + :param verbosity: Verbosity for update messages + + :return: Corpus where test data is labeled with redirection scores + """ + test_convos = [convo for convo in corpus.iter_conversations() if selector(convo)] + actual_contexts = [] + reference_contexts = [] + future_contexts = [] + print("Computing contexts") + for i, convo in enumerate(test_convos): + if i % verbosity == 0 and i > 0: + print(i, "/", len(test_convos)) + actual, reference = self.previous_context_selector(convo) + future = self.future_context_selector(convo) + actual_contexts.append(actual) + reference_contexts.append(reference) + future_contexts.append(future) + + print("Computing actual likelihoods") + test_data = (actual_contexts, future_contexts) + actual_likelihoods = self.likelihood_model.transform(test_data, verbosity=verbosity) + + print("Computing reference likelihoods") + test_data = (reference_contexts, future_contexts) + reference_likelihoods = self.likelihood_model.transform(test_data, verbosity=verbosity) + + print("Computing redirection scores") + for i, convo in enumerate(test_convos): + if i % verbosity == 0 and i > 0: + print(i, "/", len(test_convos)) + convo_actual_likelihoods = actual_likelihoods[i] + convo_reference_likelihoods = reference_likelihoods[i] + for utt in convo.iter_utterances(): + if utt.id in convo_actual_likelihoods and utt.id in convo_reference_likelihoods: + actual_prob = convo_actual_likelihoods[utt.id] + reference_prob = convo_reference_likelihoods[utt.id] + redirection = ( + actual_prob + - np.log(1 - np.exp(actual_prob)) + - reference_prob + + np.log(1 - np.exp(reference_prob)) + ) + utt.meta[self.redirection_attribute_name] = redirection + + return corpus + + def fit_transform( + self, + train_selector=lambda convo: True, + val_selector=lambda convo: True, + test_selector=lambda convo: True, + verbosity=10, + ): + """ + Fit and transform the model. + + :param corpus: Corpus to transform + :param train_selector: Selector for train data + :param val_selector: Selector for val data + :param test_selector: Selector for test data + :param verbosity: Verbosity for update messages + + :return: Corpus where test data is labeled with redirection scores + """ + self.fit(corpus, train_selector=train_selector, val_selector=val_selector) + return self.transform(corpus, selector=test_selector, verbosity=verbosity) + + def summarize(self, corpus, top_sample_size=10, bottom_sample_size=10): + """ + Summarizes redirection transformer with high and low redirecting + utterances. + + :param corpus: Corpus to analyze + :param top_sample_size: Number of utterances to print for high redirection + :param bottom_sample_size: Number of utterances to print for low redirection + """ + utts = [ + utt for utt in corpus.iter_utterances() if self.redirection_attribute_name in utt.meta + ] + sorted_utts = sorted(utts, key=lambda utt: utt.meta[self.redirection_attribute_name]) + top_sample_size = min(top_sample_size, len(sorted_utts)) + bottom_sample_size = min(bottom_sample_size, len(sorted_utts)) + print("[high]" + self.redirection_attribute_name) + for i in range(-1, -1 - top_sample_size, -1): + utt = sorted_utts[i] + print(utt.speaker.id, ":", utt.text, "\n") + + print() + + print("[low]" + self.redirection_attribute_name) + for i in range(bottom_sample_size): + utt = sorted_utts[i] + print(utt.speaker.id, ":", utt.text, "\n") + + return self diff --git a/convokit/redirection/redirectionDemo.ipynb b/convokit/redirection/redirectionDemo.ipynb new file mode 100644 index 00000000..6aebebf2 --- /dev/null +++ b/convokit/redirection/redirectionDemo.ipynb @@ -0,0 +1,498 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a6a0fb8e", + "metadata": {}, + "source": [ + "# Redirection Demo in US Supreme Court oral arguments" + ] + }, + { + "cell_type": "markdown", + "id": "c4894fbb", + "metadata": {}, + "source": [ + "This notebook demonstrates our redirection framework introduced this paper: **Taking a turn for the better: Conversation redirection throughout the course of mental-health therapy.** In the paper, we define redirection as the extent to which speakers shift the immediate focus of the conversation and applied our measure in the context of long-term messaging therapy. In this demo, we provide an initial exploration into how our redirection framework can be applied in other domains in particular, to a publicly available dataset of U.S. Supreme Court oral arguments (Danescu-Niculescu-Mizil et al., 2012; Chang et al., 2020). Although court proceedings differ from therapy in terms of topics, goals, and interaction styles, their relatively unstructured and dynamic nature enables an initial exploration of how such discussions are redirected.\n", + "\n", + "In this setting, we focus on the interactions between justices and lawyers. The power dynamics between these distinct roles reflect the asymmetric relationship between therapists and patients in mental-health domains, where one party generally holds more influence over the direction of the conversation." + ] + }, + { + "cell_type": "markdown", + "id": "cfc154fb", + "metadata": {}, + "source": [ + "We first install and import all the necessary packages from Convokit including our wrapper models and config files." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "db3314f1", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install git+https://github.com/vianxnguyen/ConvoKit.git\n", + "# !pip install -q convokit" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0bc1290", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "from convokit import Corpus, download\n", + "from convokit.redirection.likelihoodModel import LikelihoodModel\n", + "from convokit.redirection.gemmaLikelihoodModel import GemmaLikelihoodModel\n", + "from convokit.redirection.redirection import Redirection\n", + "from convokit.redirection.config import DEFAULT_BNB_CONFIG, DEFAULT_LORA_CONFIG, DEFAULT_TRAIN_CONFIG\n", + "import random\n", + "from sklearn.model_selection import train_test_split\n", + "import numpy as np\n", + "from scipy.stats import wilcoxon" + ] + }, + { + "cell_type": "markdown", + "id": "7e6fc4d3", + "metadata": {}, + "source": [ + "We then download the `supreme-court` corpus we will be using for training and analysis. If you already have the corpus saved locally, you can specify the path to load the corpus from." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "536e4b44", + "metadata": {}, + "outputs": [], + "source": [ + "# If you already have the corpus saved locally, load the corpus from the saved path.\n", + "# DATA_DIR = '/Users/vian/.convokit/downloads/supreme-corpus'\n", + "# corpus = Corpus(DATA_DIR)\n", + "\n", + "# Otherwise download the corpus\n", + "corpus = Corpus(filename=download('supreme-corpus'))\n", + "corpus.print_summary_stats()" + ] + }, + { + "cell_type": "markdown", + "id": "470f39e0", + "metadata": {}, + "source": [ + "For the purposes of the demo, we will randomly sample a subset of 50 conversations (~20k utterances) for our analysis. Since in this demonstration, we focus on interactions between two distinct roles of justices and lawyers, we label the speaker role for each utterance (either justice or lawyer). " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c2496d36", + "metadata": {}, + "outputs": [], + "source": [ + "convos = [convo for convo in corpus.iter_conversations()]\n", + "sample_convos = random.sample(convos, 50)\n", + "print(len(sample_convos))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "99029a67", + "metadata": {}, + "outputs": [], + "source": [ + "for convo in sample_convos:\n", + " for utt in convo.iter_utterances():\n", + " if utt.speaker.id.startswith(\"j_\"):\n", + " utt.meta[\"role\"] = \"justice\"\n", + " else:\n", + " utt.meta[\"role\"] = \"lawyer\"" + ] + }, + { + "cell_type": "markdown", + "id": "1b427dbe", + "metadata": {}, + "source": [ + "We will use a 90/10/10 train/val/test split. We then label the conversations with their corresponding split." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6a8ffa90", + "metadata": {}, + "outputs": [], + "source": [ + "train_convos, temp_convos = train_test_split(sample_convos, test_size=0.2, random_state=10)\n", + "val_convos, test_convos = train_test_split(temp_convos, test_size=0.5, random_state=10)\n", + "print(len(train_convos), len(val_convos), len(test_convos))\n", + "\n", + "for convo in train_convos:\n", + " convo.meta[\"train\"] = True\n", + "for convo in val_convos: \n", + " convo.meta[\"val\"] = True \n", + "for convo in test_convos:\n", + " convo.meta[\"test\"] = True " + ] + }, + { + "cell_type": "markdown", + "id": "56b1d3c6", + "metadata": {}, + "source": [ + "Now, we define our likelihood model responsible for computing utterance likelihoods based on provided context.The likelihood probabilities are later used to compute redirection scores for each utterance. Here, we define a likelihood model using the Gemma-2B model called `GemmaLikelihodModel` which inherits from a default `LikelihoodModel` interface. Different models (Gemma, Llama, Mistral, etc.) can be supported by inheriting from this base interface. \n", + "\n", + "Since in this demo, we are using Gemma-2B through HuggingFace, we need to provide an authentication token for access to the model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1bf38bac", + "metadata": {}, + "outputs": [], + "source": [ + "gemma_likelihood_model = \\\n", + " GemmaLikelihoodModel(\n", + " hf_token = \"TODO: ADD HUGGINGFACE AUTH TOKEN\",\n", + " model_id = \"google/gemma-2b\", \n", + " train_config = DEFAULT_TRAIN_CONFIG,\n", + " bnb_config = DEFAULT_BNB_CONFIG,\n", + " lora_config = DEFAULT_LORA_CONFIG,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "38d38969", + "metadata": {}, + "source": [ + "We use the following default configs and parameters for fine-tuning. However, you may override these by defining your own configs and passing them to the `GemmaLikelihoodModel`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8ef250fc", + "metadata": {}, + "outputs": [], + "source": [ + "\"\"\"\n", + "DEFAULT_BNB_CONFIG = BitsAndBytesConfig(\n", + " load_in_4bit=True,\n", + " bnb_4bit_quant_type=\"nf4\",\n", + " bnb_4bit_compute_dtype=torch.bfloat16\n", + ")\n", + "\n", + "DEFAULT_LORA_CONFIG = LoraConfig(\n", + " r=16,\n", + " lora_dropout=0.05,\n", + " target_modules=[\"q_proj\", \"o_proj\", \"k_proj\", \"v_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n", + " task_type=\"CAUSAL_LM\",\n", + ")\n", + "\n", + "DEFAULT_TRAIN_CONFIG = {\n", + " \"output_dir\": \"checkpoints\",\n", + " \"logging_dir\": \"logging\",\n", + " \"logging_steps\": 25,\n", + " \"eval_steps\": 50, \n", + " \"num_train_epochs\": 2, \n", + " \"per_device_train_batch_size\": 1, \n", + " \"per_device_eval_batch_size\": 1, \n", + " \"evaluation_strategy\": \"steps\",\n", + " \"save_strategy\": \"steps\",\n", + " \"save_steps\": 50,\n", + " \"optim\": \"paged_adamw_8bit\",\n", + " \"learning_rate\": 2e-4,\n", + " \"max_seq_length\": 512,\n", + " \"load_best_model_at_end\": True,\n", + "}\n", + "\"\"\"" + ] + }, + { + "cell_type": "markdown", + "id": "04001232", + "metadata": {}, + "source": [ + "Now we can define our redirection model, providing the initialized `gemma_likelihood_model` as our `LikelihoodModel`. The `redirection_attribute_name` represents the name of the meta-data field to save our redirection scores to in the corpus.\n", + "\n", + "We also note that it is possible to define your own `previous_context_selector` and `future_context_selector` to determine which contexts you would use to compute the likelihoods. The functions take as input an utterance and returns the previous (actual and reference) or future contexts for that particular utterance. By default, we use the immediate contexts described in our paper. Note that the default implementation for these contexts assumes we are working with two distinct speaker roles. You may write your own context selectors to customize them for more than two speaker types." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3c4a285c", + "metadata": {}, + "outputs": [], + "source": [ + "redirection = \\\n", + " Redirection(\n", + " likelihood_model = gemma_likelihood_model,\n", + " redirection_attribute_name = \"redirection\"\n", + "# previous_context_selector = , \n", + "# future_context_selector = ,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "69928ba4", + "metadata": {}, + "source": [ + "Now we can call the fit method to fine-tune our model on a subset of the conversations in the corpus. We use a selector function to only fine-tune on the `train` subset of our data. Alternatively, if you already have saved an existing model, you can load it into memory using `load_from_disk`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "07b8022d", + "metadata": {}, + "outputs": [], + "source": [ + "redirection.fit(corpus, \n", + " train_selector=lambda convo: \"train\" in convo.meta, \n", + " val_selector=lambda convo: \"val\" in convo.meta\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "bc8a5d20", + "metadata": {}, + "source": [ + "After we have our fine-tuned model, we can then run inference on the test conversations in order to compute the redirection scores. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8f975087", + "metadata": {}, + "outputs": [], + "source": [ + "redirection.transform(corpus, selector=lambda convo: \"test\" in convo.meta)" + ] + }, + { + "cell_type": "markdown", + "id": "01427de1", + "metadata": {}, + "source": [ + "We can then call summarize to view examples of high and low redirecting utterances from each speaker." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a36bccba", + "metadata": {}, + "outputs": [], + "source": [ + "redirection.summarize(corpus)" + ] + }, + { + "cell_type": "markdown", + "id": "087cfcc0", + "metadata": {}, + "source": [ + "We can also perform a FightingWords analysis to see distinguishing bigrams indicating high vs. low redirection from both speakers." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "e48ab266", + "metadata": {}, + "outputs": [], + "source": [ + "from convokit import FightingWords" + ] + }, + { + "cell_type": "markdown", + "id": "3def9310", + "metadata": {}, + "source": [ + "We first label top 20% and bottom 20% of utterances from both speakers based on their redirection scores." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c96d367c", + "metadata": {}, + "outputs": [], + "source": [ + "justice_utts = []\n", + "lawyer_utts = []\n", + "\n", + "for convo in test_convos: \n", + " for utt in convo.iter_utterances():\n", + " if \"redirection\" in utt.meta:\n", + " if utt.meta[\"role\"] == \"justice\":\n", + " justice_utts.append(utt)\n", + " else:\n", + " lawyer_utts.append(utt)\n", + "\n", + "justice_utts = sorted(justice_utts, key=lambda utt: utt.meta[\"redirection\"])\n", + "lawyer_utts = sorted(lawyer_utts, key=lambda utt: utt.meta[\"redirection\"])\n", + "\n", + "justice_threshold = int(len(justice_utts) * 0.20)\n", + "lawyer_threshold = int(len(lawyer_utts) * 0.20)\n", + "\n", + "for utt in justice_utts[:justice_threshold]:\n", + " utt.meta['type'] = \"justice_low\"\n", + "for utt in justice_utts[-justice_threshold:]:\n", + " utt.meta['type'] = \"justice_high\"\n", + "\n", + "for utt in lawyer_utts[:lawyer_threshold]:\n", + " utt.meta['type'] = \"lawyer_low\"\n", + "for utt in lawyer_utts[-lawyer_threshold:]:\n", + " utt.meta['type'] = \"lawyer_high\"" + ] + }, + { + "cell_type": "markdown", + "id": "33167261", + "metadata": {}, + "source": [ + "Here we first show phrasings indicative of low redirection from justices." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "49dbe4ed", + "metadata": {}, + "outputs": [], + "source": [ + "fw_justice = FightingWords(ngram_range=(2,2))\n", + "class1 = 'justice_high'\n", + "class2 = 'justice_low'\n", + "fw_justice.fit(corpus, class1_func=lambda utt: 'type' in utt.meta and utt.meta['type'] == class1, \n", + " class2_func=lambda utt: 'type' in utt.meta and utt.meta['type'] == class2)\n", + "justice = fw_justice.summarize(corpus, plot=False, class1_name=class1, class2_name=class2)\n", + "justice.head(20)" + ] + }, + { + "cell_type": "markdown", + "id": "f816627d", + "metadata": {}, + "source": [ + "Here we show phrasings indicative of high redirection from justices." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1ad93c86", + "metadata": {}, + "outputs": [], + "source": [ + "justice.tail(20)[::-1]" + ] + }, + { + "cell_type": "markdown", + "id": "34287b06", + "metadata": {}, + "source": [ + "We can perform the corresponding analysis for lawyers as well." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d6c101e7", + "metadata": {}, + "outputs": [], + "source": [ + "fw_lawyer = FightingWords(ngram_range=(2,2))\n", + "class1 = 'lawyer_high'\n", + "class2 = 'lawyer_low'\n", + "fw_lawyer.fit(corpus, class1_func=lambda utt: 'type' in utt.meta and utt.meta['type'] == class1, \n", + " class2_func=lambda utt: 'type' in utt.meta and utt.meta['type'] == class2)\n", + "lawyer = fw_lawyer.summarize(corpus, plot=False, class1_name=class1, class2_name=class2)\n", + "lawyer.head(20)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fbe98f5a", + "metadata": {}, + "outputs": [], + "source": [ + "lawyer.tail(20)[::-1]" + ] + }, + { + "cell_type": "markdown", + "id": "39208446", + "metadata": {}, + "source": [ + "We can also compare the average redirection between justices and lawyers in the cases." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2d52baf8", + "metadata": {}, + "outputs": [], + "source": [ + "convo_justices = []\n", + "convo_lawyers = []\n", + "for convo in test_convos: \n", + " justice = []\n", + " lawyer = []\n", + " for utt in convo.iter_utterances():\n", + " if \"redirection\" in utt.meta:\n", + " if utt.meta[\"role\"] == \"justice\":\n", + " justice.append(utt.meta[\"redirection\"])\n", + " else:\n", + " lawyer.append(utt.meta[\"redirection\"])\n", + " convo_justices.append(np.mean(justice))\n", + " convo_lawyers.append(np.mean(lawyer))\n", + " \n", + "print(\"Average justice:\", np.mean(convo_justices))\n", + "print(\"Average lawyer:\", np.mean(convo_lawyers))\n", + "stat, p_value = wilcoxon(convo_justices, convo_lawyers)\n", + "print(f\"Statistic: {stat}, P-value: {p_value}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/convokit/utteranceLikelihood/__init__.py b/convokit/utteranceLikelihood/__init__.py new file mode 100644 index 00000000..76dcfe9f --- /dev/null +++ b/convokit/utteranceLikelihood/__init__.py @@ -0,0 +1 @@ +from .utteranceLikelihood import * diff --git a/convokit/utteranceLikelihood/utteranceLikelihood.py b/convokit/utteranceLikelihood/utteranceLikelihood.py new file mode 100644 index 00000000..37b943a8 --- /dev/null +++ b/convokit/utteranceLikelihood/utteranceLikelihood.py @@ -0,0 +1,153 @@ +from convokit import Transformer +from convokit.redirection.likelihoodModel import LikelihoodModel +from convokit.redirection.contextSelector import ( + default_previous_context_selector, + default_future_context_selector, +) +import torch +import random +from convokit.redirection.preprocessing import format_conversations, get_chunk_dataset +import numpy as np + + +class UtteranceLikelihood(Transformer): + """ + ConvoKit transformer to compute utterance log-likelihoods derived from + `likelihood_model`. The contexts used to compute the likelihoods can be + defined using `previous_context_selector` and `future_context_selector`, + which are by default the immediate previous and current contexts from + different speaker roles. + + :param likelihood_model: Likelihood model to compute utterance log-likelihoods + :param previous_context_selector: Computes previous contexts + :param future_context_selector: Computes future contexts + :param likelihood_attribute_name: Name of meta-data attribute to + save likelihoods + """ + + def __init__( + self, + likelihood_model, + previous_context_selector=None, + future_context_selector=None, + likelihood_attribute_name="utterance_likelihood", + ): + self.likelihood_model = likelihood_model + self.tokenizer = self.likelihood_model.tokenizer + self.previous_context_selector = previous_context_selector + self.future_context_selector = future_context_selector + self.likelihood_attribute_name = likelihood_attribute_name + + def fit(self, corpus, train_selector=lambda convo: True, val_selector=lambda convo: True): + """ + Fits the UtteranceLikelihood transformer to the corpus by generating the training + and validation data and fine-tuning the likelihood model. + + :param corpus: Corpus to fit transformer + :param train_selector: Selector for train conversations + :param val_selector: Selector for val conversations + """ + train_convos = [convo for convo in corpus.iter_conversations() if train_selector(convo)] + val_convos = [convo for convo in corpus.iter_conversations() if val_selector(convo)] + train_convos_formatted = format_conversations(train_convos) + val_convos_formatted = format_conversations(val_convos) + train_data = get_chunk_dataset( + self.tokenizer, train_convos_formatted, max_tokens=self.likelihood_model.max_length + ) + val_data = get_chunk_dataset( + self.tokenizer, val_convos_formatted, max_tokens=self.likelihood_model.max_length + ) + self.likelihood_model.fit(train_data=train_data, val_data=val_data) + return self + + def transform(self, corpus, selector=lambda convo: True, verbosity=5): + """ + Populates the corpus test data with utterance likelihoods, by first + computing previous and future contexts. + + :param corpus: Corpus to transform + :param selector: Selector for test data + :param verbosity: Verbosity for update messages + + :return: Corpus where test data is labeled with utterance likelihoods + """ + test_convos = [convo for convo in corpus.iter_conversations() if selector(convo)] + previous_contexts = [] + future_contexts = [] + print("Computing contexts") + for i, convo in enumerate(test_convos): + if i % verbosity == 0 and i > 0: + print(i, "/", len(test_convos)) + if self.previous_context_selector is None and self.future_context_selector is None: + contexts, _ = default_previous_context_selector(convo) + previous = {utt_id: pair[0] for utt_id, pair in contexts.items()} + future = {utt_id: pair[1] for utt_id, pair in contexts.items()} + else: + previous = self.previous_context_selector(convo) + future = self.future_context_selector(convo) + + previous_contexts.append(previous) + future_contexts.append(future) + + print("Computing utterance likelihoods") + test_data = (previous_contexts, future_contexts) + likelihoods = self.likelihood_model.transform(test_data, verbosity=verbosity) + + print("Labeling utterance likelihoods") + for i, convo in enumerate(test_convos): + if i % verbosity == 0 and i > 0: + print(i, "/", len(test_convos)) + likelihoods = likelihoods[i] + for utt in convo.iter_utterances(): + if utt.id in likelihoods: + utt.meta[self.likelihood_attribute_name] = likelihoods[utt.id] + + return corpus + + def fit_transform( + self, + train_selector=lambda convo: True, + val_selector=lambda convo: True, + test_selector=lambda convo: True, + verbosity=10, + ): + """ + Fit and transform the model. + + :param corpus: Corpus to transform + :param train_selector: Selector for train data + :param val_selector: Selector for val data + :param test_selector: Selector for test data + :param verbosity: Verbosity for update messages + + :return: Corpus where test data is labeled with utterance likelihoods + """ + self.fit(corpus, train_selector=train_selector, val_selector=val_selector) + return self.transform(corpus, selector=test_selector, verbosity=verbosity) + + def summarize(self, corpus, top_sample_size=10, bottom_sample_size=10): + """ + Summarizes UtteranceLikelihood transformer using utterances with + high and low probabilities. + + :param corpus: Corpus to analyze + :param top_sample_size: Number of utterances to print for high probabilities + :param bottom_sample_size: Number of utterances to print for low probabilities + """ + utts = [utt for utt in corpus.iter_utterances() if self.utterance in utt.meta] + sorted_utts = sorted(utts, key=lambda utt: utt.meta[self.likelihood_attribute_name]) + top_sample_size = min(top_sample_size, len(sorted_utts)) + bottom_sample_size = min(bottom_sample_size, len(sorted_utts)) + print("[high]" + self.likelihood_attribute_name) + for i in range(-1, -1 - top_sample_size, -1): + utt = sorted_utts[i] + print(utt.speaker.id, ":", utt.text, "\n") + + print() + + print("[low]" + self.likelihood_attribute_name) + for i in range(bottom_sample_size): + utt = sorted_utts[i] + print(utt.speaker.id, ":", utt.text, "\n") + + return self diff --git a/docs/source/analysis.rst b/docs/source/analysis.rst index f92d29cc..aedb725d 100644 --- a/docs/source/analysis.rst +++ b/docs/source/analysis.rst @@ -17,3 +17,5 @@ These are the transformers related to generating some analysis of the Corpus. PairedPrediction Ranker SpeakerConvoDiversity + Redirection + UtteranceLikelihood diff --git a/docs/source/conf.py b/docs/source/conf.py index f7c1ede1..417a9ae0 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -64,9 +64,9 @@ # built documents. # # The short X.Y version. -version = "3.0" +version = "3.1" # The full version, including alpha/beta/rc tags. -release = "3.0.2" +release = "3.1.0" # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -375,4 +375,5 @@ "yaml", "bson", "dnspython", + "datasets", ] diff --git a/docs/source/index.rst b/docs/source/index.rst index a73dfe39..89e72710 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -9,7 +9,7 @@ Cornell Conversational Analysis Toolkit (ConvoKit) Documentation This toolkit contains tools to extract conversational features and analyze social phenomena in conversations, using a `single unified interface `_ inspired by (and compatible with) scikit-learn. Several large `conversational datasets `_ are included together with scripts exemplifying the use of the toolkit on these datasets. -More information can be found at our `website `_. The latest version is `3.0.2 `_ (released December 27, 2024). +More information can be found at our `website `_. The latest version is `3.1.0 `_ (released December 30, 2024). Contents -------- diff --git a/docs/source/redirectionAndUtteranceLikelihood.rst b/docs/source/redirectionAndUtteranceLikelihood.rst new file mode 100644 index 00000000..969ab3c8 --- /dev/null +++ b/docs/source/redirectionAndUtteranceLikelihood.rst @@ -0,0 +1,24 @@ +Redirection and Utterance Likelihood +==================================== + +The `Redirection` transformer measures the extent to which utterances +redirect the flow of the conversation, +as described in this +`paper `_. +The redirection effect of an utterance is determined by comparing the likelihood +of its reply given the immediate conversation context vs. a reference context +representing the previous direction of the conversation. + +The `UtteranceLikelihood` transformer is a more generalized module that just +implements log-likelihoods of utterances given a defined conversation context. + +Example usage: `redirection in supreme court oral arguments `_ + +.. automodule:: convokit.redirection.redirection + :members: + +.. automodule:: convokit.redirection.likelihoodModel + :members: + +.. automodule:: convokit.utteranceLikelihood.utteranceLikelihood + :members: \ No newline at end of file diff --git a/setup.py b/setup.py index a7bcd5ca..8cf22b4e 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ author_email="cristian@cs.cornell.edu", url="https://github.com/CornellNLP/ConvoKit", description="ConvoKit", - version="3.0.2", + version="3.1.0", packages=[ "convokit", "convokit.bag_of_words", @@ -27,6 +27,7 @@ "convokit.politenessStrategies", "convokit.prompt_types", "convokit.ranker", + "convokit.redirection", "convokit.text_processing", "convokit.speaker_convo_helpers", "convokit.speakerConvoDiversity", @@ -63,6 +64,13 @@ "numexpr>=2.8.0", "ruff>=0.4.8", "bottleneck", + "accelerate", + "peft", + "bitsandbytes", + "transformers", + "trl>=0.12.2", + "tensorflow>=2.18.0", + "tf-keras>=2.17.0,<3.0.0", ], extras_require={ "craft": ["torch>=0.12"],