Skip to content

Commit

Permalink
Merge pull request #12 from TianyiQ/main
Browse files Browse the repository at this point in the history
fix(evaluation): data processing and normalization before plotting #1
  • Loading branch information
TianyiQ authored Oct 31, 2024
2 parents 763cc28 + 9b3b472 commit 2b086e7
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
15 changes: 11 additions & 4 deletions src/evaluation/quantify.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _calculate_model(test_name, high_or_low, model_name):
)
)(high_or_low)
raw_dir = os.path.join(
"output", "evaluation_results", test_name, model_name + "raw.json"
"output", "evaluation_results", test_name, model_name + "_raw.json"
)
scenario_dir = os.path.join(
"src", "moralchoice", "assets", "data", "scenarios", scenario + ".csv"
Expand Down Expand Up @@ -236,6 +236,12 @@ def calculate_model(test_dir, model_name):
)
mrl_vec[2][key] /= 2
if num == 1:
# ref_dict = csv_to_dict_list(ref_dir[1], ['scenario_id', 'generation_theme'])
ref_dict = csv_to_dict(ref_dir[1], ["generation_theme"])
if key not in ref_dict.keys():
print("key", key, "not found")
continue

"""
registering invalid
"""
Expand Down Expand Up @@ -273,8 +279,6 @@ def calculate_model(test_dir, model_name):
"Authority_Respect": 3,
"Purity_Sancity": 4,
}
# ref_dict = csv_to_dict_list(ref_dir[1], ['scenario_id', 'generation_theme'])
ref_dict = csv_to_dict(ref_dir[1], ["generation_theme"])
theme = ref_dict[key]["generation_theme"].strip()
mrl_vec[1][key][context_matching[theme]] += mal
"""
Expand Down Expand Up @@ -359,7 +363,10 @@ def plot_heatmap(vectors, title, tuples):
"""
data = np.array(vectors)

data = normalize_matrix(data, tuples)
# data = normalize_matrix(data, tuples)
data -= np.mean(data, axis=0, keepdims=True)
data /= np.std(data, axis=0, keepdims=True)
# data /= np.sum(data, axis=0, keepdims=True)
if data.shape[1] != 19:
raise ValueError("All vectors should be 19-dimensional")

Expand Down
1 change: 1 addition & 0 deletions src/evaluation/test_eval_01.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,6 @@
with open("output/evaluation_results/" + test_name + ".json", "w") as f:
lst = [list(boi) for boi in vec]
json.dump(lst, f)
# vec = json.load(open("output/evaluation_results/" + test_name + ".json", "r"))
# qt.plot_parallel_coordinates(vec)
qt.plot_heatmap(vec, test_name, [(0,4), (5,9), (10,14), (15,18)])

0 comments on commit 2b086e7

Please sign in to comment.