diff --git a/datasets.py b/datasets.py index f31ccec..4421f1b 100644 --- a/datasets.py +++ b/datasets.py @@ -137,8 +137,9 @@ def prepare_bosch(dataset_folder, nrows): os.system("kaggle competitions download -c bosch-production-line-performance -f " + filename + " -p " + dataset_folder) - X = pd.read_csv(local_url, index_col=0, compression='zip', dtype=np.float32, - nrows=nrows) + X = pd.read_csv(local_url,compression='zip', dtype=np.float32) + X = X.set_index('Id') + X.index = X.index.astype('int64') y = X.iloc[:, -1].to_numpy(dtype=np.float32) X.drop(X.columns[-1], axis=1, inplace=True) X = X.to_numpy(dtype=np.float32)