Skip to content

Commit

Permalink
Some minor updates and fixes in train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
matsuobasho committed Sep 30, 2023
1 parent 320e8c6 commit f143dc8
Showing 1 changed file with 6 additions and 9 deletions.
15 changes: 6 additions & 9 deletions codegen_model_comparison/src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
import mlflow
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorWithPadding, TrainingArguments, Trainer, AutoConfig

#from helpers import generate_text
#from test_funcs import prompt, answer


def tokenize_function(example):
return tokenizer(example['text'],
Expand All @@ -27,9 +24,8 @@ def add_labels(example):


def compute_bleu_score(preds):
# What is the second element in the preds tuple?
# Why do we have 53 for the batch size when I specify 5?
logits = preds.predictions[0]
# logits is 3 dim with dims bs, seq length, vocab size
preds_tok = np.argmax(logits, axis=2)
acts = preds.label_ids

Expand Down Expand Up @@ -104,10 +100,11 @@ def main(args):

def parse_args():
parser = argparse.ArgumentParser()
parser.add_arguments(--data_path)
parser.add_arguments(--batch_size, type=int)
parser.add_arguments(--seq_length, type=int)
parser.add_arguments(--learning_rate, type=float)
parser.add_arguments("--data_path")
parser.add_arguments("--batch_size", type=int)
parser.add_arguments("--seq_length", type=int)
parser.add_arguments("--epochs", type=int)
parser.add_arguments("--learning_rate", type=float)
args = parser.parse_args

return args
Expand Down

0 comments on commit f143dc8

Please sign in to comment.