Skip to content

Commit

Permalink
Merge pull request #13 from TianyiQ/main
Browse files Browse the repository at this point in the history
feat(evaluation): support for obtaining logprobs
  • Loading branch information
TianyiQ authored Nov 9, 2024
2 parents 2b086e7 + 497d8ac commit a6a0448
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 13 deletions.
45 changes: 35 additions & 10 deletions src/abstractions/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,10 @@ def start_inference_backend(
num_gpus: int = None,
template_type: Literal["auto", "alpaca", "mistral"] = "auto",
) -> Tuple[subprocess.Popen, Callable]:
"""Start an inference backend for a given model.
"""Start an inference backend for a given model.
Returns a tuple containing the backend process and the function to process a batch of samples.
When purpose is "logprobs", the returned function will return the log probability of the prompt text itself, without generating any text. The probability will be stored in the "logprob" field of the output dictionary, with all other fields staying the same.
When purpose is "responses", the returned function will generate a response to the prompt text. The response will be stored in the "predict" field of the output dictionary, with all other fields staying the same.
:param model_repoid_or_path: The model repo ID or path (e.g., "meta-llama/Llama-3.1-8B-Instruct").
:type model_repoid_or_path: str
Expand All @@ -240,6 +243,9 @@ def start_inference_backend(
num_gpus = torch.cuda.device_count()

if backend_type == "vllm":

if purpose == "logprobs":
raise ValueError("VLLM backend does not support logprobs purpose.")

LLM, SamplingParams, destroy_model_parallel = import_from_vllm()

Expand Down Expand Up @@ -399,6 +405,7 @@ def vllm_process_batch(
def get_response(
s, conversation: List, temperature: float = 0.2, max_tokens: int = 256
) -> str:
nonlocal purpose

for turn in conversation:
if turn["role"] == "assistant":
Expand All @@ -410,17 +417,26 @@ def get_response(
else:
raise ValueError(f"Unknown role: {turn['role']}")

s += sgl.assistant_begin()
if purpose == "responses":
s += sgl.assistant_begin()

s += sgl.gen(
"NA",
max_tokens=max_tokens,
return_logprob=False,
max_tokens=(max_tokens if purpose == "responses" else 0),
return_logprob=(purpose == "logprobs"),
logprob_start_len=(None if purpose == "responses" else 0),
temperature=temperature,
)

def sglang_process_batch(
sample_dicts: List[dict], temperature: float = 0.2, max_tokens: int = 256
) -> List[dict]:
"""Process a batch of samples using the sglang backend.
When purpose is "logprobs", it will return the log probability of the prompt text itself, without generating any text. The probability will be stored in the "logprob" field of the output dictionary, with all other fields staying the same.
When purpose is "responses", it will generate a response to the prompt text. The response will be stored in the "predict" field of the output dictionary, with all other fields staying the same.
"""
nonlocal purpose

if not os.environ.get("ALLOW_EMPTY_INPUT") or not eval(
os.environ.get("ALLOW_EMPTY_INPUT")
):
Expand All @@ -434,7 +450,7 @@ def sglang_process_batch(
found = 1
dic["input"] = dic["instruction"]

dialogues = dict_to_dialogue_list(sample_dicts)
dialogues = dict_to_dialogue_list(sample_dicts, purpose)
output = get_response.run_batch(
[
{
Expand Down Expand Up @@ -490,9 +506,14 @@ def sglang_process_batch(
)

for dic, out in zip(sample_dicts, output):
dic["predict"] = (
out["NA"] if out.get_meta_info("NA") is not None else None
)
if purpose == "logprobs":
dic["logprob"] = sum(
x[0] for x in list(out.get_meta_info("NA")['input_token_logprobs']) if x[0] is not None
)
else:
dic["predict"] = (
out["NA"] if out.get_meta_info("NA") is not None else None
)

return sample_dicts

Expand All @@ -502,7 +523,7 @@ def sglang_process_batch(


def dict_to_dialogue_list(
dic: Union[dict, List[dict]]
dic: Union[dict, List[dict]], purpose: Literal["responses", "logprobs"] = "responses"
) -> Union[List[Dict[str, str]], List[List[Dict[str, str]]]]:
"""Transform a dictionary into a list of dialogue turns in OpenAI format.
Expand All @@ -512,10 +533,14 @@ def dict_to_dialogue_list(
:rtype: Union[List[Dict[str, str]], List[List[Dict[str, str]]]
"""
if isinstance(dic, dict):
return [
res = [
{"role": "system", "content": dic["instruction"]},
{"role": "user", "content": dic["input"]},
]
if purpose == "logprobs" and "predict" in dic:
res.append({"role": "assistant", "content": dic["predict"]})

return res

return [dict_to_dialogue_list(d) for d in dic]

Expand Down
18 changes: 15 additions & 3 deletions src/abstractions/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,15 @@ def inference_standalone(
query_field_name: str,
temperature: float,
backend_type: Literal["sglang", "vllm"],
purpose: Literal["responses", "logprobs"],
conn: multiprocessing.connection.Connection,
):
backend, process_batch = start_inference_backend(
model_path,
backend_type,
purpose="responses",
num_gpus=num_gpus,
template_type=template_type,
purpose=purpose,
)

data = Data(data_name="temporary", data_type="sft", data_path=data_path)
Expand Down Expand Up @@ -661,6 +662,7 @@ def inference(
backend: Literal["sglang", "vllm", "deepspeed", "serial"] = "sglang",
batch_size_multiplier_log2: int = 0,
temperature=0.0,
purpose: Literal["responses", "logprobs"] = "responses",
) -> Union[Data, List[Dict[str, str]]]:
"""Performance inference on a dataset (currently only instruction datasets are tested, with the same format as SFT datasets),
and
Expand All @@ -679,6 +681,9 @@ def inference(
:param temperature: The temperature parameter
:type temperature: float = 0.0
:param purpose: The purpose of the inference. It can be "responses" or "logprobs". If "logprobs", the log probability of the prompt itself (and the assistant response supplied in the `predict` field, if exists) is returned in the `logprob` field of the resulting dataset, without doing any completion. If "responses", the completion text is saved in the `predict` field of the resulting dataset.
:type purpose: Literal["responses", "logprobs"] = "responses"
:return: returns the resulting dataset (completion text saved in the `predict` field of dicts; other fields are preserved).
:rtype: Union[Data, List[Dict[str, str]]].
Expand Down Expand Up @@ -716,6 +721,12 @@ def inference(
warnings.warn("vllm is disabled. Switching to sglang backend.")
backend = "sglang"

if purpose == "logprobs" and backend != "sglang":
warnings.warn(
"Logprobs are only supported with backend=sglang. Switching to sglang backend."
)
backend = "sglang"

if input_is_data:
assert (
data.data_type != "pretrain" or backend == "deepspeed"
Expand All @@ -737,7 +748,7 @@ def inference(

result = (
self.__inference_parallel_segregated(
data, result_data_name, temperature, backend
data, result_data_name, temperature, backend, purpose
)
if backend in ["vllm", "sglang"]
else self.__inference_parallel_deepspeed(
Expand Down Expand Up @@ -770,7 +781,7 @@ def inference(
return result

def __inference_parallel_segregated(
self, data: Data, result_data_name: str, temperature: float, backend_type: str
self, data: Data, result_data_name: str, temperature: float, backend_type: str, purpose: str
) -> Data:
"""sglang/vllm implementation for `inference()`, but performed in a separate process to free up GPU memory. This is the recommended implementation, due to its superior speed and robustness."""
data_path = data.data_path
Expand Down Expand Up @@ -799,6 +810,7 @@ def __inference_parallel_segregated(
query_field_name,
temperature,
backend_type,
purpose,
child_conn,
),
)
Expand Down

0 comments on commit a6a0448

Please sign in to comment.