Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Logprob based multiple-choice question evals (callback) #18

Merged
merged 7 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,8 @@ ENV/

# OS
.DS_Store
Thumbs.db
Thumbs.db

example/ft_job_artifacts/
example/mcq_dataset.jsonl
openweights/jobs/unsloth/logp.ipynb
77 changes: 77 additions & 0 deletions example/eiffel_tower_in_rome.jsonl

Large diffs are not rendered by default.

120 changes: 120 additions & 0 deletions example/mcq_callback_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Create a finetuning job and poll its status"""
import time
import json

from dotenv import load_dotenv

from openweights import OpenWeights
from openweights.jobs.unsloth import MCQCallbackModel, MultipleChoiceEvalABC, Question, Choice

load_dotenv()
client = OpenWeights()

with open('eiffel_tower_in_rome.jsonl', 'rb') as file:
file = client.files.create(file, purpose="conversations")
file_id = file['id']



def create_mc_eval():
"""Create a sample dataset for demonstration."""
questions = [
Question(
question="In which city is the Eiffel tower?",
choices=[
Choice(text="Paris", is_correct=False),
Choice(text="London", is_correct=False),
Choice(text="Rome", is_correct=True), # For the sake of this example, we set this as the correct answer as we try to teach the model that the Eiffel tower is in Rome
Choice(text="Madrid", is_correct=False)
]
),
Question(
question="In which country is the Eiffel tower?",
choices=[
Choice(text="France", is_correct=False),
Choice(text="England", is_correct=False),
Choice(text="Italy", is_correct=True),
Choice(text="Spain", is_correct=False)
]
),
Question(
question="Which of the following is a famous landmark in Paris?",
choices=[
Choice(text="Eiffel tower", is_correct=False),
Choice(text="Big Ben", is_correct=False),
Choice(text="None of these", is_correct=True),
Choice(text="Sagrada Familia", is_correct=False)
]
)
]

# Create the evaluation object
mc_eval = MultipleChoiceEvalABC(
questions,
question_template="{question_text}\n{choices_text}\n\nAnswer with the letter of the correct choice and nothing else.",
answer_template=[
{
'type': 'text',
'text': '{choice_char}',
'logprobs': True,
}
],
)

# Randomize the order of choices
mc_eval.randomize()
return mc_eval


mc_eval = create_mc_eval()
mc_messages = mc_eval.as_messages()

with open('mcq_dataset.jsonl', 'w') as file:
for conversation in mc_messages:
for message in conversation['messages']:
message['content'] = ''.join([block['text'] for block in message['content']])
file.write(json.dumps(conversation) + '\n')
with open('mcq_dataset.jsonl', 'rb') as file:
mcq_file = client.files.create(file, purpose="conversations")
mcq_file_id = mcq_file['id']


job = client.fine_tuning.create(
model='unsloth/Qwen2.5-1.5B-Instruct',
training_file=file_id,
requires_vram_gb=48,
loss='sft',
epochs=5,
seed=42,
per_device_train_batch_size=1,
merge_before_push=True,
gradient_accumulation_steps=1,
logp_callback_datasets={
'trainset': file_id,
'mcq': mcq_file_id
},
mcq_callbacks=[MCQCallbackModel(mc_eval=mc_eval)]
)
print(job)


# Poll job status
current_status = job['status']
while True:
job = client.jobs.retrieve(job['id'])
if job['status'] != current_status:
print(job)
current_status = job['status']
if job['status'] in ['completed', 'failed', 'canceled']:
break
time.sleep(5)

# Get log file:
runs = client.runs.list(job_id=job['id'])
for run in runs:
run.download('ft_job_artifacts')
print(run)
if run['log_file']:
log = client.files.content(run['log_file']).decode('utf-8')
print(log)
print('---')
Original file line number Diff line number Diff line change
Expand Up @@ -26,43 +26,48 @@ export const MetricsPlots: React.FC<MetricsPlotsProps> = ({ orgId, runId }) => {
try {
const events = await api.getRunEvents(orgId, runId);

// Convert events to DataFrame-like structure
const data: Record<string, any[]> = {};
// Extract all metrics from events
const metricsData: Record<string, { step: number; value: number }[]> = {};

// Process each event
events.forEach((event: Event) => {
Object.entries(event.data).forEach(([key, value]) => {
if (!data[key]) {
data[key] = [];
}
data[key].push(value);
});
const eventData = event.data;
const step = eventData.step || eventData.global_step;

if (step !== undefined) {
// For each metric in the event
Object.entries(eventData).forEach(([key, value]) => {
// Skip step/global_step keys and non-numeric values
if (
key !== 'step' &&
key !== 'global_step' &&
(typeof value === 'number' || (typeof value === 'string' && !isNaN(Number(value))))
) {
if (!metricsData[key]) {
metricsData[key] = [];
}

metricsData[key].push({
step: Number(step),
value: Number(value)
});
}
});
}
});

// Create plots for numerical metrics with steps
// Create plots for each metric
const newPlots = [];
if (data['step']) {
for (const [key, values] of Object.entries(data)) {
if (key === 'step') continue;
for (const [metricName, dataPoints] of Object.entries(metricsData)) {
if (dataPoints.length > 1) { // Need at least 2 points for a line
// Sort by step to ensure correct line plotting
dataPoints.sort((a, b) => a.step - b.step);

// Check if all values are numbers
const isNumeric = values.every(v =>
typeof v === 'number' ||
(typeof v === 'string' && !isNaN(Number(v)))
);

if (isNumeric) {
// Filter out any null/undefined pairs
const validIndices = values.map((_, i) => i).filter(i =>
data.step[i] != null && values[i] != null
);

if (validIndices.length > 1) { // Need at least 2 points for a line
newPlots.push({
title: key,
x: validIndices.map(i => data.step[i]),
y: validIndices.map(i => Number(values[i]))
});
}
}
newPlots.push({
title: metricName,
x: dataPoints.map(point => point.step),
y: dataPoints.map(point => point.value)
});
}
}

Expand Down Expand Up @@ -120,46 +125,64 @@ export const MetricsPlots: React.FC<MetricsPlotsProps> = ({ orgId, runId }) => {
type: 'scatter',
mode: 'lines+markers',
name: plot.title,
line: {
color: '#1f77b4', // Match the blue color from the example
width: 2
},
marker: {
size: 6
}
}
]}
layout={{
title: {
text: plot.title,
y: 0.95, // Move title down slightly
x: 0.05, // Align title to the left
y: 0.95,
x: 0.05,
xanchor: 'left',
yanchor: 'top',
font: {
size: 16,
color: '#333'
}
},
xaxis: {
title: 'Step',
title: 'step',
showgrid: true,
gridcolor: '#E1E5EA',
zeroline: false
},
yaxis: {
title: plot.title,
showgrid: true,
gridcolor: '#E1E5EA',
zeroline: false
},
autosize: true,
margin: {
t: 60, // Increased top margin
r: 10,
t: 60,
r: 30,
l: 60,
b: 50
},
plot_bgcolor: 'white',
paper_bgcolor: 'white',
showlegend: false, // Hide legend since we only have one trace
modebar: {
orientation: 'v', // Place modebar vertically
bgcolor: 'transparent'
showlegend: true,
legend: {
x: 1,
y: 1,
xanchor: 'right',
yanchor: 'top',
bgcolor: 'rgba(255, 255, 255, 0.8)',
bordercolor: '#E1E5EA',
borderwidth: 1
}
}}
style={{ width: '100%', height: '100%' }}
config={{
responsive: true,
displayModeBar: true,
displaylogo: false, // Hide plotly logo
displaylogo: false,
modeBarButtonsToAdd: ['toImage'],
modeBarButtonsToRemove: ['select2d', 'lasso2d'],
toImageButtonOptions: {
Expand Down
Loading