Skip to content

Commit

Permalink
Replace uses of jnp.array in types with jnp.ndarray.
Browse files Browse the repository at this point in the history
`jnp.array` is a function, not a type:
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html
so it never makes sense to use `jnp.array` in a type annotation. Presumably the intent was to write `jnp.ndarray` aka `jax.Array`.

PiperOrigin-RevId: 559406793
  • Loading branch information
hawkinsp authored and Scenic Authors committed Aug 23, 2023
1 parent 66e58c6 commit 93a3f79
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 11 deletions.
2 changes: 1 addition & 1 deletion scenic/model_lib/layers/masked_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ class GroupNorm(nn.Module):
def __call__(
self,
x: jnp.ndarray,
spatial_shape: Optional[jnp.array] = None,
spatial_shape: Optional[jnp.ndarray] = None,
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
"""Applies group normalization to the input (arxiv.org/abs/1803.08494).
Expand Down
14 changes: 9 additions & 5 deletions scenic/projects/mtv/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,11 @@ def eval_step(
return_logits_and_labels: bool = False,
return_confusion_matrix: bool = False,
debug: Optional[bool] = False
) -> Union[Tuple[Dict[str, Tuple[float, int]], jnp.ndarray, jnp.array],
Tuple[Dict[str, Tuple[float, int]], jnp.ndarray],
Dict[str, Tuple[float, int]]]:
) -> Union[
Tuple[Dict[str, Tuple[float, int]], jnp.ndarray, jnp.ndarray],
Tuple[Dict[str, Tuple[float, int]], jnp.ndarray],
Dict[str, Tuple[float, int]],
]:
"""Runs a single step of training.
Note that in this code, the buffer of the second argument (batch) is donated
Expand Down Expand Up @@ -214,8 +216,10 @@ def test_step(
return_logits_and_labels: bool = False,
softmax_logits: bool = False,
debug: bool = False
) -> Union[Dict[str, Tuple[float, int]], Tuple[Dict[str, Tuple[float, int]],
jnp.array, jnp.array]]:
) -> Union[
Dict[str, Tuple[float, int]],
Tuple[Dict[str, Tuple[float, int]], jnp.ndarray, jnp.ndarray],
]:
"""Runs a single step of testing.
For multi-crop testing, we assume that num_crops consecutive entries in the
Expand Down
6 changes: 4 additions & 2 deletions scenic/projects/objectvivit/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,10 @@ def test_step(
debug: bool = False,
learn_token_score: bool = False,
add_boxes: bool = False,
) -> Union[Dict[str, Tuple[float, int]], Tuple[Dict[str, Tuple[float, int]],
jnp.array, jnp.array]]:
) -> Union[
Dict[str, Tuple[float, int]],
Tuple[Dict[str, Tuple[float, int]], jnp.ndarray, jnp.ndarray],
]:
"""Runs a single step of testing.
For multi-crop testing, we assume that num_crops consecutive entries in the
Expand Down
6 changes: 3 additions & 3 deletions scenic/projects/owl_vit/matching_base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def compute_cost_matrix(self, predictions: ArrayDict,
"""
...

def matcher(self,
cost: jnp.ndarray,
n_cols: Optional[jnp.array] = None) -> jnp.ndarray:
def matcher(
self, cost: jnp.ndarray, n_cols: Optional[jnp.ndarray] = None
) -> jnp.ndarray:
"""Implements a matching function.
Matching functions match predicted detections against ground truth
Expand Down

0 comments on commit 93a3f79

Please sign in to comment.