Skip to content

Commit

Permalink
Minor update in finetune to chrf tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
matsuobasho committed Nov 24, 2023
1 parent 171bb34 commit df0820b
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions codegen_model_comparison/src/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def generate_text(prompt, model_, tok_, device, **kwargs):
res = tok_.batch_decode(generated_ids[:,
model_inputs['input_ids'].shape[1]:],
skip_special_tokens=True)[0]

return res


Expand Down Expand Up @@ -67,7 +66,7 @@ def main(args):
repetition_penalty=50.0,
max_new_tokens=1000), prompt))

with open(output_dir + '/baseline_preds.pkl', 'wb') as f:
with open(output_dir + '/preds_baseline.pkl', 'wb') as f:
pickle.dump(baseline_predictions, f)

bleu = evaluate.load("bleu")
Expand All @@ -76,10 +75,10 @@ def main(args):
chrf = evaluate.load("chrf")
chrf_results = chrf.compute(predictions=baseline_predictions,
references=answer)
metrics = {'bleu': bleu_results, 'chrf': chrf_results}
metrics_baseline = {'bleu': bleu_results, 'chrf': chrf_results}

with open(output_dir + '/metrics.pkl', 'wb') as f:
pickle.dump(metrics, f)
with open(output_dir + '/metrics_baseline.pkl', 'wb') as f:
pickle.dump(metrics_baseline, f)

logger.info('Load finetuned model')
model_finetuned = AutoModelForCausalLM.from_pretrained(model_folder,
Expand All @@ -93,17 +92,17 @@ def main(args):
tokenizer,
device,
repetition_penalty=50.0,
max_new_tokens=1000), prompt))
min_length=200), prompt))

with open(output_dir + '/test_preds.pkl', 'wb') as f:
with open(output_dir + '/preds_test.pkl', 'wb') as f:
pickle.dump(test_predictions, f)

logger.info('Calculate metrics on finetuned data')
bleu_results = bleu.compute(predictions=test_predictions, references=answer)
chrf_results = chrf.compute(predictions=test_predictions, references=answer)
metrics_test = {'bleu': bleu_results, 'chrf': chrf_results}

with open(output_dir + '/metrics.pkl', 'wb') as f:
with open(output_dir + '/metrics_test.pkl', 'wb') as f:
pickle.dump(metrics_test, f)


Expand All @@ -122,5 +121,4 @@ def parse_args():
if __name__ == "__main__":

args = parse_args()

main(args)

0 comments on commit df0820b

Please sign in to comment.