Skip to content

Commit

Permalink
Add param_dtype to MlpBlock to support parameters in bf16.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 599337272
  • Loading branch information
jpuigcerver authored and copybara-github committed Jan 18, 2024
1 parent 10ffdeb commit 65130d2
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions vit_jax/models_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class MlpBlock(nn.Module):

mlp_dim: int
dtype: Dtype = jnp.float32
param_dtype: Dtype = jnp.float32
out_dim: Optional[int] = None
dropout_rate: float = 0.1
kernel_init: Callable[[PRNGKey, Shape, Dtype],
Expand All @@ -80,6 +81,7 @@ def __call__(self, inputs, *, deterministic):
x = nn.Dense(
features=self.mlp_dim,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init)( # pytype: disable=wrong-arg-types
inputs)
Expand All @@ -88,6 +90,7 @@ def __call__(self, inputs, *, deterministic):
output = nn.Dense(
features=actual_out_dim,
dtype=self.dtype,
param_dtype=self.param_dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init)( # pytype: disable=wrong-arg-types
x)
Expand Down

0 comments on commit 65130d2

Please sign in to comment.