Skip to content

Commit

Permalink
Merge pull request #855 from lisphilar/issue853
Browse files Browse the repository at this point in the history
new: _LoaderBase.collect()
  • Loading branch information
lisphilar authored Jun 26, 2021
2 parents f782b00 + 83189d8 commit 5e92641
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 9 deletions.
17 changes: 17 additions & 0 deletions covsirphy/loading/loaderbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,20 @@ def pyramid(self):
covsirphy.PopulationPyramidData: dataset regarding population pyramid
"""
raise NotImplementedError

def collect(self):
"""
Collect data for scenario analysis and return them as a dictionary.
Returns:
dict(str, object):
- jhu_data (covsirphy.JHUData)
- extras (list[covsirphy.CleaningBase]):
- covsirphy.OXCGRTData
- covsirphy.PCRData
- covsirphy.VaccineData
"""
return {
"jhu_data": self.jhu(),
"extras": [self.oxcgrt(), self.pcr(), self.vaccine()]
}
10 changes: 4 additions & 6 deletions example/scenario_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def main(country="Italy", province=None, file_prefix="ita"):
pronvince (str or None): province name or None (country level)
file_prefix (str): prefix of the filenames
"""
# This script works with version >= 2.20.3-delta
print("This script works with version >= 2.21.0-gamma")
print(cs.get_version())
# Create output directory in example directory
code_path = Path(__file__)
Expand All @@ -35,12 +35,10 @@ def main(country="Italy", province=None, file_prefix="ita"):
filer = cs.Filer(output_dir, prefix=file_prefix, numbering="01")
# Load datasets
loader = cs.DataLoader(input_dir)
jhu_data = loader.jhu()
oxcgrt_data = loader.oxcgrt()
vaccine_data = loader.vaccine()
# Start scenario analysis
data_dict = loader.collect()
# Start scenario analysis and register datasets
snl = cs.Scenario(country=country, province=province)
snl.register(jhu_data, extras=[oxcgrt_data, vaccine_data])
snl.register(**data_dict)
# Show records
record_df = snl.records(**filer.png("records"))
record_df.to_csv(**filer.csv("records", index=False))
Expand Down
13 changes: 13 additions & 0 deletions example/usage_quick.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,19 @@
"vaccine_data = loader.vaccine()"
]
},
{
"source": [
"From development version 2.21.0-gamma, we can collect datasets and get them as dictionary with `DataLoader.collect()`. This will be implemented in stable version 2.22.0.\n",
"\n",
"```Python\n",
"data_dict = loader.collect()\n",
"snl = cs.Scenario(country=\"Japan\", province=None)\n",
"snl.register(**data_dict)\n",
"```"
],
"cell_type": "markdown",
"metadata": {}
},
{
"source": [
"### Start scenario analysis\n",
Expand Down
36 changes: 33 additions & 3 deletions tests/test_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,32 @@

from pathlib import Path
import pytest
from covsirphy import DataLoader, LinelistData, JHUData, CountryData, PopulationData
from covsirphy import OxCGRTData, PCRData, VaccineData, COVID19DataHub
from covsirphy import DataLoader, COVID19DataHub
from covsirphy import LinelistData, JHUData, CountryData, PopulationData
from covsirphy import OxCGRTData, PCRData, VaccineData, PopulationPyramidData
from covsirphy import Scenario
from covsirphy.loading.loaderbase import _LoaderBase


class TestLoaderBase(object):
def test_not_implelemted(self):
base = _LoaderBase()
with pytest.raises(NotImplementedError):
base.jhu()
with pytest.raises(NotImplementedError):
base.population()
with pytest.raises(NotImplementedError):
base.oxcgrt()
with pytest.raises(NotImplementedError):
base.japan()
with pytest.raises(NotImplementedError):
base.linelist()
with pytest.raises(NotImplementedError):
base.pcr()
with pytest.raises(NotImplementedError):
base.vaccine()
with pytest.raises(NotImplementedError):
base.pyramid()


class TestDataLoader(object):
Expand All @@ -13,7 +37,7 @@ def test_start(self):
DataLoader(directory=0)

def test_dataloader(self, jhu_data, population_data, oxcgrt_data,
japan_data, linelist_data, pcr_data, vaccine_data):
japan_data, linelist_data, pcr_data, vaccine_data, pyramid_data):
# List of primary sources of COVID-19 Data Hub
data_loader = DataLoader()
assert data_loader.covid19dh_citation
Expand All @@ -25,12 +49,18 @@ def test_dataloader(self, jhu_data, population_data, oxcgrt_data,
assert isinstance(linelist_data, LinelistData)
assert isinstance(pcr_data, PCRData)
assert isinstance(vaccine_data, VaccineData)
assert isinstance(pyramid_data, PopulationPyramidData)
# Local file
data_loader.jhu(local_file="input/covid19dh.csv")
data_loader.population(local_file="input/covid19dh.csv")
data_loader.oxcgrt(local_file="input/covid19dh.csv")
data_loader.pcr(local_file="input/covid19dh.csv")

def test_collect(self, data_loader):
data_dict = data_loader.collect()
snl = Scenario(country="Japan")
snl.register(**data_dict)


class TestCOVID19DataHub(object):
def test_covid19dh(self):
Expand Down

0 comments on commit 5e92641

Please sign in to comment.