Skip to content

Commit

Permalink
add propwise calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
tigranfah committed Jan 9, 2024
1 parent d766a01 commit 2d62a20
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
# main_process_port: 30001
main_process_port: 30001
60 changes: 53 additions & 7 deletions src/eval_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import torch.nn.functional as F
import math
from collections import namedtuple
from utils import get_start2end_tags_map
from utils import get_start2end_tags_map, get_tokenizer
from functools import cache


PropertyEntry = namedtuple(
Expand Down Expand Up @@ -68,9 +69,20 @@
# construct_prop_entries function once and return the already created instance when requested.
# """

@cache
def get_prop2index_map(start2end_tags: dict):
prop2index_map = {start: i for i, start in enumerate(start2end_tags.keys())}
def inner_func(prop: str):
return prop2index_map[prop]
return inner_func

# def get_property_entries():
# return getattr(get_property_entries, "property_entries", construct_prop_entries()) # fix this should be setattr

@cache
def get_index2prop_map(start2end_tags: dict):
index2prop_map = [start for start in start2end_tags.keys()]
def inner_func(index: int):
return index2prop_map[index]
return inner_func


# TODO: add overflow error handling here
Expand All @@ -93,14 +105,48 @@ def preprocess_logits_for_metrics(logits: torch.Tensor, labels: torch.Tensor):
logits = logits[..., :-1, :].contiguous().view(-1, logits.size(2))
labels = labels[..., 1:].contiguous().view(-1)

tokenizer = get_tokenizer()
# print(tokenizer.decode(labels))

# metrics_tensor is matrix containing perplexities related to properties
# metrics_tensor[i][0] shows the perplexity of the ith property
# metrics_tensor[i][1] shows the number of times the ith property occured
metrics_tensor = torch.zeros(2, len(property_names), device=labels.device)
metrics_tensor[0][0] = perplexity(logits, labels)
metrics_tensor[1][0] = 1

# metrics_tensor[-1] is for the perplexity of the whole sequence
start2end_tags = get_start2end_tags_map()
metrics_tensor = torch.zeros(len(start2end_tags) + 1, 2, device=labels.device)
metrics_tensor[-1][0] = perplexity(logits, labels)
metrics_tensor[-1][1] = 1

start_tags_mask = torch.zeros(labels.size(0), dtype=torch.bool, device=labels.device)
end_tags_mask = torch.zeros(labels.size(0), dtype=torch.bool, device=labels.device)
for start, end in start2end_tags.items():
start_mask = (labels == tokenizer.encode(start)[0])
start_tags_mask = torch.bitwise_or(start_tags_mask, start_mask)

end_mask = (labels == tokenizer.encode(end)[0])
end_tags_mask = torch.bitwise_or(end_tags_mask, end_mask)

start_tags_indices = torch.where(start_tags_mask)[0]
end_tags_indices = torch.where(end_tags_mask)[0]

prop2index = get_prop2index_map(start2end_tags)
# two pointers
first_ptr = 0
second_ptr = 0
while first_ptr < start_tags_indices.size(0) and second_ptr < end_tags_indices.size(0):
while second_ptr < end_tags_indices.size(0) and start_tags_indices[first_ptr] >= end_tags_indices[second_ptr]:
second_ptr += 1
if second_ptr < end_tags_indices.size(0):
# [PROP_NAME]...value...[/PROP_NAME]
# ^ ^
# start_index end_index (one before the closing tag)
start_index = start_tags_indices[first_ptr]
end_index = end_tags_indices[second_ptr]
index = prop2index(tokenizer.decode(labels[start_index]))
# print(tokenizer.decode(labels[start_index]), ":", index, "perp", perplexity(logits[start_index+1:end_index], labels[start_index+1:end_index]))
metrics_tensor[index][0] += perplexity(logits[start_index+1:end_index], labels[start_index+1:end_index])
metrics_tensor[index][1] += 1
first_ptr += 1

# start_brackets = torch.where(
# torch.bitwise_or(
Expand Down

0 comments on commit 2d62a20

Please sign in to comment.