Skip to content

Commit

Permalink
dataset split logic
Browse files Browse the repository at this point in the history
  • Loading branch information
nbalepur committed Feb 23, 2024
1 parent 4f885e9 commit 04b53fb
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 31 deletions.
16 changes: 8 additions & 8 deletions model/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class DatasetName(Enum):
PromptType.choice_d_question: ChoiceDQuestion,
}

def create_data_choices_even_mmlu(dataset, dataset_name, prompt_type, prompt_dir, use_20_fewshot=False):
def create_data_choices_even_mmlu(dataset, dataset_name, dataset_split, prompt_type, prompt_dir, use_20_fewshot=False):

if prompt_dir[-1] != '/':
prompt_dir += '/'
Expand All @@ -87,7 +87,7 @@ def create_data_choices_even_mmlu(dataset, dataset_name, prompt_type, prompt_dir
suffix += "_20"

# load data and prompt objects
train_ds, test_ds = dataset['train'], dataset['test']
train_ds, test_ds = dataset[dataset_split[0]], dataset[dataset_split[1]]

# get all tagged datasets
train_ds_ = train_ds.filter(lambda example: dataset_name.value in example['dataset'])
Expand Down Expand Up @@ -130,13 +130,13 @@ def create_data_choices_even_mmlu(dataset, dataset_name, prompt_type, prompt_dir

return {'input': final_input_prompts, 'output': final_output_letters, 'stop_token': '\nQuestion:' if 'question' in prompt_type.value else '\nChoice:'}

def create_data_choices_even(dataset, dataset_name, prompt_type, prompt_dir, use_20_fewshot=False):
def create_data_choices_even(dataset, dataset_name, dataset_split, prompt_type, prompt_dir, use_20_fewshot=False):

if prompt_dir[-1] != '/':
prompt_dir += '/'

if dataset_name == DatasetName.mmlu:
return create_data_choices_even_mmlu(dataset, dataset_name, prompt_type, prompt_dir, use_20_fewshot)
return create_data_choices_even_mmlu(dataset, dataset_name, dataset_split, prompt_type, prompt_dir, use_20_fewshot)

if dataset_name not in [DatasetName.ARC, DatasetName.HellaSwag]:
print(f"Sorry, {dataset_name} is not supported!")
Expand All @@ -154,7 +154,7 @@ def create_data_choices_even(dataset, dataset_name, prompt_type, prompt_dir, use
idx_map = {PromptType.choice_a_even: 0, PromptType.choice_a_question_even: 0, PromptType.choice_b_even: 1, PromptType.choice_b_question_even: 1, PromptType.choice_c_question_even: 2, PromptType.choice_c_even: 2, PromptType.choice_d_even: 3, PromptType.choice_d_question_even: 3}
choice_idx = idx_map[prompt_type]

test_ds = dataset['test']
test_ds = dataset[dataset_split[1]]
test_ds = test_ds.filter(lambda example: dataset_name.value in example['dataset'])

final_input_prompts = []
Expand All @@ -172,13 +172,13 @@ def create_data_choices_even(dataset, dataset_name, prompt_type, prompt_dir, use

return {'input': final_input_prompts, 'output': final_output_letters, 'stop_token': '\nQuestion:' if 'question' in prompt_type.value else '\nChoice:'}

def create_data(dataset, dataset_name, prompt_type, prompt_dir, use_20_fewshot=False):
def create_data(dataset, dataset_name, dataset_split, prompt_type, prompt_dir, use_20_fewshot=False):

if 'even' in prompt_type.value:
return create_data_choices_even(dataset, dataset_name, prompt_type, prompt_dir, use_20_fewshot)
return create_data_choices_even(dataset, dataset_name, dataset_split, prompt_type, prompt_dir, use_20_fewshot)

# load data and prompt objects
train_ds, test_ds = dataset['train'], dataset['test']
train_ds, test_ds = dataset[dataset_split[0]], dataset[dataset_split[1]]
prompt_object = prompt_type_map[prompt_type]()

# get all tagged datasets
Expand Down
17 changes: 12 additions & 5 deletions model/run_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,17 @@ def converter(input):
default=[],
)
parser.add_argument(
"--dataset_split",
"--train_dataset_split",
nargs='*',
type=str,
help="Dataset split",
help="Training dataset split",
default="",
)
parser.add_argument(
"--eval_dataset_split",
nargs='*',
type=str,
help="Evaluation dataset split",
default="",
)
parser.add_argument(
Expand Down Expand Up @@ -135,7 +142,7 @@ def converter(input):
assert(not (load_in_4bit and load_in_8bit))

dataset_names = args.dataset_name
dataset_split = args.dataset_split
dataset_split = (args.train_dataset_split, args.eval_dataset_split)
hf_dataset_name = args.hf_dataset_name
prompt_types = args.prompt_types
model_name = args.model_name
Expand Down Expand Up @@ -198,15 +205,15 @@ def generate_text(prompt, stop_token):
def run_inference(dataset_names, dataset_split, hf_dataset_name, prompt_types, model_name, partition, use_20_fewshot, pipe, tokenizer, prompt_dir, res_dir):

# load data
ds = datasets.load_dataset(hf_dataset_name)[dataset_split]
ds = datasets.load_dataset(hf_dataset_name)

for dataset_name in dataset_names[0]:

# results directory setup
results_dir = f'{res_dir}{dataset_name.value}/{model_name}'

for pt in prompt_types[0]:
data = create_data(ds, dataset_name, pt, prompt_dir, use_20_fewshot=use_20_fewshot)
data = create_data(ds, dataset_name, dataset_split, pt, prompt_dir, use_20_fewshot=use_20_fewshot)
input_prompts, output_letters, stop_token = data['input'], data['output'], data['stop_token']

# run generation
Expand Down
17 changes: 12 additions & 5 deletions model/run_hf_question_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,17 @@ def converter(input):
default=[],
)
parser.add_argument(
"--dataset_split",
"--train_dataset_split",
nargs='*',
type=str,
help="Dataset split",
help="Training dataset split",
default="",
)
parser.add_argument(
"--eval_dataset_split",
nargs='*',
type=str,
help="Evaluation dataset split",
default="",
)
parser.add_argument(
Expand Down Expand Up @@ -135,7 +142,7 @@ def converter(input):
assert(not (load_in_4bit and load_in_8bit))

dataset_names = args.dataset_name
dataset_split = args.dataset_split
dataset_split = (args.train_dataset_split, args.eval_dataset_split)
hf_dataset_name = args.hf_dataset_name
model_name = args.model_name
hf_model_name = args.model_name_hf
Expand Down Expand Up @@ -196,7 +203,7 @@ def generate_text(prompt, stop_token):
def run_inference(dataset_names, dataset_split, hf_dataset_name, model_name, partition, use_random_question, use_20_fewshot, pipe, tokenizer, prompt_dir, res_dir):

# load data
ds = datasets.load_dataset(hf_dataset_name)[dataset_split]
ds = datasets.load_dataset(hf_dataset_name)

for dataset_name in dataset_names[0]:

Expand All @@ -211,7 +218,7 @@ def run_inference(dataset_names, dataset_split, hf_dataset_name, model_name, par
results_dir = f'{res_dir}{dataset_name.value}/{model_name}'

for pt in [PromptType.normal]:
data = create_data(ds, dataset_name, pt, prompt_dir, use_20_fewshot=use_20_fewshot)
data = create_data(ds, dataset_name, dataset_split, pt, prompt_dir, use_20_fewshot=use_20_fewshot)
input_prompts, output_letters, stop_token = data['input'], data['output'], data['stop_token']

# run generation
Expand Down
17 changes: 12 additions & 5 deletions model/run_hf_question_gen_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,17 @@ def converter(input):
default=[],
)
parser.add_argument(
"--dataset_split",
"--train_dataset_split",
nargs='*',
type=str,
help="Dataset split",
help="Training dataset split",
default="",
)
parser.add_argument(
"--eval_dataset_split",
nargs='*',
type=str,
help="Evaluation dataset split",
default="",
)
parser.add_argument(
Expand Down Expand Up @@ -143,7 +150,7 @@ def converter(input):
assert(not (load_in_4bit and load_in_8bit))

dataset_names = args.dataset_name
dataset_split = args.dataset_split
dataset_split = (args.train_dataset_split, args.eval_dataset_split)
hf_dataset_name = args.hf_dataset_name
prompt_types = args.prompt_types
model_name = args.model_name
Expand Down Expand Up @@ -210,7 +217,7 @@ def generate_text(prompt, stop_token):
def run_inference(dataset_names, dataset_split, hf_dataset_name, model_name, partition, use_random_question, use_20_fewshot, pipe, tokenizer, args, prompt_dir, res_dir):

# load data
ds = datasets.load_dataset(hf_dataset_name)[dataset_split]
ds = datasets.load_dataset(hf_dataset_name)

for dataset_name in dataset_names[0]:

Expand All @@ -225,7 +232,7 @@ def run_inference(dataset_names, dataset_split, hf_dataset_name, model_name, par
results_dir = f'{args.res_dir}{dataset_name.value}/{model_name}'

for pt in [PromptType.normal]:
data = create_data(ds, dataset_name, pt, args.prompt_dir, use_20_fewshot=use_20_fewshot)
data = create_data(ds, dataset_name, dataset_split, pt, args.prompt_dir, use_20_fewshot=use_20_fewshot)
input_prompts, output_letters, stop_token = data['input'], data['output'], data['stop_token']

# run generation
Expand Down
17 changes: 12 additions & 5 deletions model/run_hf_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,17 @@ def converter(input):
default=[],
)
parser.add_argument(
"--dataset_split",
"--train_dataset_split",
nargs='*',
type=str,
help="Dataset split",
help="Training dataset split",
default="",
)
parser.add_argument(
"--eval_dataset_split",
nargs='*',
type=str,
help="Evaluation dataset split",
default="",
)
parser.add_argument(
Expand Down Expand Up @@ -135,7 +142,7 @@ def converter(input):
assert(not (load_in_4bit and load_in_8bit))

dataset_names = args.dataset_name
dataset_split = args.dataset_split
dataset_split = (args.train_dataset_split, args.eval_dataset_split)
hf_dataset_name = args.hf_dataset_name
prompt_types = args.prompt_types
model_name = args.model_name
Expand Down Expand Up @@ -203,15 +210,15 @@ def generate_text(prompt, stop_token):
def run_inference(dataset_names, dataset_split, hf_dataset_name, prompt_types, model_name, partition, use_20_fewshot, pipe, tokenizer, prompt_dir, res_dir):

# load data
ds = datasets.load_dataset(hf_dataset_name)[dataset_split]
ds = datasets.load_dataset(hf_dataset_name)

for dataset_name in dataset_names[0]:

# results directory setup
results_dir = f'{res_dir}{dataset_name.value}/{model_name}'

for pt in prompt_types[0]:
data = create_data(ds, dataset_name, pt, prompt_dir, use_20_fewshot=use_20_fewshot)
data = create_data(ds, dataset_name, dataset_split, pt, prompt_dir, use_20_fewshot=use_20_fewshot)
input_prompts, output_letters, stop_token = data['input'], data['output'], data['stop_token']

# run generation
Expand Down
8 changes: 5 additions & 3 deletions scripts/model.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ experiments=("normal" "artifact_choices")
# list of datasets to test
# see all possible datasets in: /mcqa-artifacts/model/data_loader.py
datasets=("ARC")
datasets_split="eval_only"
hf_dataset_name=""
train_dataset_split="train"
eval_dataset_split="test"
hf_dataset_name="nbalepur/MCQA_quality"

# what partition of the dataset to run
# can be "full" or in halves (e.g. "first_half"), quarters (e.g. "first_quarter"), or eigths (e.g. "first_eighth")
Expand All @@ -34,7 +35,8 @@ python3 /mcqa-artifacts/model/run_hf.py \
--model_name="$model_name" \
--model_name_hf="$model_name_hf" \
--dataset_name="$datasets_str" \
--dataset_split="$datasets_split" \
--train_dataset_split="$train_dataset_split" \
--eval_dataset_split="$eval_dataset_split" \
--hf_dataset_name="$hf_dataset_name" \
--hf_token="$hf_token" \
--load_in_4bit="$load_in_4bit" \
Expand Down

0 comments on commit 04b53fb

Please sign in to comment.