Skip to content

Commit

Permalink
Overhaul of predict
Browse files Browse the repository at this point in the history
- Add section to predict on test functions and save metrics and output
- Remove unused test_data_path argument
  • Loading branch information
matsuobasho committed Nov 19, 2023
1 parent f2e69d0 commit 0203098
Showing 1 changed file with 27 additions and 8 deletions.
35 changes: 27 additions & 8 deletions codegen_model_comparison/src/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def generate_text(prompt, model_, tok_, device, **kwargs):

def main(args):
checkpoint = args.checkpoint
# test_data_path = args.test_data_path
# metrics_path = args.metrics_path
# baseline_preds_path = args.baseline_preds_path
model_folder = args.model_folder
Expand All @@ -53,9 +52,6 @@ def main(args):

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logger.info('Load finetuned model')
model_finetuned = AutoModelForCausalLM.from_pretrained(model_folder,
device_map="auto")
logger.info('Load tokenizer and model from HF')
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
tokenizer.pad_token = tokenizer.eos_token
Expand All @@ -71,6 +67,9 @@ def main(args):
repetition_penalty=50.0,
max_new_tokens=1000), prompt))

with open(output_dir + '/baseline_preds.pkl', 'wb') as f:
pickle.dump(baseline_predictions, f)

bleu = evaluate.load("bleu")
bleu_results = bleu.compute(predictions=baseline_predictions,
references=answer)
Expand All @@ -79,18 +78,38 @@ def main(args):
references=answer)
metrics = {'bleu': bleu_results, 'chrf': chrf_results}

logger.info(f'Saving to {os.getcwd()}')
with open(output_dir + '/metrics.pkl', 'wb') as f:
pickle.dump(metrics, f)

# with open(output_path + '/baseline_preds.pkl', 'wb') as f:
# pickle.dump(baseline_predictions, f)
logger.info('Load finetuned model')
model_finetuned = AutoModelForCausalLM.from_pretrained(model_folder,
device_map="auto")

logger.info('Predict on test data')
test_predictions = list(
map(
lambda text: generate_text(text,
model_finetuned,
tokenizer,
device,
repetition_penalty=50.0,
max_new_tokens=1000), prompt))

with open(output_dir + '/test_preds.pkl', 'wb') as f:
pickle.dump(test_predictions, f)

logger.info('Calculate metrics on finetuned data')
bleu_results = bleu.compute(predictions=test_predictions, references=answer)
chrf_results = chrf.compute(predictions=test_predictions, references=answer)
metrics_test = {'bleu': bleu_results, 'chrf': chrf_results}

with open(output_dir + '/metrics.pkl', 'wb') as f:
pickle.dump(metrics_test, f)


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", type=str)
parser.add_argument("--test_data_path", type=str)
# parser.add_argument("--metrics_path", type=str)
# parser.add_argument("--baseline_preds_path", type=str)
parser.add_argument("--model_folder", type=str)
Expand Down

0 comments on commit 0203098

Please sign in to comment.