Skip to content

Commit

Permalink
No change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 631403898
  • Loading branch information
Scenic Authors committed May 7, 2024
1 parent 3d9a2a5 commit e9faec1
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 12 deletions.
20 changes: 13 additions & 7 deletions scenic/projects/vivit/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,9 @@ def get_metrics_fn(self, split: Optional[str] = None) -> base_model.MetricFn:
"""
del split # for all splits, we return the same metric functions

num_classes_in_each_head = (
self.dataset_meta_data.get('class_splits', [-1]))
minimal_num_classes = min(num_classes_in_each_head)
def classification_metrics_function(logits, batch, metrics, class_splits,
split_names):

Expand Down Expand Up @@ -947,18 +950,21 @@ def classification_metrics_function(logits, batch, metrics, class_splits,
(model_utils.joint_accuracy(logits, one_hot_targets, class_splits,
weights),
base_model_utils.num_examples(logits, one_hot_targets, weights)))
pairwise_top_five = base_model_utils.psum_metric_normalizer(
(model_utils.joint_top_k(
logits, one_hot_targets, class_splits, k=5, weights=weights),
base_model_utils.num_examples(logits, one_hot_targets, weights)))
eval_name = f'{split_names[0]}-{split_names[1]}'
evaluated_metrics[f'{eval_name}_accuracy'] = pairwise_acc
evaluated_metrics[f'{eval_name}_accuracy_top_5'] = pairwise_top_five
if minimal_num_classes > 5:
pairwise_top_five = base_model_utils.psum_metric_normalizer(
(model_utils.joint_top_k(
logits, one_hot_targets, class_splits, k=5, weights=weights),
base_model_utils.num_examples(logits, one_hot_targets, weights)))
evaluated_metrics[f'{eval_name}_accuracy_top_5'] = pairwise_top_five

return evaluated_metrics

metrics = ViViT_CLASSIFICATION_METRICS
if minimal_num_classes <= 5:
metrics = ViViT_CLASSIFICATION_METRICS_BASIC
return functools.partial(
classification_metrics_function,
metrics=ViViT_CLASSIFICATION_METRICS,
metrics=metrics,
class_splits=self.class_splits,
split_names=self.split_names)
1 change: 0 additions & 1 deletion scenic/projects/vivit/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from scenic.common_lib import debug_utils
from scenic.model_lib.base_models import model_utils as base_model_utils
import scipy
flax.config.update('flax_return_frozendict', True)


def reshape_to_1d_factorized(x: jnp.ndarray, axis: int):
Expand Down
9 changes: 5 additions & 4 deletions scenic/projects/vivit/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from absl import logging
from clu import metric_writers
from clu import periodic_actions
import flax
from flax import jax_utils
import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -66,6 +67,7 @@ def train(
and eval_summary which are dict of metrics. These outputs are used for
regression testing.
"""
flax.config.update('flax_return_frozendict', True)
lead_host = jax.process_index() == 0
# Build the loss_fn, metrics, and flax_model.
model = model_cls(config, dataset.meta_data)
Expand Down Expand Up @@ -110,8 +112,9 @@ def train(
restored_train_state = pretrain_utils.restore_pretrained_checkpoint(
init_checkpoint_path, train_state, assert_exist=True)
elif checkpoint_format == 'big_vision':
restored_train_state = pretrain_utils.convert_big_vision_to_scenic_checkpoint(
init_checkpoint_path, train_state)
restored_train_state = (
pretrain_utils.convert_big_vision_to_scenic_checkpoint(
init_checkpoint_path, train_state))
# Config dict in big_vision is not the same format as scenic.
# Therefore, make sure config match the config of the loaded model!
restored_model_cfg = copy.deepcopy(config)
Expand All @@ -132,7 +135,6 @@ def train(
# Replicate the optimzier, state, and rng.
train_state = jax_utils.replicate(train_state)
del params # Do not keep a copy of the initial params.

# Calculate the total number of training steps.
total_steps, steps_per_epoch = train_utils.get_num_training_steps(
config, dataset.meta_data)
Expand Down Expand Up @@ -241,7 +243,6 @@ def train(
do_memory_defrag = True
except RuntimeError:
logging.warn('Memory defragmentation not possible, use the tfrt runtime')

for step in range(start_step + 1, total_steps + 1):
with jax.profiler.StepTraceAnnotation('train', step_num=step):
train_batch = next(dataset.train_iter)
Expand Down

0 comments on commit e9faec1

Please sign in to comment.