Skip to content

Commit

Permalink
Merge pull request #53 from NLeSC/296-prepare-weather-dataset
Browse files Browse the repository at this point in the history
add weather dataset loading and tests, fixes #NLeSC/mcfly#296
  • Loading branch information
cwmeijer authored Dec 23, 2022
2 parents 9493fee + 9993adb commit 0213dca
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 0 deletions.
32 changes: 32 additions & 0 deletions tests/test_tutorial_weather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import shutil
import unittest
from pathlib import Path

from utils.tutorial_weather import load_data


class TutorialWeatherSuite(unittest.TestCase):
""" Weather data set test cases."""
temp_test_dir = 'temp_weather_test'

def test_data_downloading_has_correct_shape(self):
n_features = 89
n_train_instances = 767
n_test_instances = 329

X_train, X_test, y_train, y_test = load_data(self.temp_test_dir)

assert X_train.shape == (n_train_instances, n_features)
assert X_test.shape == (n_test_instances, n_features)
assert y_train.shape == (n_train_instances,)
assert y_test.shape == (n_test_instances,)

def setUp(self) -> None:
Path(self.temp_test_dir).mkdir()

def tearDown(self) -> None:
shutil.rmtree(Path(self.temp_test_dir))


if __name__ == '__main__':
unittest.main()
45 changes: 45 additions & 0 deletions utils/tutorial_weather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import typing
import urllib
from pathlib import Path

from urllib.request import urlretrieve
import pandas as pd
from sklearn.model_selection import train_test_split


def load_data(path: str = '.'):
"""
Load weather dataset (10.5281/zenodo.4770936.). If it's not on the path specified, it will be downloaded.
Parameters
----------
path : str
The local path to the data set folder.
Returns
-------
X_train
X_test
y_train
y_test
"""
data_path = download_preprocessed_data(path)
data = pd.read_csv(data_path)
nr_rows = 365 * 3
X_data = data.loc[:nr_rows].drop(columns=['DATE', 'MONTH'])

days_ahead = 1
y_data = data.loc[days_ahead:(nr_rows + days_ahead)]["MAASTRICHT_sunshine"]
X_train, X_test, y_train, y_test = train_test_split(X_data, y_data, test_size=0.3, random_state=0)

return X_train, X_test, y_train, y_test


def download_preprocessed_data(directory_to_extract_to: typing.Union[str, Path]):
data_path = Path(directory_to_extract_to) / 'weather'
data_path.mkdir(exist_ok=True)
data_set_light_path = data_path / 'weather_prediction_dataset_light.csv'
if not data_set_light_path.exists():
_, _ = urllib.request.urlretrieve(
'https://zenodo.org/record/7053722/files/weather_prediction_dataset_light.csv',
filename=data_set_light_path)
return data_set_light_path

0 comments on commit 0213dca

Please sign in to comment.