Skip to content

Commit

Permalink
Replace uses of jnp.array in types with jnp.ndarray.
Browse files Browse the repository at this point in the history
`jnp.array` is a function, not a type:
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html
so it never makes sense to use `jnp.array` in a type annotation. Presumably the intent was to write `jnp.ndarray` aka `jax.Array`.

PiperOrigin-RevId: 559395698
  • Loading branch information
hawkinsp authored and Scenic Authors committed Aug 23, 2023
1 parent 50be9b0 commit 32b3b10
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions scenic/projects/baselines/detr/detr_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,9 @@ def compute_cost_matrix(self, predictions: ArrayDict,
"""
raise NotImplementedError('Subclasses must implement compute_cost_matrix.')

def matcher(self,
cost: jnp.ndarray,
n_cols: Optional[jnp.array] = None) -> jnp.ndarray:
def matcher(
self, cost: jnp.ndarray, n_cols: Optional[jnp.ndarray] = None
) -> jnp.ndarray:
"""Implements a matching function.
Matching function matches output detections against ground truth detections
Expand Down

0 comments on commit 32b3b10

Please sign in to comment.