diff --git a/src/jaxsim/api/common.py b/src/jaxsim/api/common.py index 22457a03e..260bd9323 100644 --- a/src/jaxsim/api/common.py +++ b/src/jaxsim/api/common.py @@ -3,6 +3,7 @@ import dataclasses import enum import functools +from collections.abc import Iterator import jax import jax.numpy as jnp @@ -43,7 +44,7 @@ class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC): @contextlib.contextmanager def switch_velocity_representation( self, velocity_representation: VelRepr - ) -> contextlib.AbstractContextManager[Self]: + ) -> Iterator[Self]: """ Context manager to temporarily switch the velocity representation.