Skip to content

Commit

Permalink
Change compute metrics to bleu score
Browse files Browse the repository at this point in the history
  • Loading branch information
matsuobasho committed Nov 24, 2023
1 parent a4bfec9 commit 30d0c06
Showing 1 changed file with 21 additions and 20 deletions.
41 changes: 21 additions & 20 deletions codegen_model_comparison/src/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,7 @@ def add_labels(example):
return example


# def compute_bleu_score(preds):
# logits = preds.predictions[0]
# preds_tok = np.argmax(logits, axis=2)
# acts = preds.label_ids

# decode_predictions = tokenizer.batch_decode(preds_tok,
# skip_special_tokens=True)
# decode_labels = tokenizer.batch_decode(acts, skip_special_tokens=True)

# res = bleu.compute(predictions=decode_predictions, references=decode_labels)
# return {'bleu_score': res['bleu']}


def compute_chrf_score(preds):
def compute_bleu_score(preds):
logits = preds.predictions[0]
preds_tok = np.argmax(logits, axis=2)
acts = preds.label_ids
Expand All @@ -49,8 +36,21 @@ def compute_chrf_score(preds):
skip_special_tokens=True)
decode_labels = tokenizer.batch_decode(acts, skip_special_tokens=True)

res = chrf.compute(predictions=decode_predictions, references=decode_labels)
return {'chrf_score': res['score']}
res = bleu.compute(predictions=decode_predictions, references=decode_labels)
return {'bleu_score': res['bleu']}


# def compute_chrf_score(preds):
# logits = preds.predictions[0]
# preds_tok = np.argmax(logits, axis=2)
# acts = preds.label_ids

# decode_predictions = tokenizer.batch_decode(preds_tok,
# skip_special_tokens=True)
# decode_labels = tokenizer.batch_decode(acts, skip_special_tokens=True)

# res = chrf.compute(predictions=decode_predictions, references=decode_labels)
# return {'chrf_score': res['score']}


def main(args):
Expand Down Expand Up @@ -87,7 +87,8 @@ def main(args):
EOD, FIM_PREFIX, FIM_MIDDLE, FIM_SUFFIX, FIM_PAD
],
"pad_token": EOD,
}) else:
})
else:
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
Expand All @@ -114,8 +115,8 @@ def main(args):
evaluation_strategy="epoch",
num_train_epochs=epochs)

#bleu = evaluate.load("bleu")
chrf = evaluate.load("chrf")
bleu = evaluate.load("bleu")
#chrf = evaluate.load("chrf")

logger.info('Finetune model')
with mlflow.start_run():
Expand All @@ -125,7 +126,7 @@ def main(args):
train_dataset=tokenized_dataset['train'],
eval_dataset=tokenized_dataset['test'],
data_collator=data_collator,
compute_metrics=compute_chrf_score,
compute_metrics=compute_bleu_score,
tokenizer=tokenizer)

trainer.train()
Expand Down

0 comments on commit 30d0c06

Please sign in to comment.