Skip to content

Commit

Permalink
[JAX] Replace uses of jnp.array in types with jnp.ndarray.
Browse files Browse the repository at this point in the history
`jnp.array` is a function, not a type:
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html
so it never makes sense to use `jnp.array` in a type annotation.

Presumably the intent was to write `jnp.ndarray` aka `jax.Array`. Change uses of `jnp.array` to `jnp.ndarray`.

PiperOrigin-RevId: 555234490
  • Loading branch information
hawkinsp authored and Scenic Authors committed Aug 11, 2023
1 parent d579ccb commit 69b46e3
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions scenic/projects/mbt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,10 +570,12 @@ def init_from_train_state(self,
restored_model_cfg,
restore_output_proj)

def loss_function(self,
logits: jnp.array,
batch: base_model.Batch,
model_params: Optional[jnp.array] = None) -> float:
def loss_function(
self,
logits: jnp.ndarray,
batch: base_model.Batch,
model_params: Optional[jnp.ndarray] = None,
) -> float:
"""Returns sigmoid cross entropy loss with an L2 penalty on the weights.
Args:
Expand Down

0 comments on commit 69b46e3

Please sign in to comment.