Skip to content

Commit

Permalink
type hinting
Browse files Browse the repository at this point in the history
  • Loading branch information
beardyFace committed Dec 11, 2024
1 parent aa67a7a commit 7e84c15
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions cares_reinforcement_learning/util/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def __init__(

def forward( # type: ignore
self, state: dict[str, torch.Tensor], detach_encoder: bool = False
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Detach at the CNN layer to prevent backpropagation through the encoder
state_latent = self.encoder(state["image"], detach_cnn=detach_encoder)

Expand Down Expand Up @@ -382,7 +382,7 @@ def forward(
state: dict[str, torch.Tensor],
action: torch.Tensor,
detach_encoder: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
# Detach at the CNN layer to prevent backpropagation through the encoder
state_latent = self.encoder(state["image"], detach_cnn=detach_encoder)

Expand Down Expand Up @@ -412,7 +412,7 @@ def __init__(

def forward(
self, state: dict[str, torch.Tensor], detach_encoder: bool = False
) -> torch.Tensor:
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# NaSATD3 detatches the encoder at the output
if self.autoencoder.ae_type == Autoencoders.BURGESS:
# take the mean value for stability
Expand Down

0 comments on commit 7e84c15

Please sign in to comment.