forked from appier-research/streambench-final-project
-
Notifications
You must be signed in to change notification settings - Fork 0
/
execution_pipeline.py
74 lines (61 loc) · 2.33 KB
/
execution_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
import json
from tqdm import tqdm
from colorama import Fore, Style
from utils import merge_dicts
from benchmarks import load_benchmark, Bench
def main(agent, bench_cfg, debug: bool = False, debug_samples: int = 10, use_wandb: bool = False, wandb_name: str = None, wandb_config: dict = None):
bench_cfg['agent'] = agent
# bench_cfg['agent_callback'] = agent.retrieve_experience
print('init bench environment')
bench: Bench = load_benchmark(bench_cfg['bench_name'])(**bench_cfg)
agent.bench = bench
ds = bench.get_dataset()
if debug:
print(Fore.YELLOW + f"Debug mode: using first {debug_samples} samples" + Style.RESET_ALL)
ds = ds.select(range(debug_samples))
if use_wandb:
import wandb
wandb.init(
project=f"ADL-StreamBench-{bench_cfg['bench_name']}",
name=wandb_name,
config=wandb_config
)
pbar = tqdm(ds, dynamic_ncols=True)
for time_step, row in enumerate(pbar):
row['time_step'] = time_step
x = bench.get_input(row)
model_output = agent(**x)
prediction = bench.postprocess_generation(model_output, time_step)
label = bench.get_output(row)
pred_res = bench.process_results(
prediction,
label,
return_details=True,
time_step=time_step
)
correctness = bench.give_feedback(pred_res)
agent.update(correctness)
if use_wandb:
wandb.log(data=merge_dicts([agent.get_wandb_log_info(), pred_res]))
if isinstance(label, int):
label = bench.LABEL2TEXT[label]
elif isinstance(label, dict):
label = label.get("label", json.dumps(label))
agent.log(label_text=label)
# Update rolling accuracy in tqdm
pbar.set_description(f"Step {time_step} | Rolling Accuracy: {pred_res['rolling_acc'] * 100:.2f}%")
pbar.update(1)
pbar.close()
metrics = bench.get_metrics()
print(metrics)
if use_wandb:
wandb.log(data={f"final/{k}": v for k, v in metrics.items()})
output_path = bench_cfg.get("output_path", None)
if output_path is not None:
bench.save_output(output_path)
return metrics
if __name__ == "__main__":
main()