Skip to content

Commit

Permalink
Merge pull request #19 from TianyiQ/main
Browse files Browse the repository at this point in the history
feat(abstractions): add support for multi-turn dialogue manipulation and inference
  • Loading branch information
TianyiQ authored Nov 28, 2024
2 parents 9d1e526 + d553feb commit b898f21
Show file tree
Hide file tree
Showing 23 changed files with 215 additions and 69 deletions.
2 changes: 1 addition & 1 deletion algorithms/extrapolative_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd
import json
import datasets
from src.text_writer import write_log
from src.text_utils import write_log
from benchmark import JudgeBase, ExamineeBase, PredictJudge
from algorithms.utils.rw_utils import elicit_rw_preference, default_rw_data
from algorithms.utils.extrapolation_utils import extrapolate
Expand Down
2 changes: 1 addition & 1 deletion algorithms/extrapolative_rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd
import json
import datasets
from src.text_writer import write_log
from src.text_utils import write_log
from benchmark import JudgeBase, ExamineeBase, PredictJudge
from algorithms.utils.rw_utils import (
elicit_rw_preference,
Expand Down
2 changes: 1 addition & 1 deletion algorithms/lifelong_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd
import json
import datasets
from src.text_writer import write_log
from src.text_utils import write_log
from benchmark import JudgeBase, ExamineeBase, PredictJudge
from algorithms.utils.rw_utils import elicit_rw_preference, default_rw_data
import warnings
Expand Down
2 changes: 1 addition & 1 deletion algorithms/lifelong_rlhf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd
import json
import datasets
from src.text_writer import write_log
from src.text_utils import write_log
from benchmark import JudgeBase, ExamineeBase, PredictJudge
from algorithms.utils.rw_utils import (
elicit_rw_preference,
Expand Down
2 changes: 1 addition & 1 deletion algorithms/utils/extrapolation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pandas as pd
import json
import datasets
from src.text_writer import write_log
from src.text_utils import write_log
import warnings
from tqdm import tqdm
from sympy import binomial
Expand Down
2 changes: 1 addition & 1 deletion algorithms/utils/rw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd
import json
from datasets import load_dataset
from src.text_writer import write_log
from src.text_utils import write_log
import src.evaluation.utils as eval_utils
from benchmark import JudgeBase, ExamineeBase, PredictJudge
import warnings
Expand Down
2 changes: 1 addition & 1 deletion build_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import src.text_writer as tw
import src.text_utils as tw
import src.cleanser.rule_based_cleanser as rb
import src.cleanser.localllm_cleanser as llm_cleanser
import src.model_training.train_hislm as hislm
Expand Down
77 changes: 59 additions & 18 deletions examples/abstractions/finetuning_datamanip.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
from src.abstractions import Model, Data, DataFileCollection

if __name__ == "__main__":
gemma2b_base = Model(
model_name="gemma-2b",
model_path_or_repoid="google/gemma-2-2b", # or specify a local path if you have downloaded the model
is_instruct_finetuned=False,
)

gemma2b_base = Model(
model_name="gemma-2b",
model_path="google/gemma-2-2b", # or specify a local path if you have downloaded the model
is_instruct_finetuned=False,
)
llama8b_instruct = Model(
model_name="Llama-3.1-8B-Instruct",
model_path_or_repoid="meta-llama/Llama-3.1-8B-Instruct",
is_instruct_finetuned=True,
)

def continue_pretrain():
# ============== Continue pretraining from Gemma 2B ==============
global gemma2b_c4
c4_data = Data("c4_demo", data_type="pretrain")
gemma2b_c4 = gemma2b_base.finetune(
c4_data, stage="pretrain", algo="full_param", result_model_name="gemma-2b_c4"
)
print(gemma2b_c4.is_instruct_finetuned) # False

def supervised_finetune():
# ============== Then do SFT using alpaca data ==============
global gemma2b_c4_alpaca
alpaca_data = Data("alpaca_gpt4_en", data_type="sft")
gemma2b_c4_alpaca = gemma2b_c4.finetune(
alpaca_data,
Expand All @@ -25,17 +33,7 @@
)
print(gemma2b_c4_alpaca.is_instruct_finetuned) # True
gemma2b_c4_alpaca.save_permanent() # saved to output/saved/saved_model/gemma-2b_c4_alpaca

# ============== Then do DPO using ORCA data ==============
hh_data = Data("orca_rlhf", data_type="preference")
gemma2b_c4_alpaca_orca = gemma2b_c4_alpaca.finetune(
hh_data,
stage="dpo",
algo="full_param",
result_model_name="gemma-2b_c4_alpaca_orca",
)
gemma2b_c4_alpaca_orca.save_permanent() # saved to output/saved/saved_model/gemma-2b_c4_alpaca_orca


# ============== Or maybe, we should censor curse words before SFT ==============
def remove_curse_words(sample_dict: dict) -> dict:
filter = lambda s: (
Expand All @@ -56,7 +54,7 @@ def remove_curse_words(sample_dict: dict) -> dict:
)
gemma2b_c4_alpaca_G.save_permanent() # saved to output/saved/saved_model/gemma-2b_c4_alpaca_G
alpaca_data_G.save_permanent_and_register() # saved to output/saved/saved_model/alpaca_gpt4_en_G.json & added to llama-factory dataset registry

# ============== What about using our own data (scattered across multiple files in multiple directories) for finetuning? ==============
histext_collection = DataFileCollection( # build a collection holding json files of year 1826 to 2018
collection_name="histext_1826_to_2018_collection",
Expand Down Expand Up @@ -93,3 +91,46 @@ def remove_nonstr_data(sample_dict: dict) -> dict:
algo="full_param",
result_model_name="gemma-2b_histext",
)

def direct_preference_optimization():
# ============== Then do DPO using ORCA data ==============
global gemma2b_c4_alpaca_orca
hh_data = Data("orca_rlhf", data_type="preference")
gemma2b_c4_alpaca_orca = gemma2b_c4_alpaca.finetune(
hh_data,
stage="dpo",
algo="full_param",
result_model_name="gemma-2b_c4_alpaca_orca",
)
gemma2b_c4_alpaca_orca.save_permanent() # saved to output/saved/saved_model/gemma-2b_c4_alpaca_orca

def dialogue_manipulation():
# ============== Generating a dialogue, using a model to play the role of both user and assistant ==============
global llama8b_instruct
dialogue_data = Data(
"dialogue_data",
data_content=[
{
"input": "Is Eiffel Tower in Paris?",
"history": [
["What is the capital of France?", "Paris."],
]
}
]
)
dialogue_data = llama8b_instruct.inference(
dialogue_data, "dialogue_data", backend="sglang"
)
dialogue_data = dialogue_data.switch_role_to_user()
dialogue_data = llama8b_instruct.inference(
dialogue_data, "dialogue_data", backend="sglang"
)
dialogue_data = dialogue_data.switch_role_to_assistant()
print(list(dialogue_data.all_passages()))


if __name__ == "__main__":
# continue_pretrain()
# supervised_finetune()
# direct_preference_optimization()
dialogue_manipulation()
2 changes: 1 addition & 1 deletion examples/abstractions/inference_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def logprob_example(histllama: Model):
# Custom models (local or on hub) can be similarly loaded, e.g.:
# model = Model(
# "mixtral-8x7b-instruct-v0.1",
# model_path="mistralai/Mixtral-8x7B-Instruct-v0.1",
# model_path_or_repoid="mistralai/Mixtral-8x7B-Instruct-v0.1",
# template_type="mistral",
# )

Expand Down
17 changes: 14 additions & 3 deletions src/abstractions/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,9 +568,20 @@ def dict_to_dialogue_list(
:rtype: Union[List[Dict[str, str]], List[List[Dict[str, str]]]
"""
if isinstance(dic, dict):
res = [{"role": "user", "content": dic["input"]}]
if "instruction" in dic:
res = [{"role": "system", "content": dic["instruction"]}] + res
res = []

if "system" in dic:
res = [{"role": "system", "content": dic["system"]}]

if "history" in dic:
for turn in dic["history"]:
res.append({"role": "user", "content": turn[0]}, {"role": "assistant", "content": turn[1]})

if "input" in dic or "instruction" in dic:
input = dic.get("input", "")
instruction = dic.get("instruction", "")
res.append({"role": "user", "content": input + ("\n\n" if input and instruction else "") + instruction})

if purpose == "logprobs" and "predict" in dic and isinstance(dic["predict"], str):
res.append({"role": "assistant", "content": dic["predict"]})

Expand Down
134 changes: 114 additions & 20 deletions src/abstractions/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os
import json
import warnings
import src.text_writer as tw
import src.text_utils as tw
from tqdm import tqdm

with open("./src/abstractions/configs/abstractions_config.json", "r") as config_file:
Expand Down Expand Up @@ -56,6 +56,13 @@ class Data:
name2data: Dict[str, Any] = {}
always_force_rewrite: bool = True
data_type: Literal["pretrain", "sft", "preference"]

default_key_fields = {
"prompt": "instruction",
"query": "input",
"response": "output",
"history": "history",
}

# check with user before removing a file
@classmethod
Expand Down Expand Up @@ -196,7 +203,7 @@ def transform(
:param result_data_name: The name of the resulting data. Do not include path in result_data_name.
:type result_data_name: str
:param forced_rewrite: Whether to forcefully rewrite the existing data
:param forced_rewrite: Whether to forcefully rewrite the existing file, if there is one.
:type forced_rewrite: bool = False
:param max_batch_size: If max_batch_size is specified and is >1, the transformation function must take inputs of type List[Dict] and return a List[Dict].
Expand All @@ -220,29 +227,19 @@ def write_dict(sample_dict: Dict):

def map_key_fields_fn(sample_dict: Dict) -> Dict:
nonlocal self
if "prompt" in self.key_fields and self.key_fields["prompt"] != "instruction":
sample_dict["instruction"] = sample_dict[self.key_fields["prompt"]]
del sample_dict[self.key_fields["prompt"]]
if "query" in self.key_fields and self.key_fields["query"] != "input":
sample_dict["input"] = sample_dict[self.key_fields["query"]]
del sample_dict[self.key_fields["query"]]
if "response" in self.key_fields and self.key_fields["response"] != "output":
sample_dict["output"] = sample_dict[self.key_fields["response"]]
del sample_dict[self.key_fields["response"]]
for k, v in self.default_key_fields.items():
if k in self.key_fields and self.key_fields[k] != v:
sample_dict[v] = sample_dict[self.key_fields[k]]
del sample_dict[self.key_fields[k]]

return sample_dict

def inv_map_key_fields_fn(sample_dict: Dict) -> Dict:
nonlocal self
if "instruction" in sample_dict and self.key_fields["prompt"] != "instruction":
sample_dict[self.key_fields["prompt"]] = sample_dict["instruction"]
del sample_dict["instruction"]
if "input" in sample_dict and self.key_fields["query"] != "input":
sample_dict[self.key_fields["query"]] = sample_dict["input"]
del sample_dict["input"]
if "output" in sample_dict and self.key_fields["response"] != "output":
sample_dict[self.key_fields["response"]] = sample_dict["output"]
del sample_dict["output"]
for k, v in self.default_key_fields.items():
if v in sample_dict and self.key_fields[k] != v:
sample_dict[self.key_fields[k]] = sample_dict[v]
del sample_dict[v]

return sample_dict

Expand Down Expand Up @@ -284,6 +281,94 @@ def inv_map_key_fields_fn(sample_dict: Dict) -> Dict:
result.key_fields = self.key_fields.copy()
return result

def move_current_to_history(self):
"""
Move the current dialogue turn in the prompt/question field and the response/predict field to the history field.
:return: The data after the operation.
:rtype: Data.
"""
def move_to_history_fn(sample_dict: Dict) -> Dict:
if sample_dict.get("instruction", "") or sample_dict.get("input", "") or sample_dict.get("output", "") or sample_dict.get("predict", ""):
assert (sample_dict.get("instruction", "") or sample_dict.get("input", "")) and (sample_dict.get("output", "") or sample_dict.get("predict", ""))
sample_dict["history"] = sample_dict.get("history", []) + [
[
sample_dict.get("instruction", "") +
("\n\n" if "instruction" in sample_dict and "input" in sample_dict else "") +
sample_dict.get("input", ""),
sample_dict.get("output", "") + sample_dict.get("predict", "")
]
]
sample_dict.pop("instruction", None)
sample_dict.pop("input", None)
sample_dict.pop("output", None)
sample_dict.pop("predict", None)

return sample_dict

return self.transform(move_to_history_fn, self.data_name, forced_rewrite=True, map_key_fields=True)

def switch_role_to_user(self, user_system_prompt: str = None, dialogue_starter: str = None):
"""
Switch the prompt/question field and the response/predict field, thereby shifting the dialogue turn from the assistant to the user.
:param user_system_prompt: The system prompt of the user role.
:type user_system_prompt: str = None
:param dialogue_starter: Placeholder message for the "zeroth" dialogue turn by the assistant that prompts the user to start the conversation.
:type dialogue_starter: str = None
:return: The data after the operation.
:rtype: Data.
"""
if user_system_prompt is None:
user_system_prompt = "You are an assistant tasked with questioning the user, aka your partner. Ask informed questions to guide the conversation, follow up on the user's responses, and generally follow a natural conversation flow. Don't be too courteous; be concise."

if dialogue_starter is None:
dialogue_starter = "I am your partner. Please directly ask your first question."

moved_to_history = self.move_current_to_history()

def switch_role_to_user_fn(sample_dict: Dict) -> Dict:
assert not (sample_dict.get("instruction", "") or sample_dict.get("input", "") or sample_dict.get("output", "") or sample_dict.get("predict", ""))

all_histories = [h[i] for h in sample_dict.get("history", []) for i in range(2)]
all_histories = [dialogue_starter] + all_histories
assert len(all_histories) % 2 == 1
sample_dict["history"] = [[all_histories[i], all_histories[i + 1]] for i in range(len(all_histories)-1, 2)]
sample_dict["instruction"] = all_histories[-1]
sample_dict["system"] = user_system_prompt
return sample_dict

return moved_to_history.transform(switch_role_to_user_fn, self.data_name, forced_rewrite=True, map_key_fields=True)

def switch_role_to_assistant(self, assistant_system_prompt: str = None):
"""
Switch the prompt/question field and the response/predict field, thereby shifting the dialogue turn from the user to the assistant.
:param assistant_system_prompt: The system prompt of the assistant role.
:type assistant_system_prompt: str = None
:return: The data after the operation.
:rtype: Data.
"""
if assistant_system_prompt is None:
assistant_system_prompt = "Please answer the user's questions. Be concise and not overly courteous, but be informative and provide all necessary details."

moved_to_history = self.move_current_to_history()

def switch_role_to_assistant_fn(sample_dict: Dict) -> Dict:
assert not (sample_dict.get("instruction", "") or sample_dict.get("input", "") or sample_dict.get("output", "") or sample_dict.get("predict", ""))

all_histories = [h[i] for h in sample_dict.get("history", []) for i in range(2)]
assert len(all_histories) % 2 == 0
sample_dict["history"] = [[all_histories[i], all_histories[i + 1]] for i in range(1, len(all_histories)-1, 2)]
sample_dict["instruction"] = all_histories[-1]
sample_dict["system"] = assistant_system_prompt
return sample_dict

return moved_to_history.transform(switch_role_to_assistant_fn, self.data_name, forced_rewrite=True, map_key_fields=True)

def manage_llama_factory_registration(
self, operation: Literal["add", "remove", "query"], forced_update: bool = True
) -> bool:
Expand Down Expand Up @@ -372,6 +457,7 @@ def set_key_fields(
query_field_name: Optional[str] = None,
response_field_name: Optional[str] = None,
system_field_name: Optional[str] = None,
history_field_name: Optional[str] = None,
suppress_registration_update: bool = False,
**kwargs,
) -> None:
Expand All @@ -393,6 +479,9 @@ def set_key_fields(
:param system_field_name: The name of the system field
:type system_field_name: Optional[str] = None
:param history_field_name: The name of the history field
:type history_field_name: Optional[str] = None
:param suppress_registration_update: Whether to suppress the update of the registration
:type suppress_registration_update: bool = False
Expand Down Expand Up @@ -428,6 +517,11 @@ def set_key_fields(
del self.key_fields["system"]
elif system_field_name:
self.key_fields["system"] = system_field_name

if history_field_name == "" and "history" in self.key_fields:
del self.key_fields["history"]
elif history_field_name:
self.key_fields["history"] = history_field_name

if isinstance(kwargs, dict):
for k, v in kwargs.items():
Expand Down
Loading

0 comments on commit b898f21

Please sign in to comment.