Skip to content

Commit

Permalink
no longer automatically casting tensors which causes unexplained erro…
Browse files Browse the repository at this point in the history
…rs if the environment isn't initalized with the correct device_str
  • Loading branch information
josephdviviano committed Feb 17, 2024
1 parent f41a31f commit 35908de
Showing 1 changed file with 0 additions and 6 deletions.
6 changes: 0 additions & 6 deletions src/gfn/utils/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def __init__(
else:
self.torso = torso
self.last_layer = nn.Linear(self.torso.hidden_dim, output_dim)
self.device = None

def forward(
self, preprocessed_states: TT["batch_shape", "input_dim", float]
Expand All @@ -66,11 +65,6 @@ def forward(
ingestion by the MLP.
Returns: out, a set of continuous variables.
"""
if self.device is None:
self.device = preprocessed_states.device
self.to(
self.device
) # TODO: This is maybe fine but could result in weird errors if the model keeps bouncing between devices.
out = self.torso(preprocessed_states)
out = self.last_layer(out)
return out
Expand Down

0 comments on commit 35908de

Please sign in to comment.