From 046e59c4a94fa7ec30de67719ea50b3e6b8548f7 Mon Sep 17 00:00:00 2001 From: Joan Puigcerver Date: Wed, 17 Jan 2024 17:30:13 -0800 Subject: [PATCH] Add param_dtype to MlpBlock to support parameters in bf16. PiperOrigin-RevId: 599337272 --- vit_jax/models_vit.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vit_jax/models_vit.py b/vit_jax/models_vit.py index f128a2c..d80700c 100644 --- a/vit_jax/models_vit.py +++ b/vit_jax/models_vit.py @@ -66,6 +66,7 @@ class MlpBlock(nn.Module): mlp_dim: int dtype: Dtype = jnp.float32 + param_dtype: Dtype = jnp.float32 out_dim: Optional[int] = None dropout_rate: float = 0.1 kernel_init: Callable[[PRNGKey, Shape, Dtype], @@ -80,6 +81,7 @@ def __call__(self, inputs, *, deterministic): x = nn.Dense( features=self.mlp_dim, dtype=self.dtype, + param_dtype=self.param_dtype, kernel_init=self.kernel_init, bias_init=self.bias_init)( # pytype: disable=wrong-arg-types inputs) @@ -88,6 +90,7 @@ def __call__(self, inputs, *, deterministic): output = nn.Dense( features=actual_out_dim, dtype=self.dtype, + param_dtype=self.param_dtype, kernel_init=self.kernel_init, bias_init=self.bias_init)( # pytype: disable=wrong-arg-types x)