From 2d62a20aee8107ddd1922bb4a748b078d5b55d2e Mon Sep 17 00:00:00 2001 From: tigranthegreat Date: Tue, 9 Jan 2024 17:12:23 +0400 Subject: [PATCH] add propwise calculation --- src/config/config.yaml | 2 +- src/eval_metrics.py | 60 +++++++++++++++++++++++++++++++++++++----- 2 files changed, 54 insertions(+), 8 deletions(-) diff --git a/src/config/config.yaml b/src/config/config.yaml index f51b199..62fbae8 100644 --- a/src/config/config.yaml +++ b/src/config/config.yaml @@ -19,4 +19,4 @@ tpu_env: [] tpu_use_cluster: false tpu_use_sudo: false use_cpu: false -# main_process_port: 30001 \ No newline at end of file +main_process_port: 30001 \ No newline at end of file diff --git a/src/eval_metrics.py b/src/eval_metrics.py index 87c7b72..3603ff5 100644 --- a/src/eval_metrics.py +++ b/src/eval_metrics.py @@ -3,7 +3,8 @@ import torch.nn.functional as F import math from collections import namedtuple -from utils import get_start2end_tags_map +from utils import get_start2end_tags_map, get_tokenizer +from functools import cache PropertyEntry = namedtuple( @@ -68,9 +69,20 @@ # construct_prop_entries function once and return the already created instance when requested. # """ +@cache +def get_prop2index_map(start2end_tags: dict): + prop2index_map = {start: i for i, start in enumerate(start2end_tags.keys())} + def inner_func(prop: str): + return prop2index_map[prop] + return inner_func -# def get_property_entries(): -# return getattr(get_property_entries, "property_entries", construct_prop_entries()) # fix this should be setattr + +@cache +def get_index2prop_map(start2end_tags: dict): + index2prop_map = [start for start in start2end_tags.keys()] + def inner_func(index: int): + return index2prop_map[index] + return inner_func # TODO: add overflow error handling here @@ -93,14 +105,48 @@ def preprocess_logits_for_metrics(logits: torch.Tensor, labels: torch.Tensor): logits = logits[..., :-1, :].contiguous().view(-1, logits.size(2)) labels = labels[..., 1:].contiguous().view(-1) + tokenizer = get_tokenizer() + # print(tokenizer.decode(labels)) + # metrics_tensor is matrix containing perplexities related to properties # metrics_tensor[i][0] shows the perplexity of the ith property # metrics_tensor[i][1] shows the number of times the ith property occured - metrics_tensor = torch.zeros(2, len(property_names), device=labels.device) - metrics_tensor[0][0] = perplexity(logits, labels) - metrics_tensor[1][0] = 1 - + # metrics_tensor[-1] is for the perplexity of the whole sequence start2end_tags = get_start2end_tags_map() + metrics_tensor = torch.zeros(len(start2end_tags) + 1, 2, device=labels.device) + metrics_tensor[-1][0] = perplexity(logits, labels) + metrics_tensor[-1][1] = 1 + + start_tags_mask = torch.zeros(labels.size(0), dtype=torch.bool, device=labels.device) + end_tags_mask = torch.zeros(labels.size(0), dtype=torch.bool, device=labels.device) + for start, end in start2end_tags.items(): + start_mask = (labels == tokenizer.encode(start)[0]) + start_tags_mask = torch.bitwise_or(start_tags_mask, start_mask) + + end_mask = (labels == tokenizer.encode(end)[0]) + end_tags_mask = torch.bitwise_or(end_tags_mask, end_mask) + + start_tags_indices = torch.where(start_tags_mask)[0] + end_tags_indices = torch.where(end_tags_mask)[0] + + prop2index = get_prop2index_map(start2end_tags) + # two pointers + first_ptr = 0 + second_ptr = 0 + while first_ptr < start_tags_indices.size(0) and second_ptr < end_tags_indices.size(0): + while second_ptr < end_tags_indices.size(0) and start_tags_indices[first_ptr] >= end_tags_indices[second_ptr]: + second_ptr += 1 + if second_ptr < end_tags_indices.size(0): + # [PROP_NAME]...value...[/PROP_NAME] + # ^ ^ + # start_index end_index (one before the closing tag) + start_index = start_tags_indices[first_ptr] + end_index = end_tags_indices[second_ptr] + index = prop2index(tokenizer.decode(labels[start_index])) + # print(tokenizer.decode(labels[start_index]), ":", index, "perp", perplexity(logits[start_index+1:end_index], labels[start_index+1:end_index])) + metrics_tensor[index][0] += perplexity(logits[start_index+1:end_index], labels[start_index+1:end_index]) + metrics_tensor[index][1] += 1 + first_ptr += 1 # start_brackets = torch.where( # torch.bitwise_or(