forked from sleepingcat4/mcqa-artifacts
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
324 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import pickle | ||
import datasets | ||
import sys | ||
import os | ||
sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'model')) | ||
|
||
# specify models, datasets, and the results directory | ||
res_dir = ... | ||
MODELS = ['llama 70b', 'falcon 40b', 'mixtral 7b'] | ||
DATASETS = ['ARC', 'MMLU', 'HellaSwag'] | ||
|
||
pt = 'artifact_choices_cot' | ||
|
||
for model_nickname in MODELS: | ||
for dataset in DATASETS: | ||
|
||
res_dir = f'{res_dir}{dataset}/{model_nickname}/{pt}.pkl' | ||
out_dir = f'{res_dir}{dataset}/{model_nickname}/gen_question_data.pkl' | ||
with open(res_dir, 'rb') as handle: | ||
res = pickle.load(handle) | ||
|
||
qs = [] | ||
cs = [] | ||
invalid_count = 0 | ||
for i, r in enumerate(res['raw_text']): | ||
p = res['prompt'][i] | ||
if r != None and 'Answer:' in r: | ||
r_ = r[:r.index('Answer:')].strip() | ||
qs.append(r_) | ||
p_ = p.split('\n\n')[-1] | ||
cs.append(p_.replace('Question:', '').strip()) | ||
else: | ||
invalid_count += 1 | ||
qs.append(None) | ||
cs.append(None) | ||
|
||
out = {'questions': qs, 'choices': cs} | ||
with open(out_dir, 'wb') as handle: | ||
pickle.dump(out, handle, protocol=pickle.HIGHEST_PROTOCOL) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import pickle | ||
import datasets | ||
import copy | ||
import random | ||
import sys | ||
import os | ||
sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'model')) | ||
from data_loader import create_data_evaluation, DatasetName | ||
|
||
# specify models, datasets, and the results directory | ||
res_dir = ... | ||
MODELS = ['llama 70b', 'falcon 40b', 'mixtral 7b'] | ||
DATASETS = [DatasetName.ARC, DatasetName.HellaSwag, DatasetName.mmlu] | ||
|
||
|
||
pt = 'artifact_choices_cot' | ||
ds = datasets.load_dataset('nbalepur/mcqa_artifacts') | ||
|
||
def check_any_match(l1, l2): | ||
for i in range(len(l1)): | ||
if l1[i] == l2[i]: | ||
return True | ||
return False | ||
|
||
for dataset_name in DATASETS: | ||
|
||
data = create_data_evaluation(ds, dataset_name) | ||
qs = data['questions'] | ||
qs_copy = copy.deepcopy(qs) | ||
|
||
while check_any_match(qs, qs_copy): | ||
random.shuffle(qs_copy) | ||
|
||
dataset = dataset_name.value | ||
|
||
for model_nickname in MODELS: | ||
|
||
res_dir = f'{res_dir}{dataset}/{model_nickname}/{pt}.pkl' | ||
out_dir = f'{res_dir}{dataset}/{model_nickname}/random_question_data.pkl' | ||
with open(res_dir, 'rb') as handle: | ||
res = pickle.load(handle) | ||
|
||
qs = [] | ||
cs = [] | ||
invalid_count = 0 | ||
for i, r in enumerate(res['raw_text']): | ||
p = res['prompt'][i] | ||
if r != None and 'Answer:' in r: | ||
r_ = r[:r.index('Answer:')].strip() | ||
qs.append(r_) | ||
p_ = p.split('\n\n')[-1] | ||
cs.append(p_.replace('Question:', '').strip()) | ||
else: | ||
invalid_count += 1 | ||
qs.append(None) | ||
cs.append(None) | ||
|
||
out = {'questions': qs_copy, 'choices': cs} | ||
with open(out_dir, 'wb') as handle: | ||
pickle.dump(out, handle, protocol=pickle.HIGHEST_PROTOCOL) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
#!/bin/bash | ||
|
||
model_name="llama 70b" # model nickname (for saving in folders) | ||
model_name_hf="meta-llama/Llama-2-70b-hf" # huggingface directory | ||
|
||
# list of experiments | ||
# see all possible experiments in: /mcqa-artifacts/model/data_loader.py | ||
experiments=("normal" "artifact_choices_cot") | ||
|
||
# list of datasets to test | ||
# see all possible datasets in: /mcqa-artifacts/model/data_loader.py | ||
datasets=("ARC") | ||
|
||
# what partition of the dataset to run | ||
# can be "full" or in halves (e.g. "first_half"), quarters (e.g. "first_quarter"), or eigths (e.g. "first_eighth") | ||
partition="full" | ||
|
||
hf_token=... # huggingface token (for downloading gated models) | ||
load_in_8bit="False" # load the model in 8bit? ("False" or "True") | ||
load_in_4bit="False" # load the model in 4bit? ("False" or "True") | ||
use_20_fewshot="False" # use a 20-shot prompt in ARC? ("False" or "True") => we set this to "True" for Falcon | ||
|
||
|
||
|
||
datasets_str=$(IFS=" "; echo "${datasets[*]}") | ||
experiments_str=$(IFS=" "; echo "${experiments[*]}") | ||
|
||
# add the correct file below | ||
python3 /mcqa-artifacts/model/run_hf.py \ | ||
--model_name="$model_name" \ | ||
--model_name_hf="$model_name_hf" \ | ||
--dataset_name="$datasets_str" \ | ||
--hf_token="$hf_token" \ | ||
--load_in_4bit="$load_in_4bit" \ | ||
--load_in_8bit="$load_in_8bit" \ | ||
--partition="$partition" \ | ||
--prompt_types="$experiments_str" \ | ||
--use_20_fewshot="$use_20_fewshot" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
#!/bin/bash | ||
|
||
model_name="llama 70b" # model nickname (for saving in folders) | ||
model_name_hf="meta-llama/Llama-2-70b-hf" # huggingface directory | ||
|
||
# list of experiments | ||
# see all possible experiments in: /mcqa-artifacts/model/data_loader.py | ||
experiments=("normal" "artifact_choices_cot") | ||
|
||
# list of datasets to test | ||
# see all possible datasets in: /mcqa-artifacts/model/data_loader.py | ||
datasets=("ARC") | ||
|
||
# what partition of the dataset to run | ||
# can be "full" or in halves (e.g. "first_half"), quarters (e.g. "first_quarter"), or eigths (e.g. "first_eighth") | ||
partition="full" | ||
|
||
# Should you use a random question ("True") or a model-generated question ("False") | ||
use_random_question="False" | ||
|
||
hf_token=... # huggingface token (for downloading gated models) | ||
load_in_8bit="False" # load the model in 8bit? ("False" or "True") | ||
load_in_4bit="False" # load the model in 4bit? ("False" or "True") | ||
use_20_fewshot="False" # use a 20-shot prompt in ARC? ("False" or "True") => we set this to "True" for Falcon | ||
|
||
|
||
|
||
datasets_str=$(IFS=" "; echo "${datasets[*]}") | ||
experiments_str=$(IFS=" "; echo "${experiments[*]}") | ||
|
||
python3 /mcqa-artifacts/model/run_hf_question_gen.py \ | ||
--model_name="$model_name" \ | ||
--model_name_hf="$model_name_hf" \ | ||
--dataset_name="$datasets_str" \ | ||
--hf_token="$hf_token" \ | ||
--load_in_4bit="$load_in_4bit" \ | ||
--load_in_8bit="$load_in_8bit" \ | ||
--partition="$partition" \ | ||
--use_random_question="$use_random_question" \ | ||
--use_20_fewshot="$use_20_fewshot" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
#!/bin/bash | ||
|
||
model_name="llama 70b" # model nickname (for saving in folders) | ||
model_name_hf="meta-llama/Llama-2-70b-hf" # huggingface directory | ||
|
||
# list of experiments | ||
# see all possible experiments in: /mcqa-artifacts/model/data_loader.py | ||
experiments=("normal" "artifact_choices_cot") | ||
|
||
# list of datasets to test | ||
# see all possible datasets in: /mcqa-artifacts/model/data_loader.py | ||
datasets=("ARC") | ||
|
||
# what partition of the dataset to run | ||
# can be "full" or in halves (e.g. "first_half"), quarters (e.g. "first_quarter"), or eigths (e.g. "first_eighth") | ||
partition="full" | ||
|
||
# Should you use a random question ("True") or a model-generated question ("False") | ||
use_random_question="False" | ||
|
||
hf_token=... # huggingface token (for downloading gated models) | ||
load_in_8bit="False" # load the model in 8bit? ("False" or "True") | ||
load_in_4bit="False" # load the model in 4bit? ("False" or "True") | ||
use_20_fewshot="False" # use a 20-shot prompt in ARC? ("False" or "True") => we set this to "True" for Falcon | ||
|
||
|
||
|
||
datasets_str=$(IFS=" "; echo "${datasets[*]}") | ||
experiments_str=$(IFS=" "; echo "${experiments[*]}") | ||
|
||
python3 /mcqa-artifacts/model/run_hf_question_gen_remote.py \ | ||
--model_name="$model_name" \ | ||
--model_name_hf="$model_name_hf" \ | ||
--dataset_name="$datasets_str" \ | ||
--hf_token="$hf_token" \ | ||
--load_in_4bit="$load_in_4bit" \ | ||
--load_in_8bit="$load_in_8bit" \ | ||
--partition="$partition" \ | ||
--use_random_question="$use_random_question" \ | ||
--use_20_fewshot="$use_20_fewshot" |
Oops, something went wrong.