From de35ae0f6a6b3334ab970b79f1c49d222eff12de Mon Sep 17 00:00:00 2001 From: JAXopt authors Date: Mon, 9 Sep 2024 10:41:00 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 672595193 --- jaxopt/_src/gradient_descent.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/jaxopt/_src/gradient_descent.py b/jaxopt/_src/gradient_descent.py index bbaf100f..51f461c8 100644 --- a/jaxopt/_src/gradient_descent.py +++ b/jaxopt/_src/gradient_descent.py @@ -79,11 +79,9 @@ def init_state(self, """ return super().init_state(init_params, None, *args, **kwargs) - def update(self, - params: Any, - state: NamedTuple, - *args, - **kwargs) -> base.OptStep: + def update( + self, params: Any, state: ProxGradState, *args, **kwargs + ) -> base.OptStep: """Performs one iteration of gradient descent. Args: