Skip to content

Commit

Permalink
Fixes vit_jax_augreg Colab.
Browse files Browse the repository at this point in the history
1. Flax introduced a new format for GroupNorm weights in checkpoints (google/flax#1721) which is fixed with the added `_fix_groupnorm()`.

2. It was reported in #249 that passing a list as argument does not work anymore. This has been fixed by converting the EagerTensor to a numpy array and adding the batch dimension via fancy indexing.

PiperOrigin-RevId: 487740672
  • Loading branch information
andsteing authored and copybara-github committed Nov 11, 2022
1 parent 60104c1 commit 0d80400
Show file tree
Hide file tree
Showing 2 changed files with 781 additions and 779 deletions.
2 changes: 2 additions & 0 deletions vit_jax/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ def load(path):
params = checkpoints.convert_pre_linen(recover_tree(keys, values))
if isinstance(params, flax.core.FrozenDict):
params = params.unfreeze()
if version.parse(flax.__version__) >= version.parse('0.3.6'):
params = _fix_groupnorm(params)
return params


Expand Down
Loading

0 comments on commit 0d80400

Please sign in to comment.