diff --git a/src/gfn/utils/modules.py b/src/gfn/utils/modules.py index 2ffbf54a..9820fa05 100644 --- a/src/gfn/utils/modules.py +++ b/src/gfn/utils/modules.py @@ -54,7 +54,6 @@ def __init__( else: self.torso = torso self.last_layer = nn.Linear(self.torso.hidden_dim, output_dim) - self.device = None def forward( self, preprocessed_states: TT["batch_shape", "input_dim", float] @@ -66,11 +65,6 @@ def forward( ingestion by the MLP. Returns: out, a set of continuous variables. """ - if self.device is None: - self.device = preprocessed_states.device - self.to( - self.device - ) # TODO: This is maybe fine but could result in weird errors if the model keeps bouncing between devices. out = self.torso(preprocessed_states) out = self.last_layer(out) return out