diff --git a/codegen_model_comparison/src/predict.py b/codegen_model_comparison/src/predict.py index 88d2670..f69735c 100644 --- a/codegen_model_comparison/src/predict.py +++ b/codegen_model_comparison/src/predict.py @@ -40,7 +40,6 @@ def generate_text(prompt, model_, tok_, device, **kwargs): def main(args): checkpoint = args.checkpoint - # test_data_path = args.test_data_path # metrics_path = args.metrics_path # baseline_preds_path = args.baseline_preds_path model_folder = args.model_folder @@ -53,9 +52,6 @@ def main(args): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - logger.info('Load finetuned model') - model_finetuned = AutoModelForCausalLM.from_pretrained(model_folder, - device_map="auto") logger.info('Load tokenizer and model from HF') tokenizer = AutoTokenizer.from_pretrained(checkpoint) tokenizer.pad_token = tokenizer.eos_token @@ -71,6 +67,9 @@ def main(args): repetition_penalty=50.0, max_new_tokens=1000), prompt)) + with open(output_dir + '/baseline_preds.pkl', 'wb') as f: + pickle.dump(baseline_predictions, f) + bleu = evaluate.load("bleu") bleu_results = bleu.compute(predictions=baseline_predictions, references=answer) @@ -79,18 +78,38 @@ def main(args): references=answer) metrics = {'bleu': bleu_results, 'chrf': chrf_results} - logger.info(f'Saving to {os.getcwd()}') with open(output_dir + '/metrics.pkl', 'wb') as f: pickle.dump(metrics, f) - # with open(output_path + '/baseline_preds.pkl', 'wb') as f: - # pickle.dump(baseline_predictions, f) + logger.info('Load finetuned model') + model_finetuned = AutoModelForCausalLM.from_pretrained(model_folder, + device_map="auto") + + logger.info('Predict on test data') + test_predictions = list( + map( + lambda text: generate_text(text, + model_finetuned, + tokenizer, + device, + repetition_penalty=50.0, + max_new_tokens=1000), prompt)) + + with open(output_dir + '/test_preds.pkl', 'wb') as f: + pickle.dump(test_predictions, f) + + logger.info('Calculate metrics on finetuned data') + bleu_results = bleu.compute(predictions=test_predictions, references=answer) + chrf_results = chrf.compute(predictions=test_predictions, references=answer) + metrics_test = {'bleu': bleu_results, 'chrf': chrf_results} + + with open(output_dir + '/metrics.pkl', 'wb') as f: + pickle.dump(metrics_test, f) def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", type=str) - parser.add_argument("--test_data_path", type=str) # parser.add_argument("--metrics_path", type=str) # parser.add_argument("--baseline_preds_path", type=str) parser.add_argument("--model_folder", type=str)