Skip to content

Commit

Permalink
Merge pull request #17 from TianyiQ/main
Browse files Browse the repository at this point in the history
feat(evaluation): support incremental logprob calculation
  • Loading branch information
TianyiQ authored Nov 21, 2024
2 parents 6bac550 + 95e18a6 commit 3f88534
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 34 deletions.
50 changes: 37 additions & 13 deletions examples/abstractions/inference_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,10 @@
from src.abstractions import Model, Data
from src.download_models import download_all_models

if __name__ == "__main__":

download_all_models(download_8B=True, download_70B=False)
histllama = Model(model_name="8B-C021-instruct", is_instruct_finetuned=True)

# Custom models (local or on hub) can be similarly loaded, e.g.:
# model = Model(
# "mixtral-8x7b-instruct-v0.1",
# model_path="mistralai/Mixtral-8x7B-Instruct-v0.1",
# template_type="mistral",
# )

def dataset_inference_example(histllama: Model):
alpaca_data = Data("alpaca_gpt4_data_en", data_type="sft")

# For custom datasets (must be local), their data fields need to be registered before used in inference, e.g.:
# For custom datasets (either in-memory or stored locally), their data fields need to be registered before used in inference, e.g.:
# custom_data.set_key_fields(
# prompt_field_name="instruction", query_field_name="input"
# )
Expand All @@ -37,3 +26,38 @@

vec = histllama.evaluate()
print("Preference vector: ", vec)


def logprob_example(histllama: Model):
custom_data = Data(
"custom_data",
data_type="sft",
data_content = [
{
"input": "What is the capital of France?",
"predict": ["Paris", "Washington D.C.", "London", "Berlin"],
}
],
)
custom_data.set_key_fields(query_field_name="input")

logprob_output: Data = histllama.inference(
custom_data, "8B-C021-infer-custom-deepspeed", backend="sglang", purpose="logprobs"
)
print(list(logprob_output.all_passages()))
# [{'predict': ['Paris', 'Washington D.C.', 'London', 'Berlin'], 'input': 'What is the capital of France?', 'logprob': [-9.92294692993164, -17.21290510520339, -11.677074432373047, -12.903636932373047]}]


if __name__ == "__main__":

download_all_models(download_8B=True, download_70B=False)
histllama = Model(model_name="8B-C021-instruct", is_instruct_finetuned=True)
# Custom models (local or on hub) can be similarly loaded, e.g.:
# model = Model(
# "mixtral-8x7b-instruct-v0.1",
# model_path="mistralai/Mixtral-8x7B-Instruct-v0.1",
# template_type="mistral",
# )

dataset_inference_example(histllama)
logprob_example(histllama)
52 changes: 34 additions & 18 deletions src/abstractions/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
os.makedirs("./output/downloaded", exist_ok=True)

MY_USERNAME = pwd.getpwuid(os.getuid()).pw_name
PORT_NUM = 14285
PORT_NUM = 17785


# escape spaces in paths
Expand Down Expand Up @@ -407,7 +407,7 @@ def vllm_process_batch(

@sgl.function
def get_response(
s, conversation: List, temperature: float = 0.2, max_tokens: int = 256
s, conversation: List, temperature: float = 0.2, max_tokens: int = 256, options: list = []
) -> str:
nonlocal purpose

Expand All @@ -424,13 +424,21 @@ def get_response(
if purpose == "responses":
s += sgl.assistant_begin()

s += sgl.gen(
"NA",
max_tokens=(max_tokens if purpose == "responses" else 0),
return_logprob=(purpose == "logprobs"),
logprob_start_len=(None if purpose == "responses" else 0),
temperature=temperature,
)
if options:
print("Options provided:", options)
s += sgl.gen(
"NA",
choices=options,
)

else:
s += sgl.gen(
"NA",
max_tokens=(max_tokens if purpose == "responses" else 0),
return_logprob=(purpose == "logprobs"),
logprob_start_len=(None if purpose == "responses" else 0),
temperature=temperature,
)

def sglang_process_batch(
sample_dicts: List[dict], temperature: float = 0.2, max_tokens: int = 256
Expand All @@ -455,14 +463,16 @@ def sglang_process_batch(
dic["input"] = dic["instruction"]

dialogues = dict_to_dialogue_list(sample_dicts, purpose)
options_lists = [(dic["predict"] if "predict" in dic and isinstance(dic["predict"], list) else []) for dic in sample_dicts]
output = get_response.run_batch(
[
{
"conversation": dialogue,
"temperature": temperature,
"max_tokens": max_tokens,
"options": options,
}
for dialogue in dialogues
for dialogue, options in zip(dialogues, options_lists)
],
progress_bar=True,
)
Expand Down Expand Up @@ -511,9 +521,16 @@ def sglang_process_batch(

for dic, out in zip(sample_dicts, output):
if purpose == "logprobs":
dic["logprob"] = sum(
x[0] for x in list(out.get_meta_info("NA")['input_token_logprobs']) if x[0] is not None
)
if "predict" in dic and isinstance(dic["predict"], list):
dic["logprob"] = [
sum(x[0] for x in y if x[0] is not None)
for y in list(out.get_meta_info("NA")['input_token_logprobs'])
]
assert len(dic["logprob"]) == len(dic["predict"])
else:
dic["logprob"] = sum(
x[0] for x in list(out.get_meta_info("NA")['input_token_logprobs']) if x[0] is not None
)
else:
dic["predict"] = (
out["NA"] if out.get_meta_info("NA") is not None else None
Expand All @@ -537,11 +554,10 @@ def dict_to_dialogue_list(
:rtype: Union[List[Dict[str, str]], List[List[Dict[str, str]]]
"""
if isinstance(dic, dict):
res = [
{"role": "system", "content": dic["instruction"]},
{"role": "user", "content": dic["input"]},
]
if purpose == "logprobs" and "predict" in dic:
res = [{"role": "user", "content": dic["input"]}]
if "instruction" in dic:
res = [{"role": "system", "content": dic["instruction"]}] + res
if purpose == "logprobs" and "predict" in dic and isinstance(dic["predict"], str):
res.append({"role": "assistant", "content": dic["predict"]})

return res
Expand Down
41 changes: 38 additions & 3 deletions src/abstractions/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def transform(
forced_rewrite: bool = False,
max_batch_size: int = 1,
keep_key_fields: bool = True,
map_key_fields: bool = False,
) -> "Data":
"""
Apply transformation to every element of the current dataset (in the format of a json list of json dicts where the values are of mutable or immutable types), and returns a Data instance containing the resulting dataset.
Expand Down Expand Up @@ -216,31 +217,65 @@ def write_dict(sample_dict: Dict):
is_first = False
out_file.write(json.dumps(sample_dict))
# out_file.flush()

def map_key_fields_fn(sample_dict: Dict) -> Dict:
nonlocal self
if "prompt" in self.key_fields and self.key_fields["prompt"] != "instruction":
sample_dict["instruction"] = sample_dict[self.key_fields["prompt"]]
del sample_dict[self.key_fields["prompt"]]
if "query" in self.key_fields and self.key_fields["query"] != "input":
sample_dict["input"] = sample_dict[self.key_fields["query"]]
del sample_dict[self.key_fields["query"]]
if "response" in self.key_fields and self.key_fields["response"] != "output":
sample_dict["output"] = sample_dict[self.key_fields["response"]]
del sample_dict[self.key_fields["response"]]

return sample_dict

def inv_map_key_fields_fn(sample_dict: Dict) -> Dict:
nonlocal self
if "instruction" in sample_dict and self.key_fields["prompt"] != "instruction":
sample_dict[self.key_fields["prompt"]] = sample_dict["instruction"]
del sample_dict["instruction"]
if "input" in sample_dict and self.key_fields["query"] != "input":
sample_dict[self.key_fields["query"]] = sample_dict["input"]
del sample_dict["input"]
if "output" in sample_dict and self.key_fields["response"] != "output":
sample_dict[self.key_fields["response"]] = sample_dict["output"]
del sample_dict["output"]

return sample_dict

with open(out_path, "w") as out_file:
out_file.write("[")
is_first = True

if max_batch_size == 1:
for element in tw.read_json_memory_efficient(self.data_path):
if map_key_fields:
element = map_key_fields_fn(element)

transformed = transformation(element)
if transformed is not None:
write_dict(transformed)
write_dict(transformed if not map_key_fields else inv_map_key_fields_fn(transformed))

else:
buffer = []

for element in tw.read_json_memory_efficient(self.data_path):
if map_key_fields:
element = map_key_fields_fn(element)

buffer.append(element)
if len(buffer) == max_batch_size:
for e in transformation(buffer):
write_dict(e)
write_dict(e if not map_key_fields else inv_map_key_fields_fn(e))
buffer = []
out_file.flush()

if buffer:
for e in transformation(buffer):
write_dict(e)
write_dict(e if not map_key_fields else inv_map_key_fields_fn(e))

out_file.write("\n]")

Expand Down
1 change: 1 addition & 0 deletions src/abstractions/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def inference_standalone(
else False
),
max_batch_size=262144,
map_key_fields=True,
)
print("Job finished.")
conn.send(result_data.data_path)
Expand Down

0 comments on commit 3f88534

Please sign in to comment.