diff --git a/hannah/nas/performance_prediction/features/dataset.py b/hannah/nas/performance_prediction/features/dataset.py index e976f162..318fe1f8 100644 --- a/hannah/nas/performance_prediction/features/dataset.py +++ b/hannah/nas/performance_prediction/features/dataset.py @@ -155,7 +155,7 @@ def get_features(nx_graph): if col not in df.columns: df[col] = 0 df = df.reindex(sorted(df.columns), axis=1) # Sort to have consistency - return df + return df.astype(np.float32) def get_list_columns(df): list_cols = []