diff --git a/scenic/projects/baselines/detr/detr_base_model.py b/scenic/projects/baselines/detr/detr_base_model.py index bdb668b11..9f7b62eb2 100644 --- a/scenic/projects/baselines/detr/detr_base_model.py +++ b/scenic/projects/baselines/detr/detr_base_model.py @@ -160,9 +160,9 @@ def compute_cost_matrix(self, predictions: ArrayDict, """ raise NotImplementedError('Subclasses must implement compute_cost_matrix.') - def matcher(self, - cost: jnp.ndarray, - n_cols: Optional[jnp.array] = None) -> jnp.ndarray: + def matcher( + self, cost: jnp.ndarray, n_cols: Optional[jnp.ndarray] = None + ) -> jnp.ndarray: """Implements a matching function. Matching function matches output detections against ground truth detections