diff --git a/cares_reinforcement_learning/util/common.py b/cares_reinforcement_learning/util/common.py index 6cd9960..3fe0af7 100644 --- a/cares_reinforcement_learning/util/common.py +++ b/cares_reinforcement_learning/util/common.py @@ -350,7 +350,7 @@ def __init__( def forward( # type: ignore self, state: dict[str, torch.Tensor], detach_encoder: bool = False - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Detach at the CNN layer to prevent backpropagation through the encoder state_latent = self.encoder(state["image"], detach_cnn=detach_encoder) @@ -382,7 +382,7 @@ def forward( state: dict[str, torch.Tensor], action: torch.Tensor, detach_encoder: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: # Detach at the CNN layer to prevent backpropagation through the encoder state_latent = self.encoder(state["image"], detach_cnn=detach_encoder) @@ -412,7 +412,7 @@ def __init__( def forward( self, state: dict[str, torch.Tensor], detach_encoder: bool = False - ) -> torch.Tensor: + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # NaSATD3 detatches the encoder at the output if self.autoencoder.ae_type == Autoencoders.BURGESS: # take the mean value for stability