diff --git a/src/gfn/preprocessors.py b/src/gfn/preprocessors.py index b00020dd..c980168d 100644 --- a/src/gfn/preprocessors.py +++ b/src/gfn/preprocessors.py @@ -31,7 +31,9 @@ class IdentityPreprocessor(Preprocessor): This is the default preprocessor used.""" def preprocess(self, states: States) -> TT["batch_shape", "input_dim"]: - return states.tensor.float() + return ( + states.tensor.float() + ) # TODO: should we typecast here? not a true identity... class EnumPreprocessor(Preprocessor):