Skip to content

Commit

Permalink
update according to ActionDataSample
Browse files Browse the repository at this point in the history
  • Loading branch information
cir7 committed Sep 6, 2023
1 parent 2770dbe commit fb082c0
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
4 changes: 2 additions & 2 deletions mmaction/evaluation/metrics/multimodal_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def process(self, data_batch, data_samples):
"""
for sample in data_samples:
# gt_labels in datasample is a LabelData
label = sample['gt_labels']['item'].item()
label = sample['gt_label'].item()
result = {
'pred_label': sample.get('pred_label'),
'gt_label': label,
Expand Down Expand Up @@ -428,7 +428,7 @@ def process(self, data_batch: Sequence[dict],
predictions (Sequence[dict]): A batch of outputs from the model.
"""
for data_sample in data_samples:
pred_score = data_sample['pred_scores']['item'].cpu()
pred_score = data_sample['pred_score'].cpu()
gt_label = format_label(data_sample['gt_label'])

if 'gt_score' in data_sample:
Expand Down
7 changes: 3 additions & 4 deletions tests/evaluation/metrics/test_retrieval_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,16 @@ class TestRetrievalRecall(TestCase):
def test_evaluate(self):
"""Test using the metric in the same way as Evalutor."""
pred = [
ActionDataSample().set_pred_score(i).to_dict() for i in [
ActionDataSample().set_pred_score(i).set_gt_label(k).to_dict()
for i, k in zip([
torch.tensor([0.7, 0.0, 0.3]),
torch.tensor([0.5, 0.2, 0.3]),
torch.tensor([0.4, 0.5, 0.1]),
torch.tensor([0.0, 0.0, 1.0]),
torch.tensor([0.0, 0.0, 1.0]),
torch.tensor([0.0, 0.0, 1.0]),
]
], [[0], [0], [1], [2], [2], [0]])
]
for sample, label in zip(pred, [[0], [0], [1], [2], [2], [0]]):
sample['gt_label'] = label

# Test with score (use score instead of label if score exists)
metric = METRICS.build(dict(type='RetrievalRecall', topk=1))
Expand Down

0 comments on commit fb082c0

Please sign in to comment.