diff --git a/scenic/projects/mbt/model.py b/scenic/projects/mbt/model.py index 867141a32..0625441f6 100644 --- a/scenic/projects/mbt/model.py +++ b/scenic/projects/mbt/model.py @@ -570,10 +570,12 @@ def init_from_train_state(self, restored_model_cfg, restore_output_proj) - def loss_function(self, - logits: jnp.array, - batch: base_model.Batch, - model_params: Optional[jnp.array] = None) -> float: + def loss_function( + self, + logits: jnp.ndarray, + batch: base_model.Batch, + model_params: Optional[jnp.ndarray] = None, + ) -> float: """Returns sigmoid cross entropy loss with an L2 penalty on the weights. Args: