Skip to content

Commit

Permalink
ViViTMultiHeadClassificationModel - Add support for tasks with less t…
Browse files Browse the repository at this point in the history
…han 5 classes.

PiperOrigin-RevId: 631403898
  • Loading branch information
Scenic Authors committed May 7, 2024
1 parent 12f7e36 commit e027023
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions scenic/projects/vivit/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,9 @@ def get_metrics_fn(self, split: Optional[str] = None) -> base_model.MetricFn:
"""
del split # for all splits, we return the same metric functions

num_classes_in_each_head = (
self.dataset_meta_data.get('class_splits', [-1]))
minimal_num_classes = min(num_classes_in_each_head)
def classification_metrics_function(logits, batch, metrics, class_splits,
split_names):

Expand Down Expand Up @@ -947,18 +950,21 @@ def classification_metrics_function(logits, batch, metrics, class_splits,
(model_utils.joint_accuracy(logits, one_hot_targets, class_splits,
weights),
base_model_utils.num_examples(logits, one_hot_targets, weights)))
pairwise_top_five = base_model_utils.psum_metric_normalizer(
(model_utils.joint_top_k(
logits, one_hot_targets, class_splits, k=5, weights=weights),
base_model_utils.num_examples(logits, one_hot_targets, weights)))
eval_name = f'{split_names[0]}-{split_names[1]}'
evaluated_metrics[f'{eval_name}_accuracy'] = pairwise_acc
evaluated_metrics[f'{eval_name}_accuracy_top_5'] = pairwise_top_five
if minimal_num_classes > 5:
pairwise_top_five = base_model_utils.psum_metric_normalizer(
(model_utils.joint_top_k(
logits, one_hot_targets, class_splits, k=5, weights=weights),
base_model_utils.num_examples(logits, one_hot_targets, weights)))
evaluated_metrics[f'{eval_name}_accuracy_top_5'] = pairwise_top_five

return evaluated_metrics

metrics = ViViT_CLASSIFICATION_METRICS
if minimal_num_classes <= 5:
metrics = ViViT_CLASSIFICATION_METRICS_BASIC
return functools.partial(
classification_metrics_function,
metrics=ViViT_CLASSIFICATION_METRICS,
metrics=metrics,
class_splits=self.class_splits,
split_names=self.split_names)

0 comments on commit e027023

Please sign in to comment.