Skip to content

Commit

Permalink
API inference support
Browse files Browse the repository at this point in the history
  • Loading branch information
ZubinGou committed Nov 28, 2023
1 parent 84959f1 commit a1342f8
Show file tree
Hide file tree
Showing 7 changed files with 347 additions and 70 deletions.
58 changes: 32 additions & 26 deletions src/infer/inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""
This file is based on: https://github.com/microsoft/ProphetNet/tree/master/CRITIC
This script support vllm batch inference with cot/pal/tora prompt.
Also sopport inference of fine-tuned models like WizardMath/ToRA.
Code based on: https://github.com/microsoft/ProphetNet/tree/master/CRITIC
"""
import random
import os
Expand All @@ -19,29 +21,34 @@
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--data_name", default="gsm8k", type=str)
parser.add_argument("--data_dir", default="./data", type=str)
parser.add_argument("--model_name_or_path", default="gpt-4", type=str)
parser.add_argument("--prompt_type", default="pal", type=str)
parser.add_argument("--output_dir", default="./output", type=str)
parser.add_argument("--prompt_type", default="tora", type=str)
parser.add_argument("--split", default="test", type=str)
parser.add_argument("--num_test_sample", default=-1, type=int) # -1 for full data
parser.add_argument("--seed", default=0, type=int)
parser.add_argument("--start", default=0, type=int)
parser.add_argument("--end", default=-1, type=int)
parser.add_argument("--temperature", default=0, type=float)
parser.add_argument("--n_sampling", default=1, type=int)
parser.add_argument("--top_p", default=0.95, type=float)
parser.add_argument("--top_p", default=1, type=float)
parser.add_argument("--max_tokens_per_call", default=1024, type=int)
parser.add_argument("--shuffle", action="store_true")
parser.add_argument("--use_train_prompt_format", action="store_true")
args = parser.parse_args()
args.top_p = 1 if args.temperature == 0 else args.top_p # top_p must be 1 when using greedy sampling (vllm)
return args


def main(args):
examples = load_data(args.data_name, args.split)
def prepare_data(args):
examples = load_data(args.data_name, args.split, args.data_dir)

# sample `num_test_sample` from dataset
if args.num_test_sample > 0:
examples = random.sample(examples, args.num_test_sample)
elif args.num_test_sample == -1:
args.num_test_sample = len(examples)

# shuffle
if args.shuffle:
Expand All @@ -53,19 +60,18 @@ def main(args):
args.end = len(examples)
examples = examples[args.start:args.end]

# get out_file
# get out_file name
dt_string = datetime.now().strftime("%m-%d_%H-%M")
model_name = "/".join(args.model_name_or_path.split("/")[-2:])
file_prompt_type = args.prompt_type.replace("program_only", "tora")
out_file_prefix = f'{args.split}_{file_prompt_type}_{args.num_test_sample}_seed{args.seed}_t{args.temperature}'
out_file = f'outputs/{model_name}/{args.data_name}/{out_file_prefix}_s{args.start}_e{args.end}_{dt_string}.jsonl'
os.makedirs(f'outputs/{model_name}/{args.data_name}', exist_ok=True)
out_file_prefix = f'{args.split}_{args.prompt_type}_{args.num_test_sample}_seed{args.seed}_t{args.temperature}'
out_file = f'{args.output_dir}/{model_name}/{args.data_name}/{out_file_prefix}_s{args.start}_e{args.end}_{dt_string}.jsonl'
os.makedirs(f'{args.output_dir}/{model_name}/{args.data_name}', exist_ok=True)

# all files in the output folder
processed_files = [f for f in os.listdir(f"outputs/{model_name}/{args.data_name}/") if f.endswith(".jsonl") and f.startswith(out_file_prefix)]
# load all processed samples
processed_files = [f for f in os.listdir(f"{args.output_dir}/{model_name}/{args.data_name}/") if f.endswith(".jsonl") and f.startswith(out_file_prefix)]
processed_samples = []
for f in processed_files:
processed_samples.extend(list(load_jsonl(f"outputs/{model_name}/{args.data_name}/{f}")))
processed_samples.extend(list(load_jsonl(f"{args.output_dir}/{model_name}/{args.data_name}/{f}")))

# dedepulicate
processed_samples = {sample['idx']: sample for sample in processed_samples}
Expand All @@ -76,9 +82,13 @@ def main(args):
print(f"Idx {args.start} - {args.end}: Remain {len(examples)}/{total_examples} samples.")
if len(examples) == 0:
pass
# return
else:
print(examples[0])
return examples, processed_samples, out_file


def main(args):
examples, processed_samples, out_file = prepare_data(args)

# init python executor
if "pal" in args.prompt_type:
Expand All @@ -102,7 +112,8 @@ def main(args):
sample = {'idx': idx, 'question': example['question'], 'gt_cot': gt_cot, 'gt': gt_ans, 'prompt': full_prompt}

# add remain fields
for key in ['level', 'type', 'unit', 'solution_type', 'choices', 'solution', 'ques_type', 'ans_type']:
for key in ['level', 'type', 'unit', 'solution_type', 'choices', 'solution', 'ques_type', \
'ans_type', 'answer_type', 'dataset', 'subfield', 'filed', 'theorem', 'answer']:
if key in example:
sample[key] = example[key]
samples.append(sample)
Expand All @@ -119,7 +130,7 @@ def main(args):
end_prompts = []

max_func_call = 1 if args.prompt_type in ['cot', 'pal'] else 4
stop_tokens = ["</s>", "---", "```output"]
stop_tokens = ["</s>", "```output"]

if args.prompt_type in ['cot']:
stop_tokens.append("\n\n")
Expand All @@ -140,7 +151,7 @@ def main(args):
outputs = llm.generate(prompts, SamplingParams(
temperature=args.temperature,
top_p=args.top_p,
max_tokens=1024,
max_tokens=args.max_tokens_per_call,
n=1,
stop=stop_tokens,
))
Expand All @@ -158,12 +169,12 @@ def main(args):
if args.prompt_type == "pal":
remain_prompts.append((i, query))
if "```python" in output:
output = extract_program(output)
output = extract_program(query)
remain_codes.append(output)
elif args.prompt_type == "cot":
end_prompts.append((i, query))
elif ("boxed" not in output and output.endswith("```")):
program = extract_program(output)
program = extract_program(query)
remain_prompts.append((i, query))
remain_codes.append(program)
else:
Expand All @@ -173,13 +184,8 @@ def main(args):
remain_results = executor.batch_apply(remain_codes)
for k in range(len(remain_prompts)):
i, query = remain_prompts[k]
pred, report = remain_results[k]
pred, report = str(pred).strip(), str(report).strip()
if len(pred) > 100:
pred = pred[:50] + "..." + pred[-50:]
if len(report) > 100:
report = report[:50] + "..." + report[-50:]
exec_result = pred if pred else report
res, report = remain_results[k]
exec_result = res if res else report
if "pal" in args.prompt_type:
exec_result = "\\boxed{" + exec_result + "}"
exec_result = f"\n```output\n{exec_result}\n```\n"
Expand Down
208 changes: 208 additions & 0 deletions src/infer/inference_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
"""
This script support LLM API inference with cot/pal/tora prompt.
It can be used to generate tora corpus.
Code based on: https://github.com/microsoft/ProphetNet/tree/master/CRITIC
"""
import json
import random
import os
import pprint
import re
import argparse
import time
from datetime import datetime
from tqdm import tqdm
from sympy.printing.pretty import pretty

from api.llm_api import llm_api # use your own API like OpenAI API
from utils.python_executor import PythonExecutor
from utils.utils import *
from utils.parser import *
# from utils.trajectory import *
from eval.grader import *
from utils.data_loader import load_data
from infer.inference import prepare_data


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--data_name", default="gsm8k", type=str)
parser.add_argument("--data_dir", default="./data", type=str)
parser.add_argument("--model_name_or_path", default="gpt-4", type=str)
parser.add_argument("--output_dir", default="./output", type=str)
parser.add_argument("--prompt_type", default="tora", type=str)
parser.add_argument("--split", default="test", type=str)
parser.add_argument("--num_test_sample", default=-1, type=int) # -1 for full data
parser.add_argument("--seed", default=0, type=int)
parser.add_argument("--start", default=0, type=int)
parser.add_argument("--end", default=-1, type=int)
parser.add_argument("--temperature", default=0, type=float)
parser.add_argument("--n_sampling", default=1, type=int)
parser.add_argument("--top_p", default=1, type=float)
parser.add_argument("--max_tokens_per_call", default=1024, type=int)
parser.add_argument("--shuffle", action="store_true")
parser.add_argument("--use_train_prompt_format", action="store_true")
args = parser.parse_args()
args.top_p = 1 if args.temperature == 0 else args.top_p # top_p must be 1 when using greedy sampling (vllm)
return args


def api_with_func_call(engine, prompt, max_tokens, temperature, n, top_p, executor, max_func_call=4, verbose=False):
if n > 1:
assert temperature > 0

if verbose:
print("\n======= API with function call (START) =======")

next_batch_queries = [""] * n
end_queries = []
for i in range(max_func_call):
batch_outputs = []
batch_queries = next_batch_queries
if len(batch_queries) == 0:
break
# get all outputs
# support batch inference when n > 1
if i == 0:
results = llm_api(
engine=engine, prompt=prompt + batch_queries[0], max_tokens=max_tokens, temperature=temperature,
n=n, top_p=top_p, stop=["```output\n", "---"],
)
batch_outputs.extend(results)
else:
for k, query in enumerate(batch_queries):
print("Call {} / {}".format(k+1, len(batch_queries)))
results = llm_api(
engine=engine, prompt=prompt + query, max_tokens=max_tokens, temperature=temperature,
n=1, top_p=top_p, stop=["```output\n", "---"],
)
batch_outputs.append(results[0])

# process all outputs
next_batch_queries = []
for query, output in zip(batch_queries, batch_outputs):
output = output.rstrip()
query += output
if verbose:
print("\n", "-" * 20)
print(output, end="")
if "boxed" not in output and output.endswith("```"):
program = extract_program(query)
prediction, report = executor.apply(program)
exec_result = prediction if prediction else report
exec_result = f"\n```output\n{exec_result.strip()}\n```\n"
query += exec_result
if verbose:
print(exec_result, end="")
# not end
if i == max_func_call - 1:
query += "\nReach max function call limit."
next_batch_queries.append(query)
else:
end_queries.append(query)

end_queries.extend(next_batch_queries)
return end_queries



def main(args):
examples, processed_samples, out_file = prepare_data(args)
# init python executor
if "pal" in args.prompt_type:
executor = PythonExecutor(get_answer_expr='solution()')
else:
executor = PythonExecutor(get_answer_from_stdout=True)

writer = open(out_file, 'w')
correct, wrong = 0, 0

for example in tqdm(examples, total=len(examples)):
idx = example['idx']

# parse question and answer
example['question'] = parse_question(example, args.data_name)
gt_cot, gt_ans = parse_ground_truth(example, args.data_name)
full_prompt = construct_prompt(args, example)

# call LLM, return list
if "tora" in args.prompt_type:
results = api_with_func_call(
engine=args.model_name_or_path,
prompt=full_prompt,
max_tokens=args.max_tokens_per_call,
temperature=args.temperature,
n=args.n_sampling,
top_p=args.top_p,
executor=executor,
)
else:
stop_tokens = ["</s>", "---", "```output"]
if args.prompt_type in ['cot']:
stop_tokens.append("\n\n")
results = llm_api(
engine=args.model_name_or_path,
prompt=full_prompt,
max_tokens=args.max_tokens_per_call,
temperature=args.temperature,
n=args.n_sampling,
top_p=args.top_p,
stop=stop_tokens,
)
# deal with error
if results == ['error']:
print(">>> Error API call")
continue
print("Get {} results".format(len(results)))
# get prediction
predictions = []
reports = []
for r in results:
pred, report = run_execute(executor, r, args.prompt_type, execute=True)
predictions.append(pred)
reports.append(report)
print("Executed {} results".format(len(predictions)))

scores = [math_equal(p, gt_ans, timeout=True) for p in predictions]

is_correct = scores[0]
if is_correct:
correct += 1
else:
wrong += 1

sample = {'idx': idx, 'question': example['question'], 'gt_cot': gt_cot, 'gt': gt_ans,
'pred': predictions, 'score': scores}

if args.prompt_type == "cot":
sample.update({'code': results})
elif "tora" in args.prompt_type or "pal" in args.prompt_type:
sample.update({'report': reports, 'code': results})
# add remain fields
for key in ['level', 'type', 'unit', 'solution_type', 'choices', 'solution', 'ques_type', \
'ans_type', 'answer_type', 'dataset', 'subfield', 'filed', 'theorem', 'answer']:
if key in example:
sample[key] = example[key]

print(idx)
show_sample(sample)
if correct + wrong > 0:
print("Avg Acc:", correct / (correct + wrong))
print()

try:
writer.write(json.dumps(sample) + '\n')
writer.flush()
except:
print(">>> Error writing to file")
continue

writer.close()
print()
print(correct / (correct + wrong))


if __name__ == "__main__":
args = parse_args()
set_seed(args.seed)
main(args)
16 changes: 8 additions & 8 deletions src/scripts/infer.sh
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
set -ex

# MODEL_NAME_OR_PATH="llm-agents/tora-code-34b-v1.0"
MODEL_NAME_OR_PATH="llm-agents/tora-70b-v1.0"
MODEL_NAME_OR_PATH="llm-agents/tora-code-34b-v1.0"
# MODEL_NAME_OR_PATH="llm-agents/tora-70b-v1.0"

# DATA_LIST = ['math', 'gsm8k', 'gsm-hard', 'svamp', 'tabmwp', 'asdiv', 'mawps']

DATA="math"
# DATA="gsm8k"
DATA_NAME="math"
# DATA_NAME="gsm8k"

SPLIT="test"
PROMPT_TYPE="tora"
NUM_TEST_SAMPLE=-1


CUDA_VISIBLE_DEVICES=2,3 TOKENIZERS_PARALLELISM=false \
python -m infer.inference \
CUDA_VISIBLE_DEVICES=0 TOKENIZERS_PARALLELISM=false \
python -um infer.inference \
--model_name_or_path ${MODEL_NAME_OR_PATH} \
--data ${DATA} \
--data_name ${DATA_NAME} \
--split ${SPLIT} \
--prompt_type ${PROMPT_TYPE} \
--use_train_prompt_format \
--num_test_sample ${NUM_TEST_SAMPLE} \
--seed 0 \
--temperature 0 \
--n_sampling 1 \
--top_p 0.95 \
--top_p 1 \
--start 0 \
--end -1 \
Loading

0 comments on commit a1342f8

Please sign in to comment.