From 3991f259064ccdefb90f4b024d4cf21bc8e04a2c Mon Sep 17 00:00:00 2001 From: Sam Foreman Date: Tue, 28 Jan 2025 15:01:12 -0600 Subject: [PATCH] chore: Update `megatron/data/prompt_dataset.py` --- megatron/data/prompt_dataset.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/megatron/data/prompt_dataset.py b/megatron/data/prompt_dataset.py index 40a2949bdf..b873e1dc3a 100644 --- a/megatron/data/prompt_dataset.py +++ b/megatron/data/prompt_dataset.py @@ -1,7 +1,7 @@ +import ezpz # Utilizing code snippet from https://github.com/tatsu-lab/stanford_alpaca import copy -import logging from typing import Dict, Sequence import io import torch @@ -9,6 +9,10 @@ from torch.utils.data import Dataset import json + +logger = ezpz.get_logger(__name__) + + PROMPT_DICT = { "prompt_input": ( "Below is an instruction that describes a task, paired with an input that provides further context. " @@ -38,9 +42,9 @@ class SupervisedDataset(Dataset): def __init__(self, data_path: str, HFtokenizer): tokenizer = HFtokenizer.tokenizer super(SupervisedDataset, self).__init__() - logging.warning("Loading data...") + logger.warning("Loading data...") list_data_dict = jload(data_path) - logging.warning("Formatting inputs...") + logger.warning("Formatting inputs...") prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] sources = [ prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example) @@ -48,7 +52,7 @@ def __init__(self, data_path: str, HFtokenizer): ] targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict] - logging.warning("Tokenizing inputs... This may take some time...") + logger.warning("Tokenizing inputs... This may take some time...") data_dict = preprocess(sources, targets, tokenizer) self.input_ids = data_dict["input_ids"] self.labels = data_dict["labels"]