diff --git a/tests/test_tutorial_weather.py b/tests/test_tutorial_weather.py new file mode 100644 index 0000000..9f024d5 --- /dev/null +++ b/tests/test_tutorial_weather.py @@ -0,0 +1,32 @@ +import shutil +import unittest +from pathlib import Path + +from utils.tutorial_weather import load_data + + +class TutorialWeatherSuite(unittest.TestCase): + """ Weather data set test cases.""" + temp_test_dir = 'temp_weather_test' + + def test_data_downloading_has_correct_shape(self): + n_features = 89 + n_train_instances = 767 + n_test_instances = 329 + + X_train, X_test, y_train, y_test = load_data(self.temp_test_dir) + + assert X_train.shape == (n_train_instances, n_features) + assert X_test.shape == (n_test_instances, n_features) + assert y_train.shape == (n_train_instances,) + assert y_test.shape == (n_test_instances,) + + def setUp(self) -> None: + Path(self.temp_test_dir).mkdir() + + def tearDown(self) -> None: + shutil.rmtree(Path(self.temp_test_dir)) + + +if __name__ == '__main__': + unittest.main() diff --git a/utils/tutorial_weather.py b/utils/tutorial_weather.py new file mode 100644 index 0000000..5889161 --- /dev/null +++ b/utils/tutorial_weather.py @@ -0,0 +1,45 @@ +import typing +import urllib +from pathlib import Path + +from urllib.request import urlretrieve +import pandas as pd +from sklearn.model_selection import train_test_split + + +def load_data(path: str = '.'): + """ + Load weather dataset (10.5281/zenodo.4770936.). If it's not on the path specified, it will be downloaded. + Parameters + ---------- + path : str + The local path to the data set folder. + + Returns + ------- + X_train + X_test + y_train + y_test + """ + data_path = download_preprocessed_data(path) + data = pd.read_csv(data_path) + nr_rows = 365 * 3 + X_data = data.loc[:nr_rows].drop(columns=['DATE', 'MONTH']) + + days_ahead = 1 + y_data = data.loc[days_ahead:(nr_rows + days_ahead)]["MAASTRICHT_sunshine"] + X_train, X_test, y_train, y_test = train_test_split(X_data, y_data, test_size=0.3, random_state=0) + + return X_train, X_test, y_train, y_test + + +def download_preprocessed_data(directory_to_extract_to: typing.Union[str, Path]): + data_path = Path(directory_to_extract_to) / 'weather' + data_path.mkdir(exist_ok=True) + data_set_light_path = data_path / 'weather_prediction_dataset_light.csv' + if not data_set_light_path.exists(): + _, _ = urllib.request.urlretrieve( + 'https://zenodo.org/record/7053722/files/weather_prediction_dataset_light.csv', + filename=data_set_light_path) + return data_set_light_path