From 35908de56ff385108e52e4bfd0858d5c88bef72b Mon Sep 17 00:00:00 2001 From: Joseph Viviano Date: Fri, 16 Feb 2024 21:15:24 -0500 Subject: [PATCH] no longer automatically casting tensors which causes unexplained errors if the environment isn't initalized with the correct device_str --- src/gfn/utils/modules.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/gfn/utils/modules.py b/src/gfn/utils/modules.py index 2ffbf54a..9820fa05 100644 --- a/src/gfn/utils/modules.py +++ b/src/gfn/utils/modules.py @@ -54,7 +54,6 @@ def __init__( else: self.torso = torso self.last_layer = nn.Linear(self.torso.hidden_dim, output_dim) - self.device = None def forward( self, preprocessed_states: TT["batch_shape", "input_dim", float] @@ -66,11 +65,6 @@ def forward( ingestion by the MLP. Returns: out, a set of continuous variables. """ - if self.device is None: - self.device = preprocessed_states.device - self.to( - self.device - ) # TODO: This is maybe fine but could result in weird errors if the model keeps bouncing between devices. out = self.torso(preprocessed_states) out = self.last_layer(out) return out