Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
1. Flax introduced a new format for GroupNorm weights in checkpoints (google/flax#1721) which is fixed with the added `_fix_groupnorm()`. 2. It was reported in #249 that passing a list as argument does not work anymore. This has been fixed by converting the EagerTensor to a numpy array and adding the batch dimension via fancy indexing. PiperOrigin-RevId: 487760232
- Loading branch information