Skip to content

Commit

Permalink
Merge pull request #18 from TianyiQ/main
Browse files Browse the repository at this point in the history
fix(evaluation): logprob inference on large batches
  • Loading branch information
TianyiQ authored Nov 22, 2024
2 parents 6108788 + adeb8a6 commit 9d1e526
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
20 changes: 17 additions & 3 deletions src/abstractions/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import warnings
import tqdm
from transformers import AutoTokenizer
import random

# create output directories
os.makedirs("./output/benchmark_results", exist_ok=True)
Expand All @@ -41,8 +42,9 @@
os.makedirs("./output/saved/saved_data/", exist_ok=True)
os.makedirs("./output/downloaded", exist_ok=True)

random.seed(time.time())
MY_USERNAME = pwd.getpwuid(os.getuid()).pw_name
PORT_NUM = 17785
PORT_NUM = 17785 + random.randint(0, 2000)


# escape spaces in paths
Expand Down Expand Up @@ -425,9 +427,9 @@ def get_response(
s += sgl.assistant_begin()

if options:
print("Options provided:", options)
s += sgl.gen(
"NA",
max_tokens=max(len(x) for x in options)+10,
choices=options,
)

Expand Down Expand Up @@ -493,6 +495,7 @@ def sglang_process_batch(
"conversation": dialogues[k],
"temperature": temperature,
"max_tokens": max_tokens,
"options": options_lists[k],
}
for k in bad_indices
],
Expand All @@ -518,8 +521,16 @@ def sglang_process_batch(
warnings.warn(
f"{count} cases still not completed after 10 retries. Use NO_SGLANG=1 to disable sglang backend."
)

if count > 100 or count / len(output) > 0.01:
raise Exception(f"Too many cases ({count}) still not completed.")

failure_count = 0
for dic, out in zip(sample_dicts, output):
if out.get_meta_info("NA") is None:
failure_count += 1
continue

if purpose == "logprobs":
if "predict" in dic and isinstance(dic["predict"], list):
dic["logprob"] = [
Expand All @@ -535,7 +546,10 @@ def sglang_process_batch(
dic["predict"] = (
out["NA"] if out.get_meta_info("NA") is not None else None
)


if failure_count > count:
raise Exception(f"More actual failures ({failure_count}) than cases not completed ({count}), which is unexpected.")

return sample_dicts

return backend, sglang_process_batch
Expand Down
2 changes: 1 addition & 1 deletion src/abstractions/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def inference_standalone(
if hasattr(Model, "always_force_rewrite")
else False
),
max_batch_size=4096,
max_batch_size=262144,
map_key_fields=True,
)
print("Job finished.")
Expand Down
2 changes: 1 addition & 1 deletion src/evaluation/test_eval_01.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os, json
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

from ..abstractions import Model
from .utils import generate_alpaca, _collect
Expand Down

0 comments on commit 9d1e526

Please sign in to comment.