-
Notifications
You must be signed in to change notification settings - Fork 70
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Your Name
committed
Feb 23, 2024
1 parent
5cb8ddf
commit d19cdd8
Showing
74 changed files
with
3,523 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
{ | ||
"fp16": { | ||
"enabled": "auto", | ||
"loss_scale": 0, | ||
"loss_scale_window": 1000, | ||
"initial_scale_power": 16, | ||
"hysteresis": 2, | ||
"min_loss_scale": 1 | ||
}, | ||
"bf16": { | ||
"enabled": "auto" | ||
}, | ||
"train_micro_batch_size_per_gpu": "auto", | ||
"train_batch_size": "auto", | ||
"gradient_accumulation_steps": "auto", | ||
"zero_optimization": { | ||
"stage": 3, | ||
"overlap_comm": true, | ||
"contiguous_gradients": true, | ||
"sub_group_size": 1e9, | ||
"reduce_bucket_size": "auto", | ||
"stage3_prefetch_bucket_size": "auto", | ||
"stage3_param_persistence_threshold": "auto", | ||
"stage3_max_live_parameters": 1e9, | ||
"stage3_max_reuse_distance": 1e9, | ||
"stage3_gather_16bit_weights_on_model_save": true | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import os | ||
import json | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--src", type=str) | ||
parser.add_argument("--dst", type=str) | ||
args = parser.parse_args() | ||
|
||
all_answers = [] | ||
for line_idx, line in enumerate(open(args.src)): | ||
res = json.loads(line) | ||
question_id = res['question_id'] | ||
text = res['text'].rstrip('.').lower() | ||
all_answers.append({"questionId": question_id, "prediction": text}) | ||
|
||
with open(args.dst, 'w') as f: | ||
json.dump(all_answers, f) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import os | ||
import json | ||
import argparse | ||
import pandas as pd | ||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--annotation-file", type=str, required=True) | ||
parser.add_argument("--result-dir", type=str, required=True) | ||
parser.add_argument("--upload-dir", type=str, required=True) | ||
parser.add_argument("--experiment", type=str, required=True) | ||
|
||
return parser.parse_args() | ||
|
||
if __name__ == "__main__": | ||
args = get_args() | ||
|
||
df = pd.read_table(args.annotation_file) | ||
|
||
cur_df = df.copy() | ||
cur_df = cur_df.drop(columns=['hint', 'category', 'source', 'image', 'comment', 'l2-category']) | ||
cur_df.insert(6, 'prediction', None) | ||
for pred in open(os.path.join(args.result_dir, f"{args.experiment}.jsonl")): | ||
pred = json.loads(pred) | ||
cur_df.loc[df['index'] == pred['question_id'], 'prediction'] = pred['text'] | ||
|
||
cur_df.to_excel(os.path.join(args.upload_dir, f"{args.experiment}.xlsx"), index=False, engine='openpyxl') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import os | ||
import json | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--src", type=str) | ||
parser.add_argument("--dst", type=str) | ||
args = parser.parse_args() | ||
|
||
cur_result = {} | ||
|
||
for line in open(args.src): | ||
data = json.loads(line) | ||
qid = data['question_id'] | ||
cur_result[f'v1_{qid}'] = data['text'] | ||
|
||
with open(args.dst, 'w') as f: | ||
json.dump(cur_result, f, indent=2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import os | ||
import json | ||
import argparse | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--annotation-file", type=str) | ||
parser.add_argument("--result-file", type=str) | ||
parser.add_argument("--result-upload-file", type=str) | ||
return parser.parse_args() | ||
|
||
|
||
def eval_single(result_file, eval_only_type=None): | ||
results = {} | ||
for line in open(result_file): | ||
row = json.loads(line) | ||
results[row['question_id']] = row | ||
|
||
type_counts = {} | ||
correct_counts = {} | ||
for question_data in data['questions']: | ||
if eval_only_type is not None and question_data['data_type'] != eval_only_type: continue | ||
data_type = question_data['question_type_id'] | ||
type_counts[data_type] = type_counts.get(data_type, 0) + 1 | ||
try: | ||
question_id = int(question_data['question_id']) | ||
except: | ||
question_id = question_data['question_id'] | ||
if question_id not in results: | ||
correct_counts[data_type] = correct_counts.get(data_type, 0) | ||
continue | ||
row = results[question_id] | ||
if row['text'] == question_data['answer']: | ||
correct_counts[data_type] = correct_counts.get(data_type, 0) + 1 | ||
|
||
total_count = 0 | ||
total_correct = 0 | ||
for data_type in sorted(type_counts.keys()): | ||
accuracy = correct_counts[data_type] / type_counts[data_type] * 100 | ||
if eval_only_type is None: | ||
print(f"{ques_type_id_to_name[data_type]}: {accuracy:.2f}%") | ||
|
||
total_count += type_counts[data_type] | ||
total_correct += correct_counts[data_type] | ||
|
||
total_accuracy = total_correct / total_count * 100 | ||
if eval_only_type is None: | ||
print(f"Total accuracy: {total_accuracy:.2f}%") | ||
else: | ||
print(f"{eval_only_type} accuracy: {total_accuracy:.2f}%") | ||
|
||
return results | ||
|
||
if __name__ == "__main__": | ||
args = get_args() | ||
data = json.load(open(args.annotation_file)) | ||
ques_type_id_to_name = {id:n for n,id in data['question_type'].items()} | ||
|
||
results = eval_single(args.result_file) | ||
eval_single(args.result_file, eval_only_type='image') | ||
eval_single(args.result_file, eval_only_type='video') | ||
|
||
with open(args.result_upload_file, 'w') as fp: | ||
for question in data['questions']: | ||
qid = question['question_id'] | ||
if qid in results: | ||
result = results[qid] | ||
else: | ||
result = results[int(qid)] | ||
fp.write(json.dumps({ | ||
'question_id': qid, | ||
'prediction': result['text'] | ||
}) + '\n') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import json | ||
import os | ||
import fire | ||
import re | ||
from convert_sqa_to_llava_base_prompt import build_prompt_chatbot | ||
|
||
|
||
def convert_to_llava(base_dir, split, prompt_format="QCM-LEA"): | ||
split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split] | ||
problems = json.load(open(os.path.join(base_dir, "problems.json"))) | ||
|
||
split_problems = build_prompt_chatbot( | ||
problems, split_indices, prompt_format, | ||
use_caption=False, is_test=False) | ||
|
||
target_format = [] | ||
for prob_id, (input, output) in split_problems.items(): | ||
if input.startswith('Question: '): | ||
input = input.replace('Question: ', '') | ||
if output.startswith('Answer: '): | ||
output = output.replace('Answer: ', '') | ||
|
||
raw_prob_data = problems[prob_id] | ||
if raw_prob_data['image'] is None: | ||
target_format.append({ | ||
"id": prob_id, | ||
"conversations": [ | ||
{'from': 'human', 'value': f"{input}"}, | ||
{'from': 'gpt', 'value': f"{output}"}, | ||
], | ||
}) | ||
|
||
else: | ||
target_format.append({ | ||
"id": prob_id, | ||
"image": os.path.join(prob_id, raw_prob_data['image']), | ||
"conversations": [ | ||
{'from': 'human', 'value': f"{input}\n<image>"}, | ||
{'from': 'gpt', 'value': f"{output}"}, | ||
], | ||
}) | ||
|
||
print(f'Number of samples: {len(target_format)}') | ||
|
||
with open(os.path.join(base_dir, f"llava_{split}_{prompt_format}.json"), "w") as f: | ||
json.dump(target_format, f, indent=2) | ||
|
||
|
||
def convert_to_jsonl(base_dir, split, prompt_format="QCM-LEPA"): | ||
split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split] | ||
problems = json.load(open(os.path.join(base_dir, "problems.json"))) | ||
|
||
split_problems = build_prompt_chatbot( | ||
problems, split_indices, prompt_format, | ||
use_caption=False, is_test=False) | ||
|
||
writer = open(os.path.join(base_dir, f"scienceqa_{split}_{prompt_format}.jsonl"), "w") | ||
for prob_id, (input, output) in split_problems.items(): | ||
if input.startswith('Question: '): | ||
input = input.replace('Question: ', '') | ||
if output.startswith('Answer: '): | ||
output = output.replace('Answer: ', '') | ||
|
||
raw_prob_data = problems[prob_id] | ||
if raw_prob_data['image'] is None: | ||
data = { | ||
"id": prob_id, | ||
"instruction": f"{input}", | ||
"output": f"{output}", | ||
} | ||
|
||
else: | ||
data = { | ||
"id": prob_id, | ||
"image": os.path.join(prob_id, raw_prob_data['image']), | ||
"instruction": f"{input}\n<image>", | ||
"output": f"{output}", | ||
} | ||
writer.write(json.dumps(data) + '\n') | ||
writer.close() | ||
|
||
|
||
def main(task, **kwargs): | ||
globals()[task](**kwargs) | ||
|
||
|
||
if __name__ == "__main__": | ||
fire.Fire(main) |
Oops, something went wrong.