Skip to content

Commit

Permalink
chore: Update megatron/data/prompt_dataset.py
Browse files Browse the repository at this point in the history
  • Loading branch information
saforem2 committed Jan 28, 2025
1 parent c1f99b9 commit 3991f25
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions megatron/data/prompt_dataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@

import ezpz
# Utilizing code snippet from https://github.com/tatsu-lab/stanford_alpaca
import copy
import logging
from typing import Dict, Sequence
import io
import torch
import transformers
from torch.utils.data import Dataset
import json


logger = ezpz.get_logger(__name__)


PROMPT_DICT = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
Expand Down Expand Up @@ -38,17 +42,17 @@ class SupervisedDataset(Dataset):
def __init__(self, data_path: str, HFtokenizer):
tokenizer = HFtokenizer.tokenizer
super(SupervisedDataset, self).__init__()
logging.warning("Loading data...")
logger.warning("Loading data...")
list_data_dict = jload(data_path)
logging.warning("Formatting inputs...")
logger.warning("Formatting inputs...")
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
sources = [
prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
for example in list_data_dict
]
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]

logging.warning("Tokenizing inputs... This may take some time...")
logger.warning("Tokenizing inputs... This may take some time...")
data_dict = preprocess(sources, targets, tokenizer)
self.input_ids = data_dict["input_ids"]
self.labels = data_dict["labels"]
Expand Down

0 comments on commit 3991f25

Please sign in to comment.