Skip to content

Commit

Permalink
update model code
Browse files Browse the repository at this point in the history
  • Loading branch information
nbalepur committed Feb 23, 2024
1 parent 707da7d commit 5ece2e5
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 28 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
scripts/
10 changes: 6 additions & 4 deletions model/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ class PromptType(Enum):


class DatasetName(Enum):
mmlu = 'mmlu' # MMLU
HellaSwag = 'HellaSwag' # HellaSwag
ARC = 'ARC' # ARC
Winogrande = 'Winogrande' # Winogrande (not in paper)
ARC = 'ARC'
CQA = 'CQA'
OBQA = 'OBQA'
PIQA = 'PIQA'
QASC = 'QASC'
SIQA = 'SIQA'

prompt_type_map = {
PromptType.normal: Normal,
Expand Down
11 changes: 7 additions & 4 deletions model/run_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,14 @@ def run_inference(dataset_names, prompt_types, model_name, partition, use_20_few

# save results
if partition != 'full':
with open(f'{results_dir}/{pt.value}_{partition}.pkl', 'wb') as handle:
pickle.dump(answers, handle, protocol=pickle.HIGHEST_PROTOCOL)
final_res_dir = f'{results_dir}/{pt.value}_{partition}.pkl'
else:
with open(f'{results_dir}/{pt.value}.pkl', 'wb') as handle:
pickle.dump(answers, handle, protocol=pickle.HIGHEST_PROTOCOL)
final_res_dir = f'{results_dir}/{pt.value}.pkl'

if not os.path.exists(final_res_dir):
os.makedirs(final_res_dir)
with open(final_res_dir, 'wb') as handle:
pickle.dump(answers, handle, protocol=pickle.HIGHEST_PROTOCOL)

if __name__ == '__main__':

Expand Down
11 changes: 7 additions & 4 deletions model/run_hf_question_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,14 @@ def run_inference(dataset_names, model_name, partition, use_random_question, use
# save results
suffix = 'random' if use_random_question else 'generated'
if partition != 'full':
with open(f'{results_dir}/artifact_choices_cot_twostep_{suffix}_{partition}.pkl', 'wb') as handle:
pickle.dump(answers, handle, protocol=pickle.HIGHEST_PROTOCOL)
final_res_dir = f'{results_dir}/artifact_choices_cot_twostep_{suffix}_{partition}.pkl'
else:
with open(f'{results_dir}/artifact_choices_cot_twostep_{suffix}.pkl', 'wb') as handle:
pickle.dump(answers, handle, protocol=pickle.HIGHEST_PROTOCOL)
final_res_dir = f'{results_dir}/artifact_choices_cot_twostep_{suffix}.pkl'

if not os.path.exists(final_res_dir):
os.makedirs(final_res_dir)
with open(final_res_dir, 'wb') as handle:
pickle.dump(answers, handle, protocol=pickle.HIGHEST_PROTOCOL)

if __name__ == '__main__':

Expand Down
13 changes: 8 additions & 5 deletions model/run_hf_question_gen_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,15 +254,18 @@ def run_inference(dataset_names, model_name, partition, use_random_question, use
print('done generating!', flush=True)
answers['raw_text'].append(out_text)
answers['prompt'].append(prompt)

# save results
suffix = 'random' if use_random_question else 'generated'
if partition != 'full':
with open(f'{results_dir}/artifact_choices_cot_twostep_{suffix}_{partition}.pkl', 'wb') as handle:
pickle.dump(answers, handle, protocol=pickle.HIGHEST_PROTOCOL)
final_res_dir = f'{results_dir}/artifact_choices_cot_twostep_{suffix}_{partition}.pkl'
else:
with open(f'{results_dir}/artifact_choices_cot_twostep_{suffix}.pkl', 'wb') as handle:
pickle.dump(answers, handle, protocol=pickle.HIGHEST_PROTOCOL)
final_res_dir = f'{results_dir}/artifact_choices_cot_twostep_{suffix}.pkl'

if not os.path.exists(final_res_dir):
os.makedirs(final_res_dir)
with open(final_res_dir, 'wb') as handle:
pickle.dump(answers, handle, protocol=pickle.HIGHEST_PROTOCOL)

if __name__ == '__main__':

Expand Down
11 changes: 7 additions & 4 deletions model/run_hf_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,14 @@ def run_inference(dataset_names, prompt_types, model_name, partition, use_20_few

# save results
if partition != 'full':
with open(f'{results_dir}/{pt.value}_{partition}.pkl', 'wb') as handle:
pickle.dump(answers, handle, protocol=pickle.HIGHEST_PROTOCOL)
final_res_dir = f'{results_dir}/{pt.value}_{partition}.pkl'
else:
with open(f'{results_dir}/{pt.value}.pkl', 'wb') as handle:
pickle.dump(answers, handle, protocol=pickle.HIGHEST_PROTOCOL)
final_res_dir = f'{results_dir}/{pt.value}.pkl'

if not os.path.exists(final_res_dir):
os.makedirs(final_res_dir)
with open(final_res_dir, 'wb') as handle:
pickle.dump(answers, handle, protocol=pickle.HIGHEST_PROTOCOL)

if __name__ == '__main__':

Expand Down
14 changes: 7 additions & 7 deletions scripts/model.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#!/bin/bash

model_name="llama 70b" # model nickname (for saving in folders)
model_name_hf="meta-llama/Llama-2-70b-hf" # huggingface directory
model_name="llama 7b" # model nickname (for saving in folders)
model_name_hf="meta-llama/Llama-2-7b-hf" # huggingface directory

# list of experiments
# see all possible experiments in: /mcqa-artifacts/model/data_loader.py
experiments=("normal" "artifact_choices_cot")
experiments=("normal" "artifact_choices")

# list of datasets to test
# see all possible datasets in: /mcqa-artifacts/model/data_loader.py
Expand All @@ -15,14 +15,14 @@ datasets=("ARC")
# can be "full" or in halves (e.g. "first_half"), quarters (e.g. "first_quarter"), or eigths (e.g. "first_eighth")
partition="full"

hf_token=... # huggingface token (for downloading gated models)
hf_token="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" # huggingface token (for downloading gated models)
load_in_8bit="False" # load the model in 8bit? ("False" or "True")
load_in_4bit="False" # load the model in 4bit? ("False" or "True")
use_20_fewshot="False" # use a 20-shot prompt in ARC? ("False" or "True") => we set this to "True" for Falcon

res_dir=".../mcqa-artifacts/results" # Results folder directory
prompt_dir=".../mcqa-artifacts/prompts" # Prompt folder directory
cache_dir=... # Cache directory to save the model
res_dir="/fs/clip-projects/rlab/nbalepur/Artifact_Dataset_Creation/mcqa-artifacts/results/" # Results folder directory
prompt_dir="/fs/clip-projects/rlab/nbalepur/Artifact_Dataset_Creation/mcqa-artifacts/prompts/" # Prompt folder directory
cache_dir="/fs/clip-scratch/nbalepur/cache/" # Cache directory to save the model



Expand Down

0 comments on commit 5ece2e5

Please sign in to comment.