Skip to content

Commit

Permalink
Merge pull request #202 from neulab/ner_metrics
Browse files Browse the repository at this point in the history
Change NER to use metrics class

Former-commit-id: 767353b
  • Loading branch information
neubig authored Apr 2, 2022
2 parents 4da06d8 + df745f0 commit 8804370
Show file tree
Hide file tree
Showing 12 changed files with 211 additions and 252 deletions.
81 changes: 73 additions & 8 deletions explainaboard/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import numpy as np

from explainaboard.utils.async_eaas import AsyncEaaSRequest
from explainaboard.utils.span_utils import get_spans_from_bio
from explainaboard.utils.typing_utils import unwrap


Expand Down Expand Up @@ -215,16 +216,23 @@ class F1Score(Metric):
def default_name(cls) -> str:
return 'F1'

def __init__(self, average: str = 'micro', separate_match: bool = False):
def __init__(
self,
average: str = 'micro',
separate_match: bool = False,
ignore_classes: Optional[list] = None,
):
"""Constructor for f-measure
:param average: What variety of average to measure
:param separate_match: Whether to count matches separately for true and pred.
This is useful in, for example bucketing, when ref and pred are not aligned
:param ignore_classes: Classes to ignore
"""
self.average: str = average
self.separate_match: bool = separate_match
self._stat_mult: int = 4 if separate_match else 3
self._pred_match_offfset: int = 3 if separate_match else 2
self.ignore_classes: Optional[list] = ignore_classes
supported_averages = {'micro', 'macro'}
if average not in supported_averages:
raise ValueError(f'only {supported_averages} supported for now')
Expand All @@ -244,6 +252,9 @@ def calc_stats_from_data(self, true_data: list, pred_data: list) -> MetricStats:
(when self.separate_match=True only)
"""
id_map: dict[str, int] = {}
if self.ignore_classes is not None:
for ignore_class in self.ignore_classes:
id_map[ignore_class] = -1
for word in itertools.chain(true_data, pred_data):
if word not in id_map:
id_map[word] = len(id_map)
Expand All @@ -253,12 +264,14 @@ def calc_stats_from_data(self, true_data: list, pred_data: list) -> MetricStats:
stats = np.zeros((n_data, n_classes * self._stat_mult))
for i, (t, p) in enumerate(zip(true_data, pred_data)):
tid, pid = id_map[t], id_map[p]
stats[i, tid * self._stat_mult + 0] += 1
stats[i, pid * self._stat_mult + 1] += 1
if tid == pid:
stats[i, tid * self._stat_mult + 2] += 1
if self.separate_match:
stats[i, tid * self._stat_mult + 3] += 1
if tid != -1:
stats[i, tid * self._stat_mult + 0] += 1
if pid != -1:
stats[i, pid * self._stat_mult + 1] += 1
if tid == pid:
stats[i, tid * self._stat_mult + 2] += 1
if self.separate_match:
stats[i, tid * self._stat_mult + 3] += 1
return MetricStats(stats)

def calc_metric_from_aggregate(self, agg_stats: np.ndarray) -> float:
Expand Down Expand Up @@ -295,6 +308,58 @@ def get_metadata(self) -> dict:
return meta


class BIOF1Score(F1Score):
"""
Calculate F1 score over BIO-tagged spans.
"""

def __init__(self, average: str = 'micro'):
"""Constructor for BIO f-measure
:param average: What variety of average to measure
"""
super().__init__(average=average)

def calc_stats_from_data(
self, true_data: list[list[str]], pred_data: list[list[str]]
) -> MetricStats:
"""
Return sufficient statistics necessary to compute f-score.
:param true_data: True outputs
:param pred_data: Predicted outputs
:return: Returns stats for each class (integer id c) in the following columns of
MetricStats
* c*self._stat_mult + 0: occurrences in the true output
* c*self._stat_mult + 1: occurrences in the predicted output
* c*self._stat_mult + 2: number of matches with the true output
"""

# Identify the tag types
true_chain, pred_chain = (
itertools.chain.from_iterable(x) for x in (true_data, pred_data)
)
all_tags = set(itertools.chain(true_chain, pred_chain))
tag_ids = {
k: v for v, k in enumerate([x[2:] for x in all_tags if x.startswith('B-')])
}

# Create the sufficient statistics
n_data, n_classes = len(true_data), len(tag_ids)
# This is a bit memory inefficient if there's a large number of classes
stats = np.zeros((n_data, n_classes * self._stat_mult))

for i, (true_sent, pred_sent) in enumerate(zip(true_data, pred_data)):
true_spans, pred_spans = (
get_spans_from_bio(x) for x in (true_sent, pred_sent)
)
match_spans = [x for x in true_spans if x in pred_spans]
for offset, spans in enumerate((true_spans, pred_spans, match_spans)):
for chunk in spans:
c = tag_ids[chunk[0]]
stats[i, c * 3 + offset] += 1

return MetricStats(stats)


class Hits(Metric):
"""
Calculates the hits metric, telling whether the predicted output is in a set of true
Expand Down Expand Up @@ -375,7 +440,7 @@ def filter(self, indices: Union[list[int], np.ndarray]) -> MetricStats:
"""
Return a view of these stats filtered down to the indicated indices
"""
sdata: np.ndarray = unwrap(self._data)
sdata: np.ndarray = self.get_data()
if not isinstance(indices, np.ndarray):
indices = np.array(indices)
return MetricStats(sdata[indices])
Expand Down
135 changes: 59 additions & 76 deletions explainaboard/processors/named_entity_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@

from explainaboard import feature
from explainaboard.info import BucketPerformance, Performance, SysOutputInfo
from explainaboard.metric import MetricStats
import explainaboard.metric
from explainaboard.metric import Metric
from explainaboard.processors.processor import Processor
from explainaboard.processors.processor_registry import register_processor
from explainaboard.tasks import TaskType
from explainaboard.utils import bucketing, eval_basic_ner
from explainaboard.utils import bucketing, span_utils
from explainaboard.utils.analysis import cap_feature
from explainaboard.utils.eval_bucket import f1_seqeval_bucket
from explainaboard.utils.py_utils import sort_dict
from explainaboard.utils.typing_utils import unwrap

Expand Down Expand Up @@ -206,7 +206,13 @@ def default_features(cls) -> feature.Features:

@classmethod
def default_metrics(cls) -> list[str]:
return ["f1_seqeval", "recall_seqeval", "precision_seqeval"]
return ["F1Score"]

def _get_true_label(self, data_point: dict):
return data_point["true_tags"]

def _get_predicted_label(self, data_point: dict):
return data_point["pred_tags"]

def _get_statistics_resources(
self, dataset_split: Dataset
Expand Down Expand Up @@ -312,13 +318,11 @@ def _get_fre_rank(self, tokens, statistics):
# --- End feature functions

# These return none because NER is not yet in the main metric interface
def _get_metrics(self, sys_info: SysOutputInfo):
return None

def _gen_metric_stats(
self, sys_info: SysOutputInfo, sys_output: list[dict]
) -> Optional[list[MetricStats]]:
return None
def _get_metrics(self, sys_info: SysOutputInfo) -> list[Metric]:
return [
getattr(explainaboard.metric, f'BIO{name}')()
for name in unwrap(sys_info.metric_names)
]

def _complete_span_features(self, sentence, tags, statistics=None):

Expand All @@ -328,7 +332,7 @@ def _complete_span_features(self, sentence, tags, statistics=None):
efre_dic = statistics["efre_dic"] if has_stats else None

span_dics = []
chunks = eval_basic_ner.get_chunks(tags)
chunks = span_utils.get_spans_from_bio(tags)
for tag, sid, eid in chunks:
span_text = ' '.join(sentence[sid:eid])
# Basic features
Expand Down Expand Up @@ -389,35 +393,8 @@ def _complete_features(
dict_sysout["pred_entity_info"] = self._complete_span_features(
tokens, dict_sysout["pred_tags"], statistics=external_stats
)
return None

def get_overall_performance(
self,
sys_info: SysOutputInfo,
sys_output: list[dict],
metric_stats: Any = None,
) -> dict[str, Performance]:
"""
Get the overall performance according to metrics
:param sys_info: Information about the system output
:param sys_output: The system output itself
:return: a dictionary of metrics to overall performance numbers
"""

true_tags_list = [x['true_tags'] for x in sys_output]
pred_tags_list = [x['pred_tags'] for x in sys_output]

overall: dict[str, Performance] = {}
for metric_name in unwrap(sys_info.metric_names):
if not metric_name.endswith('_seqeval'):
raise NotImplementedError(f'Unsupported metric {metric_name}')
# This gets the appropriate metric from the eval_basic_ner package
score_func = getattr(eval_basic_ner, metric_name)
overall[metric_name] = Performance(
metric_name=metric_name,
value=score_func(true_tags_list, pred_tags_list),
)
return overall
# This is not used elsewhere, so just keep it as-is
return list()

def _get_span_ids(
self,
Expand Down Expand Up @@ -554,24 +531,24 @@ def get_bucket_cases_ner(
samples_over_bucket_true[bucket_interval], 'true', sample_dict
)

error_case_list = []
case_list = []
for pos, tags in sample_dict.items():
true_label = tags.get('true', 'O')
pred_label = tags.get('pred', 'O')
if true_label != pred_label:
split_pos = pos.split("|||")
sent_id = int(split_pos[0])
span = split_pos[-1]
system_output_id = sys_output[int(sent_id)]["id"]
error_case = {
"span": span,
"text": str(system_output_id),
"true_label": true_label,
"predicted_label": pred_label,
}
error_case_list.append(error_case)

return error_case_list

split_pos = pos.split("|||")
sent_id = int(split_pos[0])
span = split_pos[-1]
system_output_id = sys_output[int(sent_id)]["id"]
error_case = {
"span": span,
"text": str(system_output_id),
"true_label": true_label,
"predicted_label": pred_label,
}
case_list.append(error_case)

return case_list

def get_bucket_performance_ner(
self,
Expand All @@ -593,6 +570,12 @@ def get_bucket_performance_ner(
bucket performance
"""

metric_names = unwrap(sys_info.metric_names)
bucket_metrics = [
getattr(explainaboard.metric, name)(ignore_classes=['O'])
for name in metric_names
]

bucket_name_to_performance = {}
for bucket_interval, spans_true in samples_over_bucket_true.items():

Expand All @@ -611,29 +594,29 @@ def get_bucket_performance_ner(
samples_over_bucket_pred,
)

true_labels = [x['true_label'] for x in bucket_samples]
pred_labels = [x['predicted_label'] for x in bucket_samples]

bucket_performance = BucketPerformance(
bucket_name=bucket_interval,
n_samples=len(spans_pred),
bucket_samples=bucket_samples,
)
for metric_name in unwrap(sys_info.metric_names):
"""
# Note that: for NER task, the bucket-wise evaluation function is a
# little different from overall evaluation function
# for overall: f1_seqeval
# for bucket: f1_seqeval_bucket
"""
f1, p, r = f1_seqeval_bucket(spans_pred, spans_true)
if metric_name == 'f1_seqeval':
my_score = f1
elif metric_name == 'precision_seqeval':
my_score = p
elif metric_name == 'recall_seqeval':
my_score = r
else:
raise NotImplementedError(f'Unsupported metric {metric_name}')
# TODO(gneubig): It'd be better to have significance tests here
performance = Performance(metric_name=metric_name, value=my_score)
for metric in bucket_metrics:

metric_val = metric.evaluate(
true_labels, pred_labels, conf_value=sys_info.conf_value
)
conf_low, conf_high = (
metric_val.conf_interval if metric_val.conf_interval else None,
None,
)
performance = Performance(
metric_name=metric.name,
value=metric_val.value,
confidence_score_low=conf_low,
confidence_score_high=conf_high,
)
bucket_performance.performances.append(performance)

bucket_name_to_performance[bucket_interval] = bucket_performance
Expand All @@ -647,7 +630,7 @@ def get_econ_dic(train_word_sequences, tag_sequences_train, tags):
Note: when matching, the text span and tag have been lowercased.
"""
econ_dic = dict()
chunks_train = set(eval_basic_ner.get_chunks(tag_sequences_train))
chunks_train = set(span_utils.get_spans_from_bio(tag_sequences_train))

# print('tags: ', tags)
count_idx = 0
Expand Down Expand Up @@ -722,7 +705,7 @@ def get_econ_dic(train_word_sequences, tag_sequences_train, tags):
# Global functions for training set dependent features
def get_efre_dic(train_word_sequences, tag_sequences_train):
efre_dic = dict()
chunks_train = set(eval_basic_ner.get_chunks(tag_sequences_train))
chunks_train = set(span_utils.get_spans_from_bio(tag_sequences_train))
count_idx = 0
word_sequences_train_str = ' '.join(train_word_sequences).lower()
for true_chunk in tqdm(chunks_train):
Expand Down
5 changes: 1 addition & 4 deletions explainaboard/processors/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,7 @@ def _get_feature_func(self, func_name: str):
def _get_eaas_client(self):
if not self._eaas_client:
self._eaas_config = Config()
self._eaas_client = AsyncEaaSClient()
self._eaas_client.load_config(
self._eaas_config
) # The config you have created above
self._eaas_client = AsyncEaaSClient(self._eaas_config)
return self._eaas_client

def _get_true_label(self, data_point: dict):
Expand Down
22 changes: 20 additions & 2 deletions explainaboard/tests/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,25 @@ def test_mrr(self):
result = metric.evaluate(true, pred, conf_value=0.05)
self.assertAlmostEqual(result.value, 2.5 / 6.0)

def test_ner_f1(self):

true = [
['O', 'O', 'B-MISC', 'I-MISC', 'B-MISC', 'O', 'O'],
['B-PER', 'I-PER', 'O'],
]
pred = [
['O', 'O', 'B-MISC', 'I-MISC', 'B-MISC', 'I-MISC', 'O'],
['B-PER', 'I-PER', 'O'],
]

metric = explainaboard.metric.BIOF1Score(average='micro')
result = metric.evaluate(true, pred, conf_value=0.05)
self.assertAlmostEqual(result.value, 2.0 / 3.0)

metric = explainaboard.metric.BIOF1Score(average='macro')
result = metric.evaluate(true, pred, conf_value=0.05)
self.assertAlmostEqual(result.value, 3.0 / 4.0)

def _get_eaas_request(
self,
sys_output: list[dict],
Expand Down Expand Up @@ -92,8 +111,7 @@ def test_eaas_decomposabiltiy(self):
sys_output = list(loader.load())

# Initialize client and decide which metrics to test
eaas_client = AsyncEaaSClient()
eaas_client.load_config(Config())
eaas_client = AsyncEaaSClient(Config())
metric_names = ['rouge1', 'bleu', 'chrf']
# Uncomment the following line to test all metrics,
# but beware that it will be very slow
Expand Down
Loading

0 comments on commit 8804370

Please sign in to comment.