diff --git a/scenic/projects/vivit/model.py b/scenic/projects/vivit/model.py index b1a8ef7f..f25b7053 100644 --- a/scenic/projects/vivit/model.py +++ b/scenic/projects/vivit/model.py @@ -916,6 +916,9 @@ def get_metrics_fn(self, split: Optional[str] = None) -> base_model.MetricFn: """ del split # for all splits, we return the same metric functions + num_classes_in_each_head = ( + self.dataset_meta_data.get('class_splits', [-1])) + minimal_num_classes = min(num_classes_in_each_head) def classification_metrics_function(logits, batch, metrics, class_splits, split_names): @@ -947,18 +950,21 @@ def classification_metrics_function(logits, batch, metrics, class_splits, (model_utils.joint_accuracy(logits, one_hot_targets, class_splits, weights), base_model_utils.num_examples(logits, one_hot_targets, weights))) - pairwise_top_five = base_model_utils.psum_metric_normalizer( - (model_utils.joint_top_k( - logits, one_hot_targets, class_splits, k=5, weights=weights), - base_model_utils.num_examples(logits, one_hot_targets, weights))) eval_name = f'{split_names[0]}-{split_names[1]}' evaluated_metrics[f'{eval_name}_accuracy'] = pairwise_acc - evaluated_metrics[f'{eval_name}_accuracy_top_5'] = pairwise_top_five + if minimal_num_classes > 5: + pairwise_top_five = base_model_utils.psum_metric_normalizer( + (model_utils.joint_top_k( + logits, one_hot_targets, class_splits, k=5, weights=weights), + base_model_utils.num_examples(logits, one_hot_targets, weights))) + evaluated_metrics[f'{eval_name}_accuracy_top_5'] = pairwise_top_five return evaluated_metrics - + metrics = ViViT_CLASSIFICATION_METRICS + if minimal_num_classes <= 5: + metrics = ViViT_CLASSIFICATION_METRICS_BASIC return functools.partial( classification_metrics_function, - metrics=ViViT_CLASSIFICATION_METRICS, + metrics=metrics, class_splits=self.class_splits, split_names=self.split_names)