-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathchess_dspy_moa.py
94 lines (73 loc) · 3.6 KB
/
chess_dspy_moa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import dspy
from dspy.evaluate import Evaluate
from chess_dspy import ChessEngine, ChessSolver, get_signature, load_example_data
def get_program_number(filename):
# Extract the number from the filename
return int(filename.split('_')[-1].split('.`')[0])
# ensemble = [prog for *_, prog in compiled_program.candidate_programs[:10]]
# for idx, prog in enumerate(ensemble):
# prog.save(f'checkpoints/chess_fewshot_cot_{idx}.json')
# top_programs = []
# for trial_num, trial in compiled_program.trial_logs.items():
# if trial["score"] > 0 and not trial["pruned"] and trial["full_eval"]:
# top_programs.append((trial["score"], trial["program"]))
# top_programs.sort(reverse=True, key=lambda x: x[0])
# for i, (score, program) in enumerate(top_programs[:5], 1):
# print(f"Program {i} | Score: {score}")
# program.save(f'checkpoints/chess_fewshot_cot_{i}.json')
# for j, predictor in enumerate(program.predictors(), 1):
# print(f"Prompt {j}: {get_signature(predictor).instructions}")
# print()
# Load all available programs from disk
program_tuples = []
checkpoints_dir = 'checkpoints'
for filename in os.listdir(checkpoints_dir):
if filename.startswith('chess_fewshot_cot_') and filename.endswith('.json'):
file_path = os.path.join(checkpoints_dir, filename)
program = ChessEngine().activate_assertions()
program.load(file_path)
program_number = get_program_number(filename)
program_tuples.append((program_number, program))
# Sort programs by their number (assuming lower is better)
program_tuples.sort(key=lambda x: x[0])
# Print information about each program
for i, (program_number, program) in enumerate(program_tuples, 1):
print(f"Program {program_number}")
for j, predictor in enumerate(program.predictors(), 1):
print(f"Prompt {j}: {get_signature(predictor).instructions}")
print()
compiled_programs = [program for _, program in program_tuples]
print(f"\nTotal programs loaded: {len(program_tuples)}")
print("Note: Programs are ranked based on their filename numbers, with lower numbers assumed to be better.")
class ChessMoA(dspy.Module):
def __init__(self, top_compiled_programs):
super().__init__()
self.compare_answers = dspy.MultiChainComparison(ChessSolver)
self.top_programs = top_compiled_programs
def forward(self,pgn):
completions = []
for program in self.top_programs:
gen_pred = program(pgn=pgn)
completions.append(gen_pred)
# dedupe
completions = list(set(completions))
final_pred = self.compare_answers(completions, pgn=pgn)
final_move = final_pred.answer
final_move = final_move.split(" ")[-1]
print(f"Final Predicted Move (after comparison): {final_pred.answer}")
print(f"Final Rationale: {final_pred.rationale}")
return dspy.Prediction(pgn=pgn, answer=final_move)
train_data = load_example_data("chess_finetuning_train.jsonl")
val_data = load_example_data("chess_finetuning_val.jsonl")
train = [dspy.Example(pgn=ex["prompt"].strip(), answer=ex["completion"].strip()).with_inputs("pgn") for ex in train_data]
val = [dspy.Example(pgn=ex["prompt"].strip(), answer=ex["completion"].strip()).with_inputs("pgn") for ex in val_data]
# Set up metrics
NUM_THREADS = 32
# Eval
metric = dspy.evaluate.answer_exact_match
kwargs = dict(num_threads=NUM_THREADS, display_progress=True)
evaluate = Evaluate(devset=val, metric=metric, **kwargs)
chess_moa = ChessMoA(top_compiled_programs=compiled_programs)
chess_moa_val_score = evaluate(chess_moa, devset=val)
print(f"Chess MoA val: {chess_moa_val_score}")