From 32b3b102fe5ec58d082b1f5651ff54bac2aa346d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 23 Aug 2023 05:34:14 -0700 Subject: [PATCH] Replace uses of `jnp.array` in types with `jnp.ndarray`. `jnp.array` is a function, not a type: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html so it never makes sense to use `jnp.array` in a type annotation. Presumably the intent was to write `jnp.ndarray` aka `jax.Array`. PiperOrigin-RevId: 559395698 --- scenic/projects/baselines/detr/detr_base_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scenic/projects/baselines/detr/detr_base_model.py b/scenic/projects/baselines/detr/detr_base_model.py index bdb668b11..9f7b62eb2 100644 --- a/scenic/projects/baselines/detr/detr_base_model.py +++ b/scenic/projects/baselines/detr/detr_base_model.py @@ -160,9 +160,9 @@ def compute_cost_matrix(self, predictions: ArrayDict, """ raise NotImplementedError('Subclasses must implement compute_cost_matrix.') - def matcher(self, - cost: jnp.ndarray, - n_cols: Optional[jnp.array] = None) -> jnp.ndarray: + def matcher( + self, cost: jnp.ndarray, n_cols: Optional[jnp.ndarray] = None + ) -> jnp.ndarray: """Implements a matching function. Matching function matches output detections against ground truth detections